Skip to main content

pokeys_lib/
network.rs

1//! Network device support and discovery
2
3use crate::communication::NetworkInterface;
4use crate::error::{PoKeysError, Result};
5use crate::types::{NetworkDeviceInfo, NetworkDeviceSummary};
6use std::collections::HashSet;
7use std::io::{Read, Write};
8use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket};
9use std::time::{Duration, Instant};
10
11/// UDP network interface implementation
12pub struct UdpNetworkInterface {
13    socket: UdpSocket,
14    remote_addr: SocketAddr,
15}
16
17impl UdpNetworkInterface {
18    pub fn new(remote_ip: [u8; 4], remote_port: u16) -> Result<Self> {
19        let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
20        let socket = UdpSocket::bind(local_addr)?;
21
22        let remote_addr = SocketAddr::new(
23            IpAddr::V4(Ipv4Addr::new(
24                remote_ip[0],
25                remote_ip[1],
26                remote_ip[2],
27                remote_ip[3],
28            )),
29            remote_port,
30        );
31
32        Ok(Self {
33            socket,
34            remote_addr,
35        })
36    }
37}
38
39impl NetworkInterface for UdpNetworkInterface {
40    fn send(&mut self, data: &[u8]) -> Result<usize> {
41        self.socket
42            .send_to(data, self.remote_addr)
43            .map_err(PoKeysError::Io)
44    }
45
46    fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
47        let (bytes_received, _) = self.socket.recv_from(buffer)?;
48        Ok(bytes_received)
49    }
50
51    fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
52        self.socket.set_read_timeout(Some(timeout))?;
53        let result = self.receive(buffer);
54        self.socket.set_read_timeout(None)?;
55        result
56    }
57}
58
59/// TCP network interface implementation
60pub struct TcpNetworkInterface {
61    stream: TcpStream,
62}
63
64impl TcpNetworkInterface {
65    pub fn new(remote_ip: [u8; 4], remote_port: u16) -> Result<Self> {
66        let remote_addr = SocketAddr::new(
67            IpAddr::V4(Ipv4Addr::new(
68                remote_ip[0],
69                remote_ip[1],
70                remote_ip[2],
71                remote_ip[3],
72            )),
73            remote_port,
74        );
75
76        let stream = TcpStream::connect(remote_addr)?;
77
78        Ok(Self { stream })
79    }
80}
81
82impl NetworkInterface for TcpNetworkInterface {
83    fn send(&mut self, data: &[u8]) -> Result<usize> {
84        self.stream.write(data).map_err(PoKeysError::Io)
85    }
86
87    fn receive(&mut self, buffer: &mut [u8]) -> Result<usize> {
88        self.stream.read(buffer).map_err(PoKeysError::Io)
89    }
90
91    fn receive_timeout(&mut self, buffer: &mut [u8], timeout: Duration) -> Result<usize> {
92        self.stream.set_read_timeout(Some(timeout))?;
93        let result = self.receive(buffer);
94        self.stream.set_read_timeout(None)?;
95        result
96    }
97}
98
99/// Network device discovery
100pub struct NetworkDiscovery {
101    socket: UdpSocket,
102}
103
104impl NetworkDiscovery {
105    pub fn new() -> Result<Self> {
106        // Bind to any available port for sending
107        let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
108        let socket = UdpSocket::bind(local_addr)?;
109        socket.set_broadcast(true)?;
110
111        Ok(Self { socket })
112    }
113
114    /// Discover PoKeys devices on the network
115    pub fn discover_devices(&self, timeout_ms: u32) -> Result<Vec<NetworkDeviceSummary>> {
116        let mut devices = Vec::new();
117        let mut seen_serials = HashSet::new();
118
119        // Get broadcast addresses to try
120        let broadcast_addresses = self.get_broadcast_addresses()?;
121
122        // Send discovery packets to all broadcast addresses
123        let discovery_packet = self.create_discovery_packet();
124
125        for &broadcast_addr in &broadcast_addresses {
126            let addr = SocketAddr::new(IpAddr::V4(broadcast_addr), 20055);
127            log::debug!("Sending discovery packet to {addr}");
128
129            if let Err(e) = self.socket.send_to(&discovery_packet, addr) {
130                log::warn!("Failed to send discovery packet to {addr}: {e}");
131                continue;
132            }
133        }
134
135        // Also send to general broadcast
136        let general_broadcast =
137            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 20055);
138        if let Err(e) = self.socket.send_to(&discovery_packet, general_broadcast) {
139            log::warn!("Failed to send general broadcast: {e}");
140        }
141
142        // Listen for responses
143        let start_time = Instant::now();
144        let timeout = Duration::from_millis(timeout_ms as u64);
145
146        // Set a short read timeout to allow checking for overall timeout
147        self.socket
148            .set_read_timeout(Some(Duration::from_millis(100)))?;
149
150        while start_time.elapsed() < timeout {
151            let mut buffer = [0u8; 1024];
152            match self.socket.recv_from(&mut buffer) {
153                Ok((bytes_received, sender_addr)) => {
154                    log::debug!("Received {bytes_received} bytes from {sender_addr}");
155
156                    if let Some(device) =
157                        self.parse_discovery_response(&buffer[..bytes_received], sender_addr)
158                    {
159                        // Avoid duplicate devices (same serial number)
160                        if seen_serials.insert(device.serial_number) {
161                            log::debug!(
162                                "Discovered PoKeys device: Serial {}, IP {}.{}.{}.{}, FW {}.{}",
163                                device.serial_number,
164                                device.ip_address[0],
165                                device.ip_address[1],
166                                device.ip_address[2],
167                                device.ip_address[3],
168                                device.firmware_version_major,
169                                device.firmware_version_minor
170                            );
171                            devices.push(device);
172                        }
173                    }
174                }
175                Err(ref e)
176                    if e.kind() == std::io::ErrorKind::WouldBlock
177                        || e.kind() == std::io::ErrorKind::TimedOut =>
178                {
179                    // Timeout, continue listening
180                    continue;
181                }
182                Err(e) => {
183                    log::warn!("Error receiving discovery response: {e}");
184                    continue;
185                }
186            }
187        }
188
189        log::info!(
190            "Network discovery completed, found {} devices",
191            devices.len()
192        );
193        Ok(devices)
194    }
195
196    /// Get broadcast addresses for all network interfaces
197    fn get_broadcast_addresses(&self) -> Result<Vec<Ipv4Addr>> {
198        let mut addresses = Vec::new();
199
200        // Add common broadcast addresses
201        addresses.push(Ipv4Addr::new(255, 255, 255, 255)); // General broadcast
202        addresses.push(Ipv4Addr::new(192, 168, 1, 255)); // Common home network
203        addresses.push(Ipv4Addr::new(192, 168, 0, 255)); // Common home network
204        addresses.push(Ipv4Addr::new(10, 0, 1, 255)); // Common corporate network
205        addresses.push(Ipv4Addr::new(172, 16, 0, 255)); // Common corporate network
206
207        // TODO: In a more complete implementation, we would enumerate actual network interfaces
208        // and calculate their broadcast addresses. For now, we use common ones.
209
210        Ok(addresses)
211    }
212
213    /// Search for specific device by serial number
214    pub fn search_device(
215        &self,
216        serial_number: u32,
217        timeout_ms: u32,
218    ) -> Result<Option<NetworkDeviceSummary>> {
219        let devices = self.discover_devices(timeout_ms)?;
220
221        for device in devices {
222            if device.serial_number == serial_number {
223                return Ok(Some(device));
224            }
225        }
226
227        Ok(None)
228    }
229
230    fn create_discovery_packet(&self) -> Vec<u8> {
231        // PoKeys network discovery uses an empty UDP packet
232        // The presence of any UDP packet on port 20055 triggers the device to respond
233        Vec::new()
234    }
235
236    fn parse_discovery_response(
237        &self,
238        data: &[u8],
239        sender_addr: SocketAddr,
240    ) -> Option<NetworkDeviceSummary> {
241        // PoKeys devices respond with either 14 bytes (older devices) or 19 bytes (58 series)
242        if data.len() != 14 && data.len() != 19 {
243            return None;
244        }
245
246        let _sender_ip = match sender_addr.ip() {
247            IpAddr::V4(ipv4) => ipv4.octets(),
248            _ => return None,
249        };
250
251        if data.len() == 14 {
252            // Older device format (14 bytes)
253            // Byte 0: User ID
254            // Bytes 1-2: Serial number (16-bit, big-endian)
255            // Bytes 3-4: Firmware version (major, minor)
256            // Bytes 5-8: IP address (from device response, not sender)
257            // Byte 9: DHCP flag
258            // Bytes 10-13: Host IP address
259
260            let user_id = data[0];
261            let serial_number = ((data[1] as u32) << 8) | (data[2] as u32);
262
263            // Decode firmware version: v(1+[bits 4-7]).(bits [0-3])
264            let firmware_version_encoded = data[3];
265            let firmware_revision = data[4]; // This might be revision or minor version
266
267            let major_bits = (firmware_version_encoded >> 4) & 0x0F; // Extract bits 4-7
268            let minor_bits = firmware_version_encoded & 0x0F; // Extract bits 0-3
269            let decoded_major = 1 + major_bits;
270            let decoded_minor = minor_bits;
271
272            let device_ip = [data[5], data[6], data[7], data[8]];
273            let dhcp = data[9];
274            let host_ip = [data[10], data[11], data[12], data[13]];
275
276            Some(NetworkDeviceSummary {
277                serial_number,
278                ip_address: device_ip,
279                host_ip,
280                firmware_version_major: decoded_major,
281                firmware_version_minor: decoded_minor,
282                firmware_revision,
283                user_id,
284                dhcp,
285                hw_type: 0, // Not available in 14-byte format
286                use_udp: 1, // Assume UDP for older devices
287            })
288        } else {
289            // 58 series device format (19 bytes)
290            // Byte 0: User ID
291            // Bytes 1-2: (unused in serial parsing)
292            // Bytes 3-4: Firmware version (encoded major, revision/minor)
293            // Bytes 5-8: IP address (from device response, not sender)
294            // Byte 9: DHCP flag
295            // Bytes 10-13: Host IP address
296            // Bytes 14-17: Serial number (32-bit, little-endian)
297            // Byte 18: Hardware type
298
299            let user_id = data[0];
300
301            // Decode firmware version: v(1+[bits 4-7]).(bits [0-3])
302            let firmware_version_encoded = data[3];
303            let firmware_revision = data[4]; // This might be revision or minor version
304
305            let major_bits = (firmware_version_encoded >> 4) & 0x0F; // Extract bits 4-7
306            let minor_bits = firmware_version_encoded & 0x0F; // Extract bits 0-3
307            let decoded_major = 1 + major_bits;
308            let decoded_minor = minor_bits;
309
310            let device_ip = [data[5], data[6], data[7], data[8]];
311            let dhcp = data[9];
312            let host_ip = [data[10], data[11], data[12], data[13]];
313            let serial_number = ((data[17] as u32) << 24)
314                | ((data[16] as u32) << 16)
315                | ((data[15] as u32) << 8)
316                | (data[14] as u32);
317            let hw_type = data[18];
318
319            Some(NetworkDeviceSummary {
320                serial_number,
321                ip_address: device_ip,
322                host_ip,
323                firmware_version_major: decoded_major,
324                firmware_version_minor: decoded_minor,
325                firmware_revision,
326                user_id,
327                dhcp,
328                hw_type,
329                use_udp: 1, // Default to UDP, could be determined by device type
330            })
331        }
332    }
333}
334
335/// Network device configuration
336pub struct NetworkDeviceConfig {
337    pub device_info: NetworkDeviceInfo,
338}
339
340impl Default for NetworkDeviceConfig {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346impl NetworkDeviceConfig {
347    pub fn new() -> Self {
348        Self {
349            device_info: NetworkDeviceInfo {
350                ip_address_current: [0, 0, 0, 0],
351                ip_address_setup: [0, 0, 0, 0],
352                subnet_mask: [255, 255, 255, 0],
353                gateway_ip: [0, 0, 0, 0],
354                tcp_timeout: 1000,
355                additional_network_options: 0xA0,
356                dhcp: 0,
357            },
358        }
359    }
360
361    /// Configure device IP address
362    pub fn set_ip_address(&mut self, ip: [u8; 4]) {
363        self.device_info.ip_address_setup = ip;
364    }
365
366    /// Configure subnet mask
367    pub fn set_subnet_mask(&mut self, mask: [u8; 4]) {
368        self.device_info.subnet_mask = mask;
369    }
370
371    /// Configure default gateway
372    pub fn set_default_gateway(&mut self, gateway: [u8; 4]) {
373        self.device_info.gateway_ip = gateway;
374    }
375
376    /// Enable/disable DHCP
377    pub fn set_dhcp(&mut self, enable: bool) {
378        self.device_info.dhcp = if enable { 1 } else { 0 };
379    }
380
381    /// Set TCP timeout
382    pub fn set_tcp_timeout(&mut self, timeout_ms: u16) {
383        self.device_info.tcp_timeout = timeout_ms;
384    }
385
386    /// Configure network options
387    pub fn set_network_options(
388        &mut self,
389        disable_discovery: bool,
390        disable_auto_config: bool,
391        disable_udp_config: bool,
392    ) {
393        let mut options = 0xA0u8; // Base value
394
395        if disable_discovery {
396            options |= 0x01;
397        }
398        if disable_auto_config {
399            options |= 0x02;
400        }
401        if disable_udp_config {
402            options |= 0x04;
403        }
404
405        self.device_info.additional_network_options = options;
406    }
407}
408
409/// Network utilities
410pub mod network_utils {
411    use super::*;
412
413    /// Convert IP address from bytes to string
414    pub fn ip_to_string(ip: [u8; 4]) -> String {
415        format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3])
416    }
417
418    /// Convert IP address from string to bytes
419    pub fn string_to_ip(ip_str: &str) -> Result<[u8; 4]> {
420        let parts: Vec<&str> = ip_str.split('.').collect();
421        if parts.len() != 4 {
422            return Err(PoKeysError::Parameter(
423                "Invalid IP address format".to_string(),
424            ));
425        }
426
427        let mut ip = [0u8; 4];
428        for (i, part) in parts.iter().enumerate() {
429            ip[i] = part
430                .parse::<u8>()
431                .map_err(|_| PoKeysError::Parameter("Invalid IP address octet".to_string()))?;
432        }
433
434        Ok(ip)
435    }
436
437    /// Check if IP address is in the same subnet
438    pub fn same_subnet(ip1: [u8; 4], ip2: [u8; 4], subnet_mask: [u8; 4]) -> bool {
439        for i in 0..4 {
440            if (ip1[i] & subnet_mask[i]) != (ip2[i] & subnet_mask[i]) {
441                return false;
442            }
443        }
444        true
445    }
446
447    /// Calculate network address
448    pub fn network_address(ip: [u8; 4], subnet_mask: [u8; 4]) -> [u8; 4] {
449        [
450            ip[0] & subnet_mask[0],
451            ip[1] & subnet_mask[1],
452            ip[2] & subnet_mask[2],
453            ip[3] & subnet_mask[3],
454        ]
455    }
456
457    /// Calculate broadcast address
458    pub fn broadcast_address(ip: [u8; 4], subnet_mask: [u8; 4]) -> [u8; 4] {
459        [
460            ip[0] | (!subnet_mask[0]),
461            ip[1] | (!subnet_mask[1]),
462            ip[2] | (!subnet_mask[2]),
463            ip[3] | (!subnet_mask[3]),
464        ]
465    }
466}
467
468// Convenience functions for network operations
469
470/// Create UDP connection to PoKeys device
471pub fn create_udp_connection(device: &NetworkDeviceSummary) -> Result<Box<dyn NetworkInterface>> {
472    let interface = UdpNetworkInterface::new(device.ip_address, 20055)?;
473    Ok(Box::new(interface))
474}
475
476/// Create TCP connection to PoKeys device
477pub fn create_tcp_connection(device: &NetworkDeviceSummary) -> Result<Box<dyn NetworkInterface>> {
478    let interface = TcpNetworkInterface::new(device.ip_address, 20055)?;
479    Ok(Box::new(interface))
480}
481
482/// Discover all PoKeys devices on network
483pub fn discover_all_devices(timeout_ms: u32) -> Result<Vec<NetworkDeviceSummary>> {
484    let discovery = NetworkDiscovery::new()?;
485    discovery.discover_devices(timeout_ms)
486}
487
488/// Find specific PoKeys device by serial number
489pub fn find_device_by_serial(
490    serial_number: u32,
491    timeout_ms: u32,
492) -> Result<Option<NetworkDeviceSummary>> {
493    let discovery = NetworkDiscovery::new()?;
494    discovery.search_device(serial_number, timeout_ms)
495}
496
497#[cfg(test)]
498mod tests {
499    use super::network_utils::*;
500    use super::*;
501
502    #[test]
503    fn test_ip_string_conversion() {
504        let ip = [192, 168, 1, 100];
505        let ip_str = ip_to_string(ip);
506        assert_eq!(ip_str, "192.168.1.100");
507
508        let parsed_ip = string_to_ip(&ip_str).unwrap();
509        assert_eq!(parsed_ip, ip);
510    }
511
512    #[test]
513    fn test_invalid_ip_string() {
514        assert!(string_to_ip("192.168.1").is_err());
515        assert!(string_to_ip("192.168.1.256").is_err());
516        assert!(string_to_ip("not.an.ip.address").is_err());
517    }
518
519    #[test]
520    fn test_subnet_calculations() {
521        let ip1 = [192, 168, 1, 100];
522        let ip2 = [192, 168, 1, 200];
523        let ip3 = [192, 168, 2, 100];
524        let subnet_mask = [255, 255, 255, 0];
525
526        assert!(same_subnet(ip1, ip2, subnet_mask));
527        assert!(!same_subnet(ip1, ip3, subnet_mask));
528
529        let network = network_address(ip1, subnet_mask);
530        assert_eq!(network, [192, 168, 1, 0]);
531
532        let broadcast = broadcast_address(ip1, subnet_mask);
533        assert_eq!(broadcast, [192, 168, 1, 255]);
534    }
535
536    #[test]
537    fn test_network_device_config() {
538        let mut config = NetworkDeviceConfig::new();
539
540        config.set_ip_address([192, 168, 1, 100]);
541        assert_eq!(config.device_info.ip_address_setup, [192, 168, 1, 100]);
542
543        config.set_dhcp(true);
544        assert_eq!(config.device_info.dhcp, 1);
545
546        config.set_dhcp(false);
547        assert_eq!(config.device_info.dhcp, 0);
548
549        config.set_tcp_timeout(2000);
550        assert_eq!(config.device_info.tcp_timeout, 2000);
551    }
552
553    #[test]
554    fn test_network_options() {
555        let mut config = NetworkDeviceConfig::new();
556
557        config.set_network_options(true, false, false);
558        assert_eq!(config.device_info.additional_network_options & 0x01, 0x01);
559        assert_eq!(config.device_info.additional_network_options & 0x02, 0x00);
560
561        config.set_network_options(false, true, true);
562        assert_eq!(config.device_info.additional_network_options & 0x01, 0x00);
563        assert_eq!(config.device_info.additional_network_options & 0x02, 0x02);
564        assert_eq!(config.device_info.additional_network_options & 0x04, 0x04);
565    }
566
567    #[test]
568    fn test_discovery_packet_format() {
569        let discovery = NetworkDiscovery::new().unwrap();
570        let packet = discovery.create_discovery_packet();
571
572        // PoKeys discovery packet should be empty (0 bytes)
573        assert_eq!(packet.len(), 0, "PoKeys discovery packet must be empty");
574    }
575
576    #[test]
577    fn test_discovery_response_parsing_14_bytes() {
578        let discovery = NetworkDiscovery::new().unwrap();
579
580        // Simulate 14-byte response from older device
581        let response = [
582            0x01, // User ID
583            0x12, 0x34, // Serial number (big-endian): 0x1234 = 4660
584            0x12, 0x05, // Firmware version: encoded as 0x12 (major=1+(1)=2), minor=0x05=5
585            192, 168, 1, 100,  // Device IP: 192.168.1.100
586            0x01, // DHCP enabled
587            192, 168, 1, 1, // Host IP: 192.168.1.1
588        ];
589
590        let sender_addr = "192.168.1.100:20055".parse().unwrap();
591        let device = discovery
592            .parse_discovery_response(&response, sender_addr)
593            .unwrap();
594
595        assert_eq!(device.serial_number, 4660);
596        assert_eq!(device.firmware_version_major, 2); // 1 + (0x12 >> 4) = 1 + 1 = 2
597        assert_eq!(device.firmware_version_minor, 2); // 0x12 & 0x0F = 2
598        assert_eq!(device.firmware_revision, 5); // From data[4]
599        assert_eq!(device.ip_address, [192, 168, 1, 100]);
600        assert_eq!(device.dhcp, 1);
601        assert_eq!(device.host_ip, [192, 168, 1, 1]);
602        assert_eq!(device.hw_type, 0); // Not available in 14-byte format
603    }
604
605    #[test]
606    fn test_discovery_response_parsing_19_bytes() {
607        let discovery = NetworkDiscovery::new().unwrap();
608
609        // Simulate 19-byte response from 58 series device
610        let response = [
611            0x02, // User ID
612            0x00, 0x00, // Unused bytes
613            0x21, 0x01, // Firmware version: encoded as 0x21 (major=1+(2)=3), minor=0x01=1
614            192, 168, 1, 101,  // Device IP: 192.168.1.101
615            0x00, // DHCP disabled
616            192, 168, 1, 1, // Host IP: 192.168.1.1
617            0x78, 0x56, 0x34, 0x12, // Serial number (little-endian): 0x12345678
618            0x58, // Hardware type: 58 series
619        ];
620
621        let sender_addr = "192.168.1.101:20055".parse().unwrap();
622        let device = discovery
623            .parse_discovery_response(&response, sender_addr)
624            .unwrap();
625
626        assert_eq!(device.serial_number, 0x12345678);
627        assert_eq!(device.firmware_version_major, 3); // 1 + (0x21 >> 4) = 1 + 2 = 3
628        assert_eq!(device.firmware_version_minor, 1); // 0x21 & 0x0F = 1
629        assert_eq!(device.firmware_revision, 1); // From data[4]
630        assert_eq!(device.ip_address, [192, 168, 1, 101]);
631        assert_eq!(device.dhcp, 0);
632        assert_eq!(device.host_ip, [192, 168, 1, 1]);
633        assert_eq!(device.hw_type, 0x58);
634    }
635
636    #[test]
637    fn test_discovery_response_invalid_length() {
638        let discovery = NetworkDiscovery::new().unwrap();
639
640        // Test invalid response lengths
641        let short_response = [0x01, 0x02, 0x03]; // Too short
642        let long_response = [0u8; 25]; // Too long
643        let sender_addr = "192.168.1.100:20055".parse().unwrap();
644
645        assert!(
646            discovery
647                .parse_discovery_response(&short_response, sender_addr)
648                .is_none()
649        );
650        assert!(
651            discovery
652                .parse_discovery_response(&long_response, sender_addr)
653                .is_none()
654        );
655    }
656}