1use crate::error::{PoKeysError, Result};
4use crate::types::*;
5use std::time::Duration;
6
7pub struct Protocol {
9 request_id: u8,
10 send_retries: u32,
11 read_retries: u32,
12 socket_timeout: Duration,
13}
14
15impl Default for Protocol {
16 fn default() -> Self {
17 Self {
18 request_id: 0,
19 send_retries: 3,
20 read_retries: 3,
21 socket_timeout: Duration::from_millis(1000),
22 }
23 }
24}
25
26impl Protocol {
27 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn set_retries_and_timeout(
32 &mut self,
33 send_retries: u32,
34 read_retries: u32,
35 timeout: Duration,
36 ) {
37 self.send_retries = send_retries;
38 self.read_retries = read_retries;
39 self.socket_timeout = timeout;
40 }
41
42 pub fn calculate_checksum(data: &[u8]) -> u8 {
44 data.iter()
45 .take(CHECKSUM_LENGTH)
46 .fold(0u8, |acc, &x| acc.wrapping_add(x))
47 }
48
49 pub fn prepare_request(
51 &mut self,
52 request_type: u8,
53 param1: u8,
54 param2: u8,
55 param3: u8,
56 param4: u8,
57 display: Option<bool>,
58 ) -> [u8; REQUEST_BUFFER_SIZE] {
59 let mut request = [0u8; REQUEST_BUFFER_SIZE];
60
61 request[0] = REQUEST_HEADER; request[1] = request_type;
63 request[2] = param1;
64 request[3] = param2;
65 request[4] = param3;
66 request[5] = param4;
67 request[6] = self.next_request_id();
68 request[7] = Self::calculate_checksum(&request);
69
70 if display.unwrap_or(false) {
71 println!("request: {request:02X?}");
72 }
73
74 request
75 }
76
77 pub fn validate_response(&self, response: &[u8], expected_request_id: u8) -> Result<()> {
79 if response.len() < 8 {
80 return Err(PoKeysError::Protocol("Response too short".to_string()));
81 }
82
83 if response[0] != RESPONSE_HEADER {
84 return Err(PoKeysError::Protocol("Invalid response header".to_string()));
85 }
86
87 if response[6] != expected_request_id {
88 return Err(PoKeysError::Protocol("Request ID mismatch".to_string()));
89 }
90
91 let expected_checksum = Self::calculate_checksum(response);
92 if response[7] != expected_checksum {
93 return Err(PoKeysError::InvalidChecksum);
94 }
95
96 Ok(())
97 }
98
99 fn next_request_id(&mut self) -> u8 {
100 self.request_id = self.request_id.wrapping_add(1);
101 self.request_id
102 }
103}
104
105pub trait UsbHidInterface {
107 fn write(&mut self, data: &[u8]) -> Result<usize>;
108 fn read(&mut self, buffer: &mut [u8]) -> Result<usize>;
109 fn read_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize>;
110}
111
112impl<T: UsbHidInterface + ?Sized> UsbHidInterface for Box<T> {
113 fn write(&mut self, data: &[u8]) -> Result<usize> {
114 (**self).write(data)
115 }
116
117 fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
118 (**self).read(buffer)
119 }
120
121 fn read_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
122 (**self).read_timeout(buffer, timeout)
123 }
124}
125
126pub trait NetworkInterface {
128 fn send(&mut self, data: &[u8]) -> Result<usize>;
129 fn receive(&mut self, buffer: &mut [u8]) -> Result<usize>;
130 fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize>;
131}
132
133impl<T: NetworkInterface + ?Sized> NetworkInterface for Box<T> {
134 fn send(&mut self, data: &[u8]) -> Result<usize> {
135 (**self).send(data)
136 }
137
138 fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
139 (**self).receive(buffer)
140 }
141
142 fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
143 (**self).receive_timeout(buffer, timeout)
144 }
145}
146
147#[allow(dead_code)]
149pub struct CommunicationManager {
150 protocol: Protocol,
151 connection_type: DeviceConnectionType,
152}
153
154impl CommunicationManager {
155 pub fn new(connection_type: DeviceConnectionType) -> Self {
156 Self {
157 protocol: Protocol::new(),
158 connection_type,
159 }
160 }
161
162 pub fn set_retries_and_timeout(
163 &mut self,
164 send_retries: u32,
165 read_retries: u32,
166 timeout: Duration,
167 ) {
168 self.protocol
169 .set_retries_and_timeout(send_retries, read_retries, timeout);
170 }
171
172 pub fn get_next_request_id(&mut self) -> u8 {
174 self.protocol.next_request_id()
175 }
176
177 pub fn prepare_request_with_data(
179 &mut self,
180 request_type: u8,
181 param1: u8,
182 param2: u8,
183 param3: u8,
184 param4: u8,
185 data: Option<&[u8]>,
186 ) -> [u8; REQUEST_BUFFER_SIZE] {
187 let mut request =
188 self.protocol
189 .prepare_request(request_type, param1, param2, param3, param4, None);
190
191 if let Some(payload) = data {
193 let data_len = std::cmp::min(payload.len(), 56); request[8..8 + data_len].copy_from_slice(&payload[0..data_len]);
195
196 request[7] = Protocol::calculate_checksum(&request);
198 }
199
200 request
201 }
202
203 pub fn validate_response(&self, response: &[u8], expected_request_id: u8) -> Result<()> {
205 self.protocol
206 .validate_response(response, expected_request_id)
207 }
208
209 pub fn send_usb_request<T: UsbHidInterface + ?Sized>(
211 &mut self,
212 interface: &mut T,
213 request_type: u8,
214 param1: u8,
215 param2: u8,
216 param3: u8,
217 param4: u8,
218 ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
219 let request =
220 self.protocol
221 .prepare_request(request_type, param1, param2, param3, param4, None);
222 let request_id = request[6];
223
224 let mut retries = 0;
225 while retries < self.protocol.send_retries {
226 let mut hid_packet = [0u8; 65];
228 hid_packet[1..65].copy_from_slice(&request[..64]);
229
230 match interface.write(&hid_packet) {
232 Ok(_) => {
233 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
235 let mut wait_count = 0;
236
237 while wait_count < 50 {
238 match interface.read_timeout(&mut response, Duration::from_millis(20)) {
239 Ok(bytes_read) if bytes_read > 0 => {
240 match self.protocol.validate_response(&response, request_id) {
242 Ok(_) => return Ok(response),
243 Err(e) => {
244 log::warn!("Invalid response: {e}");
245 break;
246 }
247 }
248 }
249 Ok(_) => {
250 wait_count += 1;
252 }
253 Err(e) => {
254 log::warn!("Read error: {e}");
255 break;
256 }
257 }
258 }
259 }
260 Err(e) => {
261 log::warn!("Write error: {e}");
262 }
263 }
264
265 retries += 1;
266 }
267
268 Err(PoKeysError::Transfer(
269 "Failed to send USB request".to_string(),
270 ))
271 }
272
273 pub fn send_network_request<T: NetworkInterface + ?Sized>(
275 &mut self,
276 interface: &mut T,
277 request_type: u8,
278 param1: u8,
279 param2: u8,
280 param3: u8,
281 param4: u8,
282 ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
283 let request =
284 self.protocol
285 .prepare_request(request_type, param1, param2, param3, param4, None);
286 let request_id = request[6];
287
288 let mut retries = 0;
291 while retries < self.protocol.send_retries {
292 match interface.send(&request[..64]) {
294 Ok(_) => {
295 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
297
298 match interface.receive_timeout(&mut response, self.protocol.socket_timeout) {
299 Ok(bytes_read) if bytes_read >= 8 => {
300 match self.protocol.validate_response(&response, request_id) {
302 Ok(_) => return Ok(response),
303 Err(e) => {
304 log::warn!("Invalid response: {e}");
305 }
306 }
307 }
308 Ok(_) => {
309 log::warn!("Incomplete response received");
310 }
311 Err(e) => {
312 log::warn!("Network receive error: {e}");
313 }
314 }
315 }
316 Err(e) => {
317 log::warn!("Network send error: {e}");
318 }
319 }
320
321 retries += 1;
322 }
323
324 Err(PoKeysError::Transfer(
325 "Failed to send network request".to_string(),
326 ))
327 }
328
329 pub fn send_request_no_response<T: UsbHidInterface + ?Sized>(
331 &mut self,
332 interface: &mut T,
333 request_type: u8,
334 param1: u8,
335 param2: u8,
336 param3: u8,
337 param4: u8,
338 ) -> Result<()> {
339 let request =
340 self.protocol
341 .prepare_request(request_type, param1, param2, param3, param4, None);
342
343 let mut hid_packet = [0u8; 65];
345 hid_packet[1..65].copy_from_slice(&request[..64]);
346
347 interface.write(&hid_packet)?;
348 Ok(())
349 }
350
351 pub fn send_multipart_request<T: UsbHidInterface + ?Sized>(
353 &mut self,
354 interface: &mut T,
355 request_type: u8,
356 data: &[u8],
357 ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
358 let request = self
362 .protocol
363 .prepare_request(request_type, 0, 0, 0, 0, None);
364 let request_id = request[6];
365
366 let mut hid_packet = [0u8; 65];
368 hid_packet[1..65].copy_from_slice(&request[..64]);
369 interface.write(&hid_packet)?;
370
371 for chunk in data.chunks(64) {
373 let mut data_packet = [0u8; 65];
374 data_packet[1..chunk.len() + 1].copy_from_slice(chunk);
375 interface.write(&data_packet)?;
376 }
377
378 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
380 interface.read_timeout(&mut response, self.protocol.socket_timeout)?;
381
382 self.protocol.validate_response(&response, request_id)?;
383 Ok(response)
384 }
385
386 pub fn send_usb_request_raw<T: UsbHidInterface + ?Sized>(
388 &mut self,
389 interface: &mut T,
390 request: &[u8; REQUEST_BUFFER_SIZE],
391 ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
392 let request_id = request[6];
393
394 let mut retries = 0;
395 while retries < self.protocol.send_retries {
396 let mut hid_packet = [0u8; 65];
398 hid_packet[1..65].copy_from_slice(&request[..64]);
399
400 match interface.write(&hid_packet) {
402 Ok(_) => {
403 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
405 let mut wait_count = 0;
406
407 while wait_count < 50 {
408 match interface.read_timeout(&mut response, Duration::from_millis(20)) {
409 Ok(bytes_read) if bytes_read > 0 => {
410 match self.protocol.validate_response(&response, request_id) {
412 Ok(_) => return Ok(response),
413 Err(e) => {
414 log::warn!("Invalid response: {e}");
415 break;
416 }
417 }
418 }
419 Ok(_) => {
420 wait_count += 1;
422 }
423 Err(e) => {
424 log::warn!("Read error: {e}");
425 break;
426 }
427 }
428 }
429 }
430 Err(e) => {
431 log::warn!("Write error: {e}");
432 }
433 }
434
435 retries += 1;
436 }
437
438 Err(PoKeysError::Transfer(
439 "Failed to send USB request".to_string(),
440 ))
441 }
442
443 pub fn send_network_request_raw<T: NetworkInterface + ?Sized>(
445 &mut self,
446 interface: &mut T,
447 request: &[u8; REQUEST_BUFFER_SIZE],
448 ) -> Result<[u8; RESPONSE_BUFFER_SIZE]> {
449 let request_id = request[6];
450
451 let mut retries = 0;
452 while retries < self.protocol.send_retries {
453 match interface.send(&request[..64]) {
454 Ok(_) => {
455 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
456 match interface.receive(&mut response) {
457 Ok(_) => match self.protocol.validate_response(&response, request_id) {
458 Ok(_) => return Ok(response),
459 Err(e) => {
460 log::warn!("Invalid response: {e}");
461 }
462 },
463 Err(e) => {
464 log::warn!("Network receive error: {e}");
465 }
466 }
467 }
468 Err(e) => {
469 log::warn!("Network send error: {e}");
470 }
471 }
472
473 retries += 1;
474 }
475
476 Err(PoKeysError::Transfer(
477 "Failed to send network request".to_string(),
478 ))
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_checksum_calculation() {
488 let data = [0xBB, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
489 let checksum = Protocol::calculate_checksum(&data);
490 let expected = 0xBB + 0x01 + 0x02 + 0x03 + 0x04 + 0x05 + 0x06;
491 assert_eq!(checksum, expected as u8);
492 }
493
494 #[test]
495 fn test_request_preparation() {
496 let mut protocol = Protocol::new();
497 let request = protocol.prepare_request(0x10, 0x20, 0x30, 0x40, 0x50, None);
498
499 assert_eq!(request[0], REQUEST_HEADER);
500 assert_eq!(request[1], 0x10);
501 assert_eq!(request[2], 0x20);
502 assert_eq!(request[3], 0x30);
503 assert_eq!(request[4], 0x40);
504 assert_eq!(request[5], 0x50);
505 assert_eq!(request[6], 1); let expected_checksum = Protocol::calculate_checksum(&request);
508 assert_eq!(request[7], expected_checksum);
509 }
510
511 #[test]
512 fn test_response_validation() {
513 let protocol = Protocol::new();
514 let mut response = [0u8; RESPONSE_BUFFER_SIZE];
515 response[0] = RESPONSE_HEADER;
516 response[6] = 1; response[7] = Protocol::calculate_checksum(&response);
518
519 assert!(protocol.validate_response(&response, 1).is_ok());
520 assert!(protocol.validate_response(&response, 2).is_err()); response[7] = 0xFF; assert!(protocol.validate_response(&response, 1).is_err());
524 }
525
526 #[test]
527 fn test_reboot_request_format() {
528 let mut protocol = Protocol::new();
535 let request = protocol.prepare_request(0xF3, 0, 0, 0, 0, None);
536
537 assert_eq!(request[0], REQUEST_HEADER);
538 assert_eq!(request[1], 0xF3);
539 assert_eq!(request[2], 0);
540 assert_eq!(request[3], 0);
541 assert_eq!(request[4], 0);
542 assert_eq!(request[5], 0);
543 assert_eq!(request[6], 1);
544 assert_eq!(request[7], Protocol::calculate_checksum(&request));
545
546 for i in 8..REQUEST_BUFFER_SIZE {
548 assert_eq!(request[i], 0);
549 }
550 }
551}