use std::{io, time::Duration};
use crate::mboot::ResultComm;
use color_print::cstr;
use hidapi::{HidApi, HidDevice};
use log::{debug, info};
use std::fmt::Debug;
use super::{CommunicationError, Protocol, ProtocolOpen};
mod report {
pub const CMD_OUT: u8 = 0x01;
pub const DATA_OUT: u8 = 0x02;
pub const CMD_IN: u8 = 0x03;
pub const DATA_IN: u8 = 0x04;
}
const MAX_PACKET_SIZE: usize = 1024;
#[derive(Debug)]
pub struct USBProtocol {
interface: String,
device: HidDevice,
timeout_ms: i32,
polling_interval: Duration,
}
impl ProtocolOpen for USBProtocol {
fn open(identifier: &str) -> ResultComm<Self> {
Self::open_with_options(identifier, 0, Duration::from_secs(5), Duration::from_millis(1))
}
fn open_with_options(
identifier: &str,
_baudrate: u32, timeout: Duration,
polling_interval: Duration,
) -> ResultComm<Self> {
let (vid, pid) = parse_usb_identifier(identifier)?;
let api =
HidApi::new().map_err(|e| CommunicationError::ParseError(format!("Failed to initialize HID API: {e}")))?;
let device = api
.open(vid, pid)
.map_err(|e| CommunicationError::ParseError(format!("Failed to open USB device: {e}")))?;
let timeout_ms = timeout.as_millis().try_into().unwrap_or(i32::MAX);
let usb_protocol = USBProtocol {
interface: identifier.to_owned(),
device,
timeout_ms,
polling_interval,
};
info!(
"Opened USB-HID device {} with {}ms timeout",
usb_protocol.interface,
timeout.as_millis()
);
Ok(usb_protocol)
}
}
impl Protocol for USBProtocol {
fn get_polling_interval(&self) -> Duration {
self.polling_interval
}
fn get_timeout(&self) -> Duration {
Duration::from_millis(self.timeout_ms.try_into().expect("negative timeout in USB"))
}
fn get_identifier(&self) -> &str {
&self.interface
}
fn read(&mut self, bytes: usize) -> ResultComm<Vec<u8>> {
let mut buf = vec![0u8; bytes];
self.read_usb(&mut buf)?;
Ok(buf)
}
fn write_packet_raw(&mut self, data: &[u8]) -> ResultComm<()> {
if data.len() < 6 || data[0] != 0x5A {
return Err(CommunicationError::InvalidHeader);
}
let cmd_type = data[1];
let data_len = u16::from_le_bytes([data[2], data[3]]) as usize;
if data.len() < 6 + data_len {
return Err(CommunicationError::InvalidData);
}
let cmd_data = &data[6..6 + data_len];
let report_id = match cmd_type {
0xA4 => report::CMD_OUT, 0xA5 => report::DATA_OUT, _ => return Err(CommunicationError::InvalidHeader),
};
let mut report = vec![0u8; 4 + cmd_data.len()];
report[0] = report_id;
report[1] = 0x00; report[2] = (cmd_data.len() & 0xFF) as u8;
report[3] = ((cmd_data.len() >> 8) & 0xFF) as u8;
report[4..4 + cmd_data.len()].copy_from_slice(cmd_data);
self.write_usb(&report)?;
Ok(())
}
fn read_packet_raw(&mut self, _: u8) -> ResultComm<Vec<u8>> {
let mut report = vec![0u8; MAX_PACKET_SIZE];
let size = self
.device
.read_timeout(&mut report, self.timeout_ms)
.map_err(|e| CommunicationError::IOError(io::Error::other(e.to_string())))?;
debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &report[..size]);
if size < 4 {
return Err(CommunicationError::InvalidHeader);
}
let report_id = report[0];
let packet_length = u16::from_le_bytes([report[2], report[3]]) as usize;
if packet_length == 0 {
return Err(CommunicationError::Aborted);
}
if report_id == report::CMD_IN {
let mut response = Vec::new();
response.extend_from_slice(&report[4..4 + packet_length]);
debug!("Constructed response: {response:02X?}");
return Ok(response);
} else if report_id == report::DATA_IN {
if size >= 4 + packet_length {
return Ok(report[4..4 + packet_length].to_vec());
}
}
if size > 4 {
Ok(report[4..size].to_vec())
} else {
Ok(Vec::new())
}
}
}
impl USBProtocol {
fn read_usb(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
match self.device.read(buf) {
Ok(size) => {
debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &buf[..size]);
Ok(())
}
Err(e) => Err(io::Error::other(e.to_string())),
}
}
fn write_usb(&self, buf: &[u8]) -> Result<(), io::Error> {
debug!("{}: {:02X?}", cstr!("<g!>TX"), buf);
match self.device.write(buf) {
Ok(written) => {
#[cfg(target_os = "windows")]
{
if written > 0 {
Ok(())
} else {
Err(io::Error::other("Failed to write to USB device"))
}
}
#[cfg(not(target_os = "windows"))]
{
if written == buf.len() {
Ok(())
} else {
Err(io::Error::other(format!(
"Failed to write all bytes: wrote {} of {}",
written,
buf.len()
)))
}
}
}
Err(e) => Err(io::Error::other(e.to_string())),
}
}
}
fn parse_usb_identifier(identifier: &str) -> ResultComm<(u16, u16)> {
if let Some(pos) = identifier.find([':', ',']) {
let vid_str = &identifier[..pos];
let pid_str = &identifier[pos + 1..];
let vid = parse_number_string(vid_str)
.map_err(|_| CommunicationError::ParseError(format!("Invalid VID: {vid_str}")))?;
let pid = parse_number_string(pid_str)
.map_err(|_| CommunicationError::ParseError(format!("Invalid PID: {pid_str}")))?;
Ok((vid, pid))
} else {
let vid = parse_number_string(identifier)
.map_err(|_| CommunicationError::ParseError(format!("Invalid USB identifier: {identifier}")))?;
Ok((vid, 0))
}
}
fn parse_number_string(s: &str) -> Result<u16, std::num::ParseIntError> {
let trimmed = s.trim();
if trimmed.starts_with("0x") || trimmed.starts_with("0X") {
u16::from_str_radix(&trimmed[2..], 16)
} else if trimmed.chars().all(|c| c.is_ascii_hexdigit())
&& trimmed.len() > 2
&& trimmed.chars().any(|c| matches!(c, 'a'..='f' | 'A'..='F'))
{
u16::from_str_radix(trimmed, 16)
} else {
trimmed.parse::<u16>().or_else(|_| u16::from_str_radix(trimmed, 16))
}
}