Skip to main content

pokeys_lib/
communication.rs

1//! Low-level communication protocol implementation
2
3use crate::error::{PoKeysError, Result};
4use crate::types::*;
5use std::time::Duration;
6
7/// Communication protocol implementation
8pub struct Protocol {
9    request_id: u8,
10    send_retries: u32,
11    read_retries: u32,
12    socket_timeout: Duration,
13}
14
15impl Default for Protocol {
16    fn default() -> Self {
17        Self {
18            request_id: 0,
19            send_retries: 3,
20            read_retries: 3,
21            socket_timeout: Duration::from_millis(1000),
22        }
23    }
24}
25
26impl Protocol {
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    pub fn set_retries_and_timeout(
32        &mut self,
33        send_retries: u32,
34        read_retries: u32,
35        timeout: Duration,
36    ) {
37        self.send_retries = send_retries;
38        self.read_retries = read_retries;
39        self.socket_timeout = timeout;
40    }
41
42    /// Calculate checksum for protocol data
43    pub fn calculate_checksum(data: &[u8]) -> u8 {
44        data.iter()
45            .take(CHECKSUM_LENGTH)
46            .fold(0u8, |acc, &x| acc.wrapping_add(x))
47    }
48
49    /// Prepare request packet
50    pub fn prepare_request(
51        &mut self,
52        request_type: u8,
53        param1: u8,
54        param2: u8,
55        param3: u8,
56        param4: u8,
57        display: Option<bool>,
58    ) -> [u8; REQUEST_BUFFER_SIZE] {
59        let mut request = [0u8; REQUEST_BUFFER_SIZE];
60
61        request[0] = REQUEST_HEADER; // 0xBB
62        request[1] = request_type;
63        request[2] = param1;
64        request[3] = param2;
65        request[4] = param3;
66        request[5] = param4;
67        request[6] = self.next_request_id();
68        request[7] = Self::calculate_checksum(&request);
69
70        if display.unwrap_or(false) {
71            println!("request: {request:02X?}");
72        }
73
74        request
75    }
76
77    /// Validate response packet
78    pub fn validate_response(&self, response: &[u8], expected_request_id: u8) -> Result<()> {
79        if response.len() < 8 {
80            return Err(PoKeysError::Protocol("Response too short".to_string()));
81        }
82
83        if response[0] != RESPONSE_HEADER {
84            return Err(PoKeysError::Protocol("Invalid response header".to_string()));
85        }
86
87        if response[6] != expected_request_id {
88            return Err(PoKeysError::Protocol("Request ID mismatch".to_string()));
89        }
90
91        let expected_checksum = Self::calculate_checksum(response);
92        if response[7] != expected_checksum {
93            return Err(PoKeysError::InvalidChecksum);
94        }
95
96        Ok(())
97    }
98
99    fn next_request_id(&mut self) -> u8 {
100        self.request_id = self.request_id.wrapping_add(1);
101        self.request_id
102    }
103}
104
105/// USB HID communication interface
106pub trait UsbHidInterface {
107    fn write(&mut self, data: &[u8]) -> Result<usize>;
108    fn read(&mut self, buffer: &mut [u8]) -> Result<usize>;
109    fn read_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize>;
110}
111
112impl<T: UsbHidInterface + ?Sized> UsbHidInterface for Box<T> {
113    fn write(&mut self, data: &[u8]) -> Result<usize> {
114        (**self).write(data)
115    }
116
117    fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
118        (**self).read(buffer)
119    }
120
121    fn read_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
122        (**self).read_timeout(buffer, timeout)
123    }
124}
125
126/// Network communication interface
127pub trait NetworkInterface {
128    fn send(&mut self, data: &[u8]) -> Result<usize>;
129    fn receive(&mut self, buffer: &mut [u8]) -> Result<usize>;
130    fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize>;
131}
132
133impl<T: NetworkInterface + ?Sized> NetworkInterface for Box<T> {
134    fn send(&mut self, data: &[u8]) -> Result<usize> {
135        (**self).send(data)
136    }
137
138    fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
139        (**self).receive(buffer)
140    }
141
142    fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
143        (**self).receive_timeout(buffer, timeout)
144    }
145}
146
147/// Communication manager that handles different connection types
148#[allow(dead_code)]
149pub struct CommunicationManager {
150    protocol: Protocol,
151    connection_type: DeviceConnectionType,
152}
153
154impl CommunicationManager {
155    pub fn new(connection_type: DeviceConnectionType) -> Self {
156        Self {
157            protocol: Protocol::new(),
158            connection_type,
159        }
160    }
161
162    pub fn set_retries_and_timeout(
163        &mut self,
164        send_retries: u32,
165        read_retries: u32,
166        timeout: Duration,
167    ) {
168        self.protocol
169            .set_retries_and_timeout(send_retries, read_retries, timeout);
170    }
171
172    /// Get the next request ID for manual packet construction
173    pub fn get_next_request_id(&mut self) -> u8 {
174        self.protocol.next_request_id()
175    }
176
177    /// Prepare a request with optional data payload
178    pub fn prepare_request_with_data(
179        &mut self,
180        request_type: u8,
181        param1: u8,
182        param2: u8,
183        param3: u8,
184        param4: u8,
185        data: Option<&[u8]>,
186    ) -> [u8; REQUEST_BUFFER_SIZE] {
187        let mut request =
188            self.protocol
189                .prepare_request(request_type, param1, param2, param3, param4, None);
190
191        // Add data payload if provided (starting at byte 8, which is protocol byte 9)
192        if let Some(payload) = data {
193            let data_len = std::cmp::min(payload.len(), 56); // Max 56 bytes of data (64 - 8 header bytes)
194            request[8..8 + data_len].copy_from_slice(&payload[0..data_len]);
195
196            // Recalculate checksum after adding data
197            request[7] = Protocol::calculate_checksum(&request);
198        }
199
200        request
201    }
202
203    /// Validate response packet
204    pub fn validate_response(&self, response: &[u8], expected_request_id: u8) -> Result<()> {
205        self.protocol
206            .validate_response(response, expected_request_id)
207    }
208
209    /// Send request via USB HID interface
210    pub fn send_usb_request<T: UsbHidInterface + ?Sized>(
211        &mut self,
212        interface: &mut T,
213        request_type: u8,
214        param1: u8,
215        param2: u8,
216        param3: u8,
217        param4: u8,
218    ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
219        let request =
220            self.protocol
221                .prepare_request(request_type, param1, param2, param3, param4, None);
222        let request_id = request[6];
223
224        let mut retries = 0;
225        while retries < self.protocol.send_retries {
226            // Prepare HID packet (add report ID byte at the beginning)
227            let mut hid_packet = [0u8; 65];
228            hid_packet[1..65].copy_from_slice(&request[..64]);
229
230            // Send request
231            match interface.write(&hid_packet) {
232                Ok(_) => {
233                    // Try to receive response
234                    let mut response = [0u8; RESPONSE_BUFFER_SIZE];
235                    let mut wait_count = 0;
236
237                    while wait_count < 50 {
238                        match interface.read_timeout(&mut response, Duration::from_millis(20)) {
239                            Ok(bytes_read) if bytes_read > 0 => {
240                                // Validate response
241                                match self.protocol.validate_response(&response, request_id) {
242                                    Ok(_) => return Ok(response),
243                                    Err(e) => {
244                                        log::warn!("Invalid response: {e}");
245                                        break;
246                                    }
247                                }
248                            }
249                            Ok(_) => {
250                                // No data received, continue waiting
251                                wait_count += 1;
252                            }
253                            Err(e) => {
254                                log::warn!("Read error: {e}");
255                                break;
256                            }
257                        }
258                    }
259                }
260                Err(e) => {
261                    log::warn!("Write error: {e}");
262                }
263            }
264
265            retries += 1;
266        }
267
268        Err(PoKeysError::Transfer(
269            "Failed to send USB request".to_string(),
270        ))
271    }
272
273    /// Send request via network interface
274    pub fn send_network_request<T: NetworkInterface + ?Sized>(
275        &mut self,
276        interface: &mut T,
277        request_type: u8,
278        param1: u8,
279        param2: u8,
280        param3: u8,
281        param4: u8,
282    ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
283        let request =
284            self.protocol
285                .prepare_request(request_type, param1, param2, param3, param4, None);
286        let request_id = request[6];
287
288        // println!("request: {request:02X?}");
289
290        let mut retries = 0;
291        while retries < self.protocol.send_retries {
292            // Send request
293            match interface.send(&request[..64]) {
294                Ok(_) => {
295                    // Try to receive response
296                    let mut response = [0u8; RESPONSE_BUFFER_SIZE];
297
298                    match interface.receive_timeout(&mut response, self.protocol.socket_timeout) {
299                        Ok(bytes_read) if bytes_read >= 8 => {
300                            // Validate response
301                            match self.protocol.validate_response(&response, request_id) {
302                                Ok(_) => return Ok(response),
303                                Err(e) => {
304                                    log::warn!("Invalid response: {e}");
305                                }
306                            }
307                        }
308                        Ok(_) => {
309                            log::warn!("Incomplete response received");
310                        }
311                        Err(e) => {
312                            log::warn!("Network receive error: {e}");
313                        }
314                    }
315                }
316                Err(e) => {
317                    log::warn!("Network send error: {e}");
318                }
319            }
320
321            retries += 1;
322        }
323
324        Err(PoKeysError::Transfer(
325            "Failed to send network request".to_string(),
326        ))
327    }
328
329    /// Send request without expecting a response
330    pub fn send_request_no_response<T: UsbHidInterface + ?Sized>(
331        &mut self,
332        interface: &mut T,
333        request_type: u8,
334        param1: u8,
335        param2: u8,
336        param3: u8,
337        param4: u8,
338    ) -> Result<()> {
339        let request =
340            self.protocol
341                .prepare_request(request_type, param1, param2, param3, param4, None);
342
343        // Prepare HID packet
344        let mut hid_packet = [0u8; 65];
345        hid_packet[1..65].copy_from_slice(&request[..64]);
346
347        interface.write(&hid_packet)?;
348        Ok(())
349    }
350
351    /// Send multi-part request for large data transfers
352    pub fn send_multipart_request<T: UsbHidInterface + ?Sized>(
353        &mut self,
354        interface: &mut T,
355        request_type: u8,
356        data: &[u8],
357    ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
358        // Implementation for multi-part data transfer
359        // This would be used for large data transfers like motion buffer updates
360
361        let request = self
362            .protocol
363            .prepare_request(request_type, 0, 0, 0, 0, None);
364        let request_id = request[6];
365
366        // Send initial request
367        let mut hid_packet = [0u8; 65];
368        hid_packet[1..65].copy_from_slice(&request[..64]);
369        interface.write(&hid_packet)?;
370
371        // Send data in chunks
372        for chunk in data.chunks(64) {
373            let mut data_packet = [0u8; 65];
374            data_packet[1..chunk.len() + 1].copy_from_slice(chunk);
375            interface.write(&data_packet)?;
376        }
377
378        // Receive response
379        let mut response = [0u8; RESPONSE_BUFFER_SIZE];
380        interface.read_timeout(&mut response, self.protocol.socket_timeout)?;
381
382        self.protocol.validate_response(&response, request_id)?;
383        Ok(response)
384    }
385
386    /// Send raw request via USB HID interface (for requests with data payloads)
387    pub fn send_usb_request_raw<T: UsbHidInterface + ?Sized>(
388        &mut self,
389        interface: &mut T,
390        request: &[u8; REQUEST_BUFFER_SIZE],
391    ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
392        let request_id = request[6];
393
394        let mut retries = 0;
395        while retries < self.protocol.send_retries {
396            // Prepare HID packet (add report ID byte at the beginning)
397            let mut hid_packet = [0u8; 65];
398            hid_packet[1..65].copy_from_slice(&request[..64]);
399
400            // Send request
401            match interface.write(&hid_packet) {
402                Ok(_) => {
403                    // Try to receive response
404                    let mut response = [0u8; RESPONSE_BUFFER_SIZE];
405                    let mut wait_count = 0;
406
407                    while wait_count < 50 {
408                        match interface.read_timeout(&mut response, Duration::from_millis(20)) {
409                            Ok(bytes_read) if bytes_read > 0 => {
410                                // Validate response
411                                match self.protocol.validate_response(&response, request_id) {
412                                    Ok(_) => return Ok(response),
413                                    Err(e) => {
414                                        log::warn!("Invalid response: {e}");
415                                        break;
416                                    }
417                                }
418                            }
419                            Ok(_) => {
420                                // No data received, continue waiting
421                                wait_count += 1;
422                            }
423                            Err(e) => {
424                                log::warn!("Read error: {e}");
425                                break;
426                            }
427                        }
428                    }
429                }
430                Err(e) => {
431                    log::warn!("Write error: {e}");
432                }
433            }
434
435            retries += 1;
436        }
437
438        Err(PoKeysError::Transfer(
439            "Failed to send USB request".to_string(),
440        ))
441    }
442
443    /// Send raw request via network interface (for requests with data payloads)
444    pub fn send_network_request_raw<T: NetworkInterface + ?Sized>(
445        &mut self,
446        interface: &mut T,
447        request: &[u8; REQUEST_BUFFER_SIZE],
448    ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
449        let request_id = request[6];
450
451        let mut retries = 0;
452        while retries < self.protocol.send_retries {
453            match interface.send(&request[..64]) {
454                Ok(_) => {
455                    let mut response = [0u8; RESPONSE_BUFFER_SIZE];
456                    match interface.receive(&mut response) {
457                        Ok(_) => match self.protocol.validate_response(&response, request_id) {
458                            Ok(_) => return Ok(response),
459                            Err(e) => {
460                                log::warn!("Invalid response: {e}");
461                            }
462                        },
463                        Err(e) => {
464                            log::warn!("Network receive error: {e}");
465                        }
466                    }
467                }
468                Err(e) => {
469                    log::warn!("Network send error: {e}");
470                }
471            }
472
473            retries += 1;
474        }
475
476        Err(PoKeysError::Transfer(
477            "Failed to send network request".to_string(),
478        ))
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_checksum_calculation() {
488        let data = [0xBB, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
489        let checksum = Protocol::calculate_checksum(&data);
490        let expected = 0xBB + 0x01 + 0x02 + 0x03 + 0x04 + 0x05 + 0x06;
491        assert_eq!(checksum, expected as u8);
492    }
493
494    #[test]
495    fn test_request_preparation() {
496        let mut protocol = Protocol::new();
497        let request = protocol.prepare_request(0x10, 0x20, 0x30, 0x40, 0x50, None);
498
499        assert_eq!(request[0], REQUEST_HEADER);
500        assert_eq!(request[1], 0x10);
501        assert_eq!(request[2], 0x20);
502        assert_eq!(request[3], 0x30);
503        assert_eq!(request[4], 0x40);
504        assert_eq!(request[5], 0x50);
505        assert_eq!(request[6], 1); // First request ID
506
507        let expected_checksum = Protocol::calculate_checksum(&request);
508        assert_eq!(request[7], expected_checksum);
509    }
510
511    #[test]
512    fn test_response_validation() {
513        let protocol = Protocol::new();
514        let mut response = [0u8; RESPONSE_BUFFER_SIZE];
515        response[0] = RESPONSE_HEADER;
516        response[6] = 1; // Request ID
517        response[7] = Protocol::calculate_checksum(&response);
518
519        assert!(protocol.validate_response(&response, 1).is_ok());
520        assert!(protocol.validate_response(&response, 2).is_err()); // Wrong request ID
521
522        response[7] = 0xFF; // Wrong checksum
523        assert!(protocol.validate_response(&response, 1).is_err());
524    }
525
526    #[test]
527    fn test_reboot_request_format() {
528        // "Reboot system" command, per PoKeys protocol spec:
529        //   byte 1 (header) = 0xBB
530        //   byte 2 (CMD)    = 0xF3
531        //   bytes 3-6       = reserved (0)
532        //   byte 7          = request ID
533        //   byte 8          = checksum of bytes 1-7
534        let mut protocol = Protocol::new();
535        let request = protocol.prepare_request(0xF3, 0, 0, 0, 0, None);
536
537        assert_eq!(request[0], REQUEST_HEADER);
538        assert_eq!(request[1], 0xF3);
539        assert_eq!(request[2], 0);
540        assert_eq!(request[3], 0);
541        assert_eq!(request[4], 0);
542        assert_eq!(request[5], 0);
543        assert_eq!(request[6], 1);
544        assert_eq!(request[7], Protocol::calculate_checksum(&request));
545
546        // Payload bytes (9-64 in 1-based spec numbering) are unused.
547        for i in 8..REQUEST_BUFFER_SIZE {
548            assert_eq!(request[i], 0);
549        }
550    }
551}