mboot/mboot/protocols/
usb.rs

1// Copyright 2025 NXP
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5use std::{io, time::Duration};
6
7use crate::mboot::ResultComm;
8use color_print::cstr;
9use hidapi::{HidApi, HidDevice};
10use log::{debug, info};
11use std::fmt::Debug;
12
13use super::{CommunicationError, Protocol, ProtocolOpen};
14
15/// Report IDs for USB-HID protocol as per NXP documentation
16mod report {
17    /// Command packet from host to device
18    pub const CMD_OUT: u8 = 0x01;
19    /// Data packet from host to device
20    pub const DATA_OUT: u8 = 0x02;
21    /// Response packet from device to host
22    pub const CMD_IN: u8 = 0x03;
23    /// Data packet from device to host
24    pub const DATA_IN: u8 = 0x04;
25}
26
27/// Maximum packet size for USB transfers
28const MAX_PACKET_SIZE: usize = 1024;
29
30#[derive(Debug)]
31pub struct USBProtocol {
32    interface: String,
33    device: HidDevice,
34    timeout_ms: i32,
35    polling_interval: Duration,
36}
37
38impl ProtocolOpen for USBProtocol {
39    fn open(identifier: &str) -> ResultComm<Self> {
40        Self::open_with_options(identifier, 0, Duration::from_secs(5), Duration::from_millis(1))
41    }
42
43    fn open_with_options(
44        identifier: &str,
45        _baudrate: u32, // Not used for USB
46        timeout: Duration,
47        polling_interval: Duration,
48    ) -> ResultComm<Self> {
49        // Parse the identifier which can be in format "vid:pid" or a path
50        let (vid, pid) = parse_usb_identifier(identifier)?;
51
52        // Initialize HidApi
53        let api =
54            HidApi::new().map_err(|e| CommunicationError::ParseError(format!("Failed to initialize HID API: {e}")))?;
55
56        // Find and open the device
57        let device = api
58            .open(vid, pid)
59            .map_err(|e| CommunicationError::ParseError(format!("Failed to open USB device: {e}")))?;
60
61        // Convert timeout to i32, clamping if necessary
62        let timeout_ms = timeout.as_millis().try_into().unwrap_or(i32::MAX);
63
64        let usb_protocol = USBProtocol {
65            interface: identifier.to_owned(),
66            device,
67            timeout_ms,
68            polling_interval,
69        };
70
71        info!(
72            "Opened USB-HID device {} with {}ms timeout",
73            usb_protocol.interface,
74            timeout.as_millis()
75        );
76
77        Ok(usb_protocol)
78    }
79}
80
81impl Protocol for USBProtocol {
82    fn get_polling_interval(&self) -> Duration {
83        self.polling_interval
84    }
85
86    fn get_timeout(&self) -> Duration {
87        Duration::from_millis(self.timeout_ms.try_into().expect("negative timeout in USB"))
88    }
89
90    fn get_identifier(&self) -> &str {
91        &self.interface
92    }
93
94    fn read(&mut self, bytes: usize) -> ResultComm<Vec<u8>> {
95        let mut buf = vec![0u8; bytes];
96        self.read_usb(&mut buf)?;
97        Ok(buf)
98    }
99    fn write_packet_raw(&mut self, data: &[u8]) -> ResultComm<()> {
100        // For USB-HID, we need to extract the command data from the UART framing
101        // UART frame format: [5A, cmd_type, len_lsb, len_msb, crc_lsb, crc_msb, ...data...]
102        if data.len() < 6 || data[0] != 0x5A {
103            return Err(CommunicationError::InvalidHeader);
104        }
105
106        let cmd_type = data[1];
107        let data_len = u16::from_le_bytes([data[2], data[3]]) as usize;
108
109        if data.len() < 6 + data_len {
110            return Err(CommunicationError::InvalidData);
111        }
112
113        // Extract the command data (without UART framing and CRC)
114        // Skip the UART header (4 bytes) and CRC (2 bytes)
115        let cmd_data = &data[6..6 + data_len];
116
117        // Determine report ID based on packet type
118        let report_id = match cmd_type {
119            0xA4 => report::CMD_OUT,  // Command packet
120            0xA5 => report::DATA_OUT, // Data packet
121            _ => return Err(CommunicationError::InvalidHeader),
122        };
123
124        // Create a generic HID report
125        let mut report = vec![0u8; 4 + cmd_data.len()]; // 4 bytes for header + data
126
127        // Set report header
128        report[0] = report_id;
129        report[1] = 0x00; // Padding (should be 0)
130        report[2] = (cmd_data.len() & 0xFF) as u8;
131        report[3] = ((cmd_data.len() >> 8) & 0xFF) as u8;
132
133        // Copy command data
134        report[4..4 + cmd_data.len()].copy_from_slice(cmd_data);
135
136        // Write the report
137        self.write_usb(&report)?;
138
139        Ok(())
140    }
141    //
142
143    fn read_packet_raw(&mut self, _: u8) -> ResultComm<Vec<u8>> {
144        // Read the initial response
145        let mut report = vec![0u8; MAX_PACKET_SIZE];
146        let size = self
147            .device
148            .read_timeout(&mut report, self.timeout_ms)
149            .map_err(|e| CommunicationError::IOError(io::Error::other(e.to_string())))?;
150
151        debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &report[..size]);
152
153        if size < 4 {
154            return Err(CommunicationError::InvalidHeader);
155        }
156
157        // Extract report ID and packet length
158        let report_id = report[0];
159        let packet_length = u16::from_le_bytes([report[2], report[3]]) as usize;
160
161        if packet_length == 0 {
162            // error!(cstr!("<r!>RX</>: Data aborted by sender!"));
163            return Err(CommunicationError::Aborted);
164        }
165
166        // Check if this is a command response (report ID 0x03)
167        if report_id == report::CMD_IN {
168            // For other command responses, extract the payload
169            let mut response = Vec::new();
170
171            // Extract the command tag and other fields
172            response.extend_from_slice(&report[4..4 + packet_length]);
173
174            debug!("Constructed response: {response:02X?}");
175
176            return Ok(response);
177        } else if report_id == report::DATA_IN {
178            // Data packet - extract the data portion
179            if size >= 4 + packet_length {
180                return Ok(report[4..4 + packet_length].to_vec());
181            }
182        }
183
184        // For other packet types, just return the data portion
185        if size > 4 {
186            Ok(report[4..size].to_vec())
187        } else {
188            Ok(Vec::new())
189        }
190    }
191}
192
193impl USBProtocol {
194    fn read_usb(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
195        match self.device.read(buf) {
196            Ok(size) => {
197                debug!("{}: Read {} bytes: {:02X?}", cstr!("<r!>RX"), size, &buf[..size]);
198                Ok(())
199            }
200            Err(e) => Err(io::Error::other(e.to_string())),
201        }
202    }
203    fn write_usb(&self, buf: &[u8]) -> Result<(), io::Error> {
204        debug!("{}: {:02X?}", cstr!("<g!>TX"), buf);
205
206        match self.device.write(buf) {
207            Ok(written) => {
208                // Platform-specific validation
209                #[cfg(target_os = "windows")]
210                {
211                    // Windows HID might report different sizes due to report descriptors
212                    // As long as write succeeded (written > 0), consider it successful
213                    if written > 0 {
214                        Ok(())
215                    } else {
216                        Err(io::Error::other("Failed to write to USB device"))
217                    }
218                }
219                #[cfg(not(target_os = "windows"))]
220                {
221                    // On other platforms, we expect the exact byte count
222                    if written == buf.len() {
223                        Ok(())
224                    } else {
225                        Err(io::Error::other(format!(
226                            "Failed to write all bytes: wrote {} of {}",
227                            written,
228                            buf.len()
229                        )))
230                    }
231                }
232            }
233            Err(e) => Err(io::Error::other(e.to_string())),
234        }
235    }
236}
237
238// Helper functions
239
240fn parse_usb_identifier(identifier: &str) -> ResultComm<(u16, u16)> {
241    // Check if the identifier contains a separator (either ':' or ',')
242    if let Some(pos) = identifier.find([':', ',']) {
243        let vid_str = &identifier[..pos];
244        let pid_str = &identifier[pos + 1..];
245
246        let vid = parse_number_string(vid_str)
247            .map_err(|_| CommunicationError::ParseError(format!("Invalid VID: {vid_str}")))?;
248
249        let pid = parse_number_string(pid_str)
250            .map_err(|_| CommunicationError::ParseError(format!("Invalid PID: {pid_str}")))?;
251
252        Ok((vid, pid))
253    } else {
254        // Try to parse as a single value (VID only)
255        let vid = parse_number_string(identifier)
256            .map_err(|_| CommunicationError::ParseError(format!("Invalid USB identifier: {identifier}")))?;
257
258        // Use 0 as default PID, which will match any device with the specified VID
259        Ok((vid, 0))
260    }
261}
262
263/// Parse a number string that can be either decimal or hexadecimal
264fn parse_number_string(s: &str) -> Result<u16, std::num::ParseIntError> {
265    let trimmed = s.trim();
266
267    if trimmed.starts_with("0x") || trimmed.starts_with("0X") {
268        // Hexadecimal with prefix
269        u16::from_str_radix(&trimmed[2..], 16)
270    } else if trimmed.chars().all(|c| c.is_ascii_hexdigit())
271        && trimmed.len() > 2
272        && trimmed.chars().any(|c| matches!(c, 'a'..='f' | 'A'..='F'))
273    {
274        // Hexadecimal without prefix (contains hex digits a-f)
275        u16::from_str_radix(trimmed, 16)
276    } else {
277        // Try decimal first, then hexadecimal as fallback
278        trimmed.parse::<u16>().or_else(|_| u16::from_str_radix(trimmed, 16))
279    }
280}