mboot/mboot/protocols/
usb.rs1use std::{io, time::Duration};
6
7use crate::mboot::ResultComm;
8use color_print::cstr;
9use hidapi::{HidApi, HidDevice};
10use log::{debug, info};
11use std::fmt::Debug;
12
13use super::{CommunicationError, Protocol, ProtocolOpen};
14
15mod report {
17 pub const CMD_OUT: u8 = 0x01;
19 pub const DATA_OUT: u8 = 0x02;
21 pub const CMD_IN: u8 = 0x03;
23 pub const DATA_IN: u8 = 0x04;
25}
26
27const MAX_PACKET_SIZE: usize = 1024;
29
30#[derive(Debug)]
31pub struct USBProtocol {
32 interface: String,
33 device: HidDevice,
34 timeout_ms: i32,
35 polling_interval: Duration,
36}
37
38impl ProtocolOpen for USBProtocol {
39 fn open(identifier: &str) -> ResultComm<Self> {
40 Self::open_with_options(identifier, 0, Duration::from_secs(5), Duration::from_millis(1))
41 }
42
43 fn open_with_options(
44 identifier: &str,
45 _baudrate: u32, timeout: Duration,
47 polling_interval: Duration,
48 ) -> ResultComm<Self> {
49 let (vid, pid) = parse_usb_identifier(identifier)?;
51
52 let api =
54 HidApi::new().map_err(|e| CommunicationError::ParseError(format!("Failed to initialize HID API: {e}")))?;
55
56 let device = api
58 .open(vid, pid)
59 .map_err(|e| CommunicationError::ParseError(format!("Failed to open USB device: {e}")))?;
60
61 let timeout_ms = timeout.as_millis().try_into().unwrap_or(i32::MAX);
63
64 let usb_protocol = USBProtocol {
65 interface: identifier.to_owned(),
66 device,
67 timeout_ms,
68 polling_interval,
69 };
70
71 info!(
72 "Opened USB-HID device {} with {}ms timeout",
73 usb_protocol.interface,
74 timeout.as_millis()
75 );
76
77 Ok(usb_protocol)
78 }
79}
80
81impl Protocol for USBProtocol {
82 fn get_polling_interval(&self) -> Duration {
83 self.polling_interval
84 }
85
86 fn get_timeout(&self) -> Duration {
87 Duration::from_millis(self.timeout_ms.try_into().expect("negative timeout in USB"))
88 }
89
90 fn get_identifier(&self) -> &str {
91 &self.interface
92 }
93
94 fn read(&mut self, bytes: usize) -> ResultComm<Vec<u8>> {
95 let mut buf = vec![0u8; bytes];
96 self.read_usb(&mut buf)?;
97 Ok(buf)
98 }
99 fn write_packet_raw(&mut self, data: &[u8]) -> ResultComm<()> {
100 if data.len() < 6 || data[0] != 0x5A {
103 return Err(CommunicationError::InvalidHeader);
104 }
105
106 let cmd_type = data[1];
107 let data_len = u16::from_le_bytes([data[2], data[3]]) as usize;
108
109 if data.len() < 6 + data_len {
110 return Err(CommunicationError::InvalidData);
111 }
112
113 let cmd_data = &data[6..6 + data_len];
116
117 let report_id = match cmd_type {
119 0xA4 => report::CMD_OUT, 0xA5 => report::DATA_OUT, _ => return Err(CommunicationError::InvalidHeader),
122 };
123
124 let mut report = vec![0u8; 4 + cmd_data.len()]; report[0] = report_id;
129 report[1] = 0x00; report[2] = (cmd_data.len() & 0xFF) as u8;
131 report[3] = ((cmd_data.len() >> 8) & 0xFF) as u8;
132
133 report[4..4 + cmd_data.len()].copy_from_slice(cmd_data);
135
136 self.write_usb(&report)?;
138
139 Ok(())
140 }
141 fn read_packet_raw(&mut self, _: u8) -> ResultComm<Vec<u8>> {
144 let mut report = vec![0u8; MAX_PACKET_SIZE];
146 let size = self
147 .device
148 .read_timeout(&mut report, self.timeout_ms)
149 .map_err(|e| CommunicationError::IOError(io::Error::other(e.to_string())))?;
150
151 debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &report[..size]);
152
153 if size < 4 {
154 return Err(CommunicationError::InvalidHeader);
155 }
156
157 let report_id = report[0];
159 let packet_length = u16::from_le_bytes([report[2], report[3]]) as usize;
160
161 if packet_length == 0 {
162 return Err(CommunicationError::Aborted);
164 }
165
166 if report_id == report::CMD_IN {
168 let mut response = Vec::new();
170
171 response.extend_from_slice(&report[4..4 + packet_length]);
173
174 debug!("Constructed response: {response:02X?}");
175
176 return Ok(response);
177 } else if report_id == report::DATA_IN {
178 if size >= 4 + packet_length {
180 return Ok(report[4..4 + packet_length].to_vec());
181 }
182 }
183
184 if size > 4 {
186 Ok(report[4..size].to_vec())
187 } else {
188 Ok(Vec::new())
189 }
190 }
191}
192
193impl USBProtocol {
194 fn read_usb(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
195 match self.device.read(buf) {
196 Ok(size) => {
197 debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &buf[..size]);
198 Ok(())
199 }
200 Err(e) => Err(io::Error::other(e.to_string())),
201 }
202 }
203 fn write_usb(&self, buf: &[u8]) -> Result<(), io::Error> {
204 debug!("{}: {:02X?}", cstr!("<g!>TX"), buf);
205
206 match self.device.write(buf) {
207 Ok(written) => {
208 #[cfg(target_os = "windows")]
210 {
211 if written > 0 {
214 Ok(())
215 } else {
216 Err(io::Error::other("Failed to write to USB device"))
217 }
218 }
219 #[cfg(not(target_os = "windows"))]
220 {
221 if written == buf.len() {
223 Ok(())
224 } else {
225 Err(io::Error::other(format!(
226 "Failed to write all bytes: wrote {} of {}",
227 written,
228 buf.len()
229 )))
230 }
231 }
232 }
233 Err(e) => Err(io::Error::other(e.to_string())),
234 }
235 }
236}
237
238fn parse_usb_identifier(identifier: &str) -> ResultComm<(u16, u16)> {
241 if let Some(pos) = identifier.find([':', ',']) {
243 let vid_str = &identifier[..pos];
244 let pid_str = &identifier[pos + 1..];
245
246 let vid = parse_number_string(vid_str)
247 .map_err(|_| CommunicationError::ParseError(format!("Invalid VID: {vid_str}")))?;
248
249 let pid = parse_number_string(pid_str)
250 .map_err(|_| CommunicationError::ParseError(format!("Invalid PID: {pid_str}")))?;
251
252 Ok((vid, pid))
253 } else {
254 let vid = parse_number_string(identifier)
256 .map_err(|_| CommunicationError::ParseError(format!("Invalid USB identifier: {identifier}")))?;
257
258 Ok((vid, 0))
260 }
261}
262
263fn parse_number_string(s: &str) -> Result<u16, std::num::ParseIntError> {
265 let trimmed = s.trim();
266
267 if trimmed.starts_with("0x") || trimmed.starts_with("0X") {
268 u16::from_str_radix(&trimmed[2..], 16)
270 } else if trimmed.chars().all(|c| c.is_ascii_hexdigit())
271 && trimmed.len() > 2
272 && trimmed.chars().any(|c| matches!(c, 'a'..='f' | 'A'..='F'))
273 {
274 u16::from_str_radix(trimmed, 16)
276 } else {
277 trimmed.parse::<u16>().or_else(|_| u16::from_str_radix(trimmed, 16))
279 }
280}