use crate::error::Error;
use crate::wire::{OperationCode, PROTOCOL_VERSION, Request, ResponseHeader};
use serde::Serialize;
use std::io::{Read, Write};
const PROBE_PAYLOAD: [u8; 8] = [0u8; 8];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct HdmIdentity {
pub protocol_version: (u8, u8),
pub software_version: (u8, u8, u8),
pub response_code: u16,
}
pub fn identify<T: Read + Write>(transport: &mut T) -> Result<HdmIdentity, Error> {
let request = Request {
op: OperationCode::ListOpsAndDeps,
payload: PROBE_PAYLOAD.to_vec(),
};
request.encode(transport)?;
let header = ResponseHeader::read(transport)?;
let protocol_version = header.protocol_version;
let [exp_major, _exp_minor] = PROTOCOL_VERSION;
if protocol_version.0 != exp_major {
return Err(Error::NotHdm { protocol_version });
}
Ok(HdmIdentity {
protocol_version,
software_version: header.software_version,
response_code: header.code,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire::{MAGIC, RESPONSE_HEADER_LEN};
use std::io::{self, Cursor};
struct Loopback {
written: Vec<u8>,
incoming: Cursor<Vec<u8>>,
}
impl Loopback {
fn new(incoming: Vec<u8>) -> Self {
Self {
written: Vec::new(),
incoming: Cursor::new(incoming),
}
}
}
impl Read for Loopback {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.incoming.read(buf)
}
}
impl Write for Loopback {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn response_header(protocol: [u8; 2], code: u16) -> Vec<u8> {
let mut out = Vec::with_capacity(RESPONSE_HEADER_LEN);
out.extend_from_slice(&protocol); out.extend_from_slice(&[0x02, 0x02, 0x10]); out.extend_from_slice(&code.to_be_bytes()); out.extend_from_slice(&[0x00, 0x00]); out.extend_from_slice(&[0x00, 0x00]); out
}
#[test]
fn identifies_hdm_from_unauthorized_response() {
let mut transport = Loopback::new(response_header([0x00, 0x05], 403));
let identity = identify(&mut transport).expect("should identify an HDM");
assert_eq!(identity.protocol_version, (0, 5));
assert_eq!(identity.software_version, (2, 2, 16));
assert_eq!(identity.response_code, 403);
}
#[test]
fn identifies_hdm_reporting_newer_protocol_minor() {
let mut header = response_header([0x00, 0x07], 101);
header[2] = 0x01;
header[3] = 0x01;
header[4] = 0x00;
let mut transport = Loopback::new(header);
let identity = identify(&mut transport).expect("a 0.7 device is still an HDM");
assert_eq!(identity.protocol_version, (0, 7));
assert_eq!(identity.software_version, (1, 1, 0));
assert_eq!(identity.response_code, 101);
}
#[test]
fn probe_frame_is_well_formed_op1() {
let mut transport = Loopback::new(response_header([0x00, 0x05], 101));
identify(&mut transport).expect("identify");
let written = &transport.written;
assert!(written.starts_with(&MAGIC));
assert_eq!(&written[6..8], PROTOCOL_VERSION);
assert_eq!(written[8], OperationCode::ListOpsAndDeps as u8);
assert_eq!(written[9], 0); assert_eq!(u16::from_be_bytes([written[10], written[11]]), 8); }
#[test]
fn rejects_non_hdm_service() {
let banner = b"SSH-2.0-OpenSSH_9.6\r\n".to_vec();
let mut transport = Loopback::new(banner);
let err = identify(&mut transport).expect_err("should reject non-HDM");
match err {
Error::NotHdm { protocol_version } => {
assert_eq!(protocol_version, (b'S', b'S'));
}
other => panic!("expected NotHdm, got {other:?}"),
}
}
#[test]
fn silent_endpoint_is_transport_error() {
let mut transport = Loopback::new(Vec::new());
let err = identify(&mut transport).expect_err("should fail to read a header");
assert!(matches!(err, Error::Transport(_)));
}
}