1use crate::communication::NetworkInterface;
4use crate::error::{PoKeysError, Result};
5use crate::types::{NetworkDeviceInfo, NetworkDeviceSummary};
6use std::collections::HashSet;
7use std::io::{Read, Write};
8use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket};
9use std::time::{Duration, Instant};
10
11pub struct UdpNetworkInterface {
13 socket: UdpSocket,
14 remote_addr: SocketAddr,
15}
16
17impl UdpNetworkInterface {
18 pub fn new(remote_ip: [u8; 4], remote_port: u16) -> Result<Self> {
19 let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
20 let socket = UdpSocket::bind(local_addr)?;
21
22 let remote_addr = SocketAddr::new(
23 IpAddr::V4(Ipv4Addr::new(
24 remote_ip[0],
25 remote_ip[1],
26 remote_ip[2],
27 remote_ip[3],
28 )),
29 remote_port,
30 );
31
32 Ok(Self {
33 socket,
34 remote_addr,
35 })
36 }
37}
38
39impl NetworkInterface for UdpNetworkInterface {
40 fn send(&mut self, data: &[u8]) -> Result<usize> {
41 self.socket
42 .send_to(data, self.remote_addr)
43 .map_err(PoKeysError::Io)
44 }
45
46 fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
47 let (bytes_received, _) = self.socket.recv_from(buffer)?;
48 Ok(bytes_received)
49 }
50
51 fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
52 self.socket.set_read_timeout(Some(timeout))?;
53 let result = self.receive(buffer);
54 self.socket.set_read_timeout(None)?;
55 result
56 }
57}
58
59pub struct TcpNetworkInterface {
61 stream: TcpStream,
62}
63
64impl TcpNetworkInterface {
65 pub fn new(remote_ip: [u8; 4], remote_port: u16) -> Result<Self> {
66 let remote_addr = SocketAddr::new(
67 IpAddr::V4(Ipv4Addr::new(
68 remote_ip[0],
69 remote_ip[1],
70 remote_ip[2],
71 remote_ip[3],
72 )),
73 remote_port,
74 );
75
76 let stream = TcpStream::connect(remote_addr)?;
77
78 Ok(Self { stream })
79 }
80}
81
82impl NetworkInterface for TcpNetworkInterface {
83 fn send(&mut self, data: &[u8]) -> Result<usize> {
84 self.stream.write(data).map_err(PoKeysError::Io)
85 }
86
87 fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
88 self.stream.read(buffer).map_err(PoKeysError::Io)
89 }
90
91 fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
92 self.stream.set_read_timeout(Some(timeout))?;
93 let result = self.receive(buffer);
94 self.stream.set_read_timeout(None)?;
95 result
96 }
97}
98
99pub struct NetworkDiscovery {
101 socket: UdpSocket,
102}
103
104impl NetworkDiscovery {
105 pub fn new() -> Result<Self> {
106 let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
108 let socket = UdpSocket::bind(local_addr)?;
109 socket.set_broadcast(true)?;
110
111 Ok(Self { socket })
112 }
113
114 pub fn discover_devices(&self, timeout_ms: u32) -> Result<Vec<NetworkDeviceSummary>> {
116 let mut devices = Vec::new();
117 let mut seen_serials = HashSet::new();
118
119 let broadcast_addresses = self.get_broadcast_addresses()?;
121
122 let discovery_packet = self.create_discovery_packet();
124
125 for &broadcast_addr in &broadcast_addresses {
126 let addr = SocketAddr::new(IpAddr::V4(broadcast_addr), 20055);
127 log::debug!("Sending discovery packet to {addr}");
128
129 if let Err(e) = self.socket.send_to(&discovery_packet, addr) {
130 log::warn!("Failed to send discovery packet to {addr}: {e}");
131 continue;
132 }
133 }
134
135 let general_broadcast =
137 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 20055);
138 if let Err(e) = self.socket.send_to(&discovery_packet, general_broadcast) {
139 log::warn!("Failed to send general broadcast: {e}");
140 }
141
142 let start_time = Instant::now();
144 let timeout = Duration::from_millis(timeout_ms as u64);
145
146 self.socket
148 .set_read_timeout(Some(Duration::from_millis(100)))?;
149
150 while start_time.elapsed() < timeout {
151 let mut buffer = [0u8; 1024];
152 match self.socket.recv_from(&mut buffer) {
153 Ok((bytes_received, sender_addr)) => {
154 log::debug!("Received {bytes_received} bytes from {sender_addr}");
155
156 if let Some(device) =
157 self.parse_discovery_response(&buffer[..bytes_received], sender_addr)
158 {
159 if seen_serials.insert(device.serial_number) {
161 log::debug!(
162 "Discovered PoKeys device: Serial {}, IP {}.{}.{}.{}, FW {}.{}",
163 device.serial_number,
164 device.ip_address[0],
165 device.ip_address[1],
166 device.ip_address[2],
167 device.ip_address[3],
168 device.firmware_version_major,
169 device.firmware_version_minor
170 );
171 devices.push(device);
172 }
173 }
174 }
175 Err(ref e)
176 if e.kind() == std::io::ErrorKind::WouldBlock
177 || e.kind() == std::io::ErrorKind::TimedOut =>
178 {
179 continue;
181 }
182 Err(e) => {
183 log::warn!("Error receiving discovery response: {e}");
184 continue;
185 }
186 }
187 }
188
189 log::info!(
190 "Network discovery completed, found {} devices",
191 devices.len()
192 );
193 Ok(devices)
194 }
195
196 fn get_broadcast_addresses(&self) -> Result<Vec<Ipv4Addr>> {
198 let mut addresses = Vec::new();
199
200 addresses.push(Ipv4Addr::new(255, 255, 255, 255)); addresses.push(Ipv4Addr::new(192, 168, 1, 255)); addresses.push(Ipv4Addr::new(192, 168, 0, 255)); addresses.push(Ipv4Addr::new(10, 0, 1, 255)); addresses.push(Ipv4Addr::new(172, 16, 0, 255)); Ok(addresses)
211 }
212
213 pub fn search_device(
215 &self,
216 serial_number: u32,
217 timeout_ms: u32,
218 ) -> Result<Option<NetworkDeviceSummary>> {
219 let devices = self.discover_devices(timeout_ms)?;
220
221 for device in devices {
222 if device.serial_number == serial_number {
223 return Ok(Some(device));
224 }
225 }
226
227 Ok(None)
228 }
229
230 fn create_discovery_packet(&self) -> Vec<u8> {
231 Vec::new()
234 }
235
236 fn parse_discovery_response(
237 &self,
238 data: &[u8],
239 sender_addr: SocketAddr,
240 ) -> Option<NetworkDeviceSummary> {
241 if data.len() != 14 && data.len() != 19 {
243 return None;
244 }
245
246 let _sender_ip = match sender_addr.ip() {
247 IpAddr::V4(ipv4) => ipv4.octets(),
248 _ => return None,
249 };
250
251 if data.len() == 14 {
252 let user_id = data[0];
261 let serial_number = ((data[1] as u32) << 8) | (data[2] as u32);
262
263 let firmware_version_encoded = data[3];
265 let firmware_revision = data[4]; let major_bits = (firmware_version_encoded >> 4) & 0x0F; let minor_bits = firmware_version_encoded & 0x0F; let decoded_major = 1 + major_bits;
270 let decoded_minor = minor_bits;
271
272 let device_ip = [data[5], data[6], data[7], data[8]];
273 let dhcp = data[9];
274 let host_ip = [data[10], data[11], data[12], data[13]];
275
276 Some(NetworkDeviceSummary {
277 serial_number,
278 ip_address: device_ip,
279 host_ip,
280 firmware_version_major: decoded_major,
281 firmware_version_minor: decoded_minor,
282 firmware_revision,
283 user_id,
284 dhcp,
285 hw_type: 0, use_udp: 1, })
288 } else {
289 let user_id = data[0];
300
301 let firmware_version_encoded = data[3];
303 let firmware_revision = data[4]; let major_bits = (firmware_version_encoded >> 4) & 0x0F; let minor_bits = firmware_version_encoded & 0x0F; let decoded_major = 1 + major_bits;
308 let decoded_minor = minor_bits;
309
310 let device_ip = [data[5], data[6], data[7], data[8]];
311 let dhcp = data[9];
312 let host_ip = [data[10], data[11], data[12], data[13]];
313 let serial_number = ((data[17] as u32) << 24)
314 | ((data[16] as u32) << 16)
315 | ((data[15] as u32) << 8)
316 | (data[14] as u32);
317 let hw_type = data[18];
318
319 Some(NetworkDeviceSummary {
320 serial_number,
321 ip_address: device_ip,
322 host_ip,
323 firmware_version_major: decoded_major,
324 firmware_version_minor: decoded_minor,
325 firmware_revision,
326 user_id,
327 dhcp,
328 hw_type,
329 use_udp: 1, })
331 }
332 }
333}
334
335pub struct NetworkDeviceConfig {
337 pub device_info: NetworkDeviceInfo,
338}
339
340impl Default for NetworkDeviceConfig {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346impl NetworkDeviceConfig {
347 pub fn new() -> Self {
348 Self {
349 device_info: NetworkDeviceInfo {
350 ip_address_current: [0, 0, 0, 0],
351 ip_address_setup: [0, 0, 0, 0],
352 subnet_mask: [255, 255, 255, 0],
353 gateway_ip: [0, 0, 0, 0],
354 tcp_timeout: 1000,
355 additional_network_options: 0xA0,
356 dhcp: 0,
357 },
358 }
359 }
360
361 pub fn set_ip_address(&mut self, ip: [u8; 4]) {
363 self.device_info.ip_address_setup = ip;
364 }
365
366 pub fn set_subnet_mask(&mut self, mask: [u8; 4]) {
368 self.device_info.subnet_mask = mask;
369 }
370
371 pub fn set_default_gateway(&mut self, gateway: [u8; 4]) {
373 self.device_info.gateway_ip = gateway;
374 }
375
376 pub fn set_dhcp(&mut self, enable: bool) {
378 self.device_info.dhcp = if enable { 1 } else { 0 };
379 }
380
381 pub fn set_tcp_timeout(&mut self, timeout_ms: u16) {
383 self.device_info.tcp_timeout = timeout_ms;
384 }
385
386 pub fn set_network_options(
388 &mut self,
389 disable_discovery: bool,
390 disable_auto_config: bool,
391 disable_udp_config: bool,
392 ) {
393 let mut options = 0xA0u8; if disable_discovery {
396 options |= 0x01;
397 }
398 if disable_auto_config {
399 options |= 0x02;
400 }
401 if disable_udp_config {
402 options |= 0x04;
403 }
404
405 self.device_info.additional_network_options = options;
406 }
407}
408
409pub mod network_utils {
411 use super::*;
412
413 pub fn ip_to_string(ip: [u8; 4]) -> String {
415 format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3])
416 }
417
418 pub fn string_to_ip(ip_str: &str) -> Result<[u8; 4]> {
420 let parts: Vec<&str> = ip_str.split('.').collect();
421 if parts.len() != 4 {
422 return Err(PoKeysError::Parameter(
423 "Invalid IP address format".to_string(),
424 ));
425 }
426
427 let mut ip = [0u8; 4];
428 for (i, part) in parts.iter().enumerate() {
429 ip[i] = part
430 .parse::<u8>()
431 .map_err(|_| PoKeysError::Parameter("Invalid IP address octet".to_string()))?;
432 }
433
434 Ok(ip)
435 }
436
437 pub fn same_subnet(ip1: [u8; 4], ip2: [u8; 4], subnet_mask: [u8; 4]) -> bool {
439 for i in 0..4 {
440 if (ip1[i] & subnet_mask[i]) != (ip2[i] & subnet_mask[i]) {
441 return false;
442 }
443 }
444 true
445 }
446
447 pub fn network_address(ip: [u8; 4], subnet_mask: [u8; 4]) -> [u8; 4] {
449 [
450 ip[0] & subnet_mask[0],
451 ip[1] & subnet_mask[1],
452 ip[2] & subnet_mask[2],
453 ip[3] & subnet_mask[3],
454 ]
455 }
456
457 pub fn broadcast_address(ip: [u8; 4], subnet_mask: [u8; 4]) -> [u8; 4] {
459 [
460 ip[0] | (!subnet_mask[0]),
461 ip[1] | (!subnet_mask[1]),
462 ip[2] | (!subnet_mask[2]),
463 ip[3] | (!subnet_mask[3]),
464 ]
465 }
466}
467
468pub fn create_udp_connection(device: &NetworkDeviceSummary) -> Result<Box<dyn NetworkInterface>> {
472 let interface = UdpNetworkInterface::new(device.ip_address, 20055)?;
473 Ok(Box::new(interface))
474}
475
476pub fn create_tcp_connection(device: &NetworkDeviceSummary) -> Result<Box<dyn NetworkInterface>> {
478 let interface = TcpNetworkInterface::new(device.ip_address, 20055)?;
479 Ok(Box::new(interface))
480}
481
482pub fn discover_all_devices(timeout_ms: u32) -> Result<Vec<NetworkDeviceSummary>> {
484 let discovery = NetworkDiscovery::new()?;
485 discovery.discover_devices(timeout_ms)
486}
487
488pub fn find_device_by_serial(
490 serial_number: u32,
491 timeout_ms: u32,
492) -> Result<Option<NetworkDeviceSummary>> {
493 let discovery = NetworkDiscovery::new()?;
494 discovery.search_device(serial_number, timeout_ms)
495}
496
497#[cfg(test)]
498mod tests {
499 use super::network_utils::*;
500 use super::*;
501
502 #[test]
503 fn test_ip_string_conversion() {
504 let ip = [192, 168, 1, 100];
505 let ip_str = ip_to_string(ip);
506 assert_eq!(ip_str, "192.168.1.100");
507
508 let parsed_ip = string_to_ip(&ip_str).unwrap();
509 assert_eq!(parsed_ip, ip);
510 }
511
512 #[test]
513 fn test_invalid_ip_string() {
514 assert!(string_to_ip("192.168.1").is_err());
515 assert!(string_to_ip("192.168.1.256").is_err());
516 assert!(string_to_ip("not.an.ip.address").is_err());
517 }
518
519 #[test]
520 fn test_subnet_calculations() {
521 let ip1 = [192, 168, 1, 100];
522 let ip2 = [192, 168, 1, 200];
523 let ip3 = [192, 168, 2, 100];
524 let subnet_mask = [255, 255, 255, 0];
525
526 assert!(same_subnet(ip1, ip2, subnet_mask));
527 assert!(!same_subnet(ip1, ip3, subnet_mask));
528
529 let network = network_address(ip1, subnet_mask);
530 assert_eq!(network, [192, 168, 1, 0]);
531
532 let broadcast = broadcast_address(ip1, subnet_mask);
533 assert_eq!(broadcast, [192, 168, 1, 255]);
534 }
535
536 #[test]
537 fn test_network_device_config() {
538 let mut config = NetworkDeviceConfig::new();
539
540 config.set_ip_address([192, 168, 1, 100]);
541 assert_eq!(config.device_info.ip_address_setup, [192, 168, 1, 100]);
542
543 config.set_dhcp(true);
544 assert_eq!(config.device_info.dhcp, 1);
545
546 config.set_dhcp(false);
547 assert_eq!(config.device_info.dhcp, 0);
548
549 config.set_tcp_timeout(2000);
550 assert_eq!(config.device_info.tcp_timeout, 2000);
551 }
552
553 #[test]
554 fn test_network_options() {
555 let mut config = NetworkDeviceConfig::new();
556
557 config.set_network_options(true, false, false);
558 assert_eq!(config.device_info.additional_network_options & 0x01, 0x01);
559 assert_eq!(config.device_info.additional_network_options & 0x02, 0x00);
560
561 config.set_network_options(false, true, true);
562 assert_eq!(config.device_info.additional_network_options & 0x01, 0x00);
563 assert_eq!(config.device_info.additional_network_options & 0x02, 0x02);
564 assert_eq!(config.device_info.additional_network_options & 0x04, 0x04);
565 }
566
567 #[test]
568 fn test_discovery_packet_format() {
569 let discovery = NetworkDiscovery::new().unwrap();
570 let packet = discovery.create_discovery_packet();
571
572 assert_eq!(packet.len(), 0, "PoKeys discovery packet must be empty");
574 }
575
576 #[test]
577 fn test_discovery_response_parsing_14_bytes() {
578 let discovery = NetworkDiscovery::new().unwrap();
579
580 let response = [
582 0x01, 0x12, 0x34, 0x12, 0x05, 192, 168, 1, 100, 0x01, 192, 168, 1, 1, ];
589
590 let sender_addr = "192.168.1.100:20055".parse().unwrap();
591 let device = discovery
592 .parse_discovery_response(&response, sender_addr)
593 .unwrap();
594
595 assert_eq!(device.serial_number, 4660);
596 assert_eq!(device.firmware_version_major, 2); assert_eq!(device.firmware_version_minor, 2); assert_eq!(device.firmware_revision, 5); assert_eq!(device.ip_address, [192, 168, 1, 100]);
600 assert_eq!(device.dhcp, 1);
601 assert_eq!(device.host_ip, [192, 168, 1, 1]);
602 assert_eq!(device.hw_type, 0); }
604
605 #[test]
606 fn test_discovery_response_parsing_19_bytes() {
607 let discovery = NetworkDiscovery::new().unwrap();
608
609 let response = [
611 0x02, 0x00, 0x00, 0x21, 0x01, 192, 168, 1, 101, 0x00, 192, 168, 1, 1, 0x78, 0x56, 0x34, 0x12, 0x58, ];
620
621 let sender_addr = "192.168.1.101:20055".parse().unwrap();
622 let device = discovery
623 .parse_discovery_response(&response, sender_addr)
624 .unwrap();
625
626 assert_eq!(device.serial_number, 0x12345678);
627 assert_eq!(device.firmware_version_major, 3); assert_eq!(device.firmware_version_minor, 1); assert_eq!(device.firmware_revision, 1); assert_eq!(device.ip_address, [192, 168, 1, 101]);
631 assert_eq!(device.dhcp, 0);
632 assert_eq!(device.host_ip, [192, 168, 1, 1]);
633 assert_eq!(device.hw_type, 0x58);
634 }
635
636 #[test]
637 fn test_discovery_response_invalid_length() {
638 let discovery = NetworkDiscovery::new().unwrap();
639
640 let short_response = [0x01, 0x02, 0x03]; let long_response = [0u8; 25]; let sender_addr = "192.168.1.100:20055".parse().unwrap();
644
645 assert!(
646 discovery
647 .parse_discovery_response(&short_response, sender_addr)
648 .is_none()
649 );
650 assert!(
651 discovery
652 .parse_discovery_response(&long_response, sender_addr)
653 .is_none()
654 );
655 }
656}