auto_discovery/
utils.rs

1//! Utility functions for the auto-discovery library
2
3use crate::{
4    error::{DiscoveryError, Result},
5    types::NetworkInterface,
6};
7use std::{
8    net::{IpAddr, Ipv4Addr, Ipv6Addr},
9    time::{Duration, SystemTime, UNIX_EPOCH},
10};
11use tracing::{debug, warn};
12
13/// Network utility functions
14pub mod network {
15    use super::*;
16
17    /// Get all available network interfaces on the system
18    pub fn get_network_interfaces() -> Result<Vec<NetworkInterface>> {
19        debug!("Enumerating network interfaces");
20
21        // This is a placeholder implementation
22        // In a real implementation, you would use platform-specific APIs
23        // or a cross-platform crate like `pnet` or `nix`
24
25        let mut interfaces = Vec::new();
26
27        // Create mock interfaces for demonstration
28        let localhost = NetworkInterface::new("lo")
29            .with_ipv4(Ipv4Addr::LOCALHOST)
30            .with_ipv6(Ipv6Addr::LOCALHOST)
31            .with_status(true, false);
32        interfaces.push(localhost);
33
34        #[cfg(target_os = "windows")]
35        {
36            let ethernet = NetworkInterface::new("Ethernet")
37                .with_ipv4("192.168.1.100".parse().unwrap())
38                .with_ipv6("fe80::1".parse().unwrap())
39                .with_status(true, true);
40            interfaces.push(ethernet);
41
42            let wifi = NetworkInterface::new("Wi-Fi")
43                .with_ipv4("192.168.1.101".parse().unwrap())
44                .with_ipv6("fe80::2".parse().unwrap())
45                .with_status(true, true);
46            interfaces.push(wifi);
47        }
48
49        #[cfg(target_os = "linux")]
50        {
51            let eth0 = NetworkInterface::new("eth0")
52                .with_ipv4("192.168.1.100".parse().unwrap())
53                .with_ipv6("fe80::1".parse().unwrap())
54                .with_status(true, true);
55            interfaces.push(eth0);
56
57            let wlan0 = NetworkInterface::new("wlan0")
58                .with_ipv4("192.168.1.101".parse().unwrap())
59                .with_ipv6("fe80::2".parse().unwrap())
60                .with_status(true, true);
61            interfaces.push(wlan0);
62        }
63
64        #[cfg(target_os = "macos")]
65        {
66            let en0 = NetworkInterface::new("en0")
67                .with_ipv4("192.168.1.100".parse().unwrap())
68                .with_ipv6("fe80::1".parse().unwrap())
69                .with_status(true, true);
70            interfaces.push(en0);
71
72            let en1 = NetworkInterface::new("en1")
73                .with_ipv4("192.168.1.101".parse().unwrap())
74                .with_ipv6("fe80::2".parse().unwrap())
75                .with_status(true, true);
76            interfaces.push(en1);
77        }
78
79        debug!("Found {} network interfaces", interfaces.len());
80        Ok(interfaces)
81    }
82
83    /// Get interfaces that support multicast
84    pub fn get_multicast_interfaces() -> Result<Vec<NetworkInterface>> {
85        let all_interfaces = get_network_interfaces()?;
86        let multicast_interfaces: Vec<NetworkInterface> = all_interfaces
87            .into_iter()
88            .filter(|iface| iface.is_up && iface.supports_multicast)
89            .collect();
90
91        debug!("Found {} multicast-capable interfaces", multicast_interfaces.len());
92        Ok(multicast_interfaces)
93    }
94
95    /// Check if an IP address is in a private range
96    pub fn is_private_ip(ip: &IpAddr) -> bool {
97        match ip {
98            IpAddr::V4(ipv4) => {
99                let octets = ipv4.octets();
100                // 10.0.0.0/8
101                octets[0] == 10 ||
102                // 172.16.0.0/12
103                (octets[0] == 172 && (octets[1] >= 16 && octets[1] <= 31)) ||
104                // 192.168.0.0/16
105                (octets[0] == 192 && octets[1] == 168) ||
106                // 127.0.0.0/8 (loopback)
107                octets[0] == 127
108            }
109            IpAddr::V6(ipv6) => {
110                // fe80::/10 (link-local)
111                let segments = ipv6.segments();
112                (segments[0] & 0xffc0) == 0xfe80 ||
113                // ::1 (loopback)
114                *ipv6 == Ipv6Addr::LOCALHOST ||
115                // fc00::/7 (unique local)
116                (segments[0] & 0xfe00) == 0xfc00
117            }
118        }
119    }
120
121    /// Check if an IP address is a loopback address
122    pub fn is_loopback_ip(ip: &IpAddr) -> bool {
123        match ip {
124            IpAddr::V4(ipv4) => ipv4.is_loopback(),
125            IpAddr::V6(ipv6) => ipv6.is_loopback(),
126        }
127    }
128
129    /// Get the local IP addresses for a given interface
130    pub fn get_interface_addresses(interface_name: &str) -> Result<Vec<IpAddr>> {
131        let interfaces = get_network_interfaces()?;
132        
133        for interface in interfaces {
134            if interface.name == interface_name {
135                return Ok(interface.all_addresses());
136            }
137        }
138
139        Err(DiscoveryError::other(format!(
140            "Interface '{interface_name}' not found"
141        )))
142    }
143
144    /// Check if a port is likely to be available for binding
145    pub async fn is_port_available(port: u16) -> bool {
146        use tokio::net::TcpListener;
147        
148        (TcpListener::bind(("127.0.0.1", port)).await).is_ok()
149    }
150
151    /// Find an available port in a given range
152    pub async fn find_available_port(start_port: u16, end_port: u16) -> Option<u16> {
153        for port in start_port..=end_port {
154            if is_port_available(port).await {
155                return Some(port);
156            }
157        }
158        None
159    }
160}
161
162/// Time utility functions
163pub mod time {
164    use super::*;
165
166    /// Get current timestamp as seconds since Unix epoch
167    pub fn current_timestamp() -> u64 {
168        SystemTime::now()
169            .duration_since(UNIX_EPOCH)
170            .unwrap_or(Duration::ZERO)
171            .as_secs()
172    }
173
174    /// Get current timestamp as milliseconds since Unix epoch
175    pub fn current_timestamp_millis() -> u64 {
176        SystemTime::now()
177            .duration_since(UNIX_EPOCH)
178            .unwrap_or(Duration::ZERO)
179            .as_millis() as u64
180    }
181
182    /// Convert duration to human-readable string
183    pub fn duration_to_string(duration: Duration) -> String {
184        let total_secs = duration.as_secs();
185        let hours = total_secs / 3600;
186        let minutes = (total_secs % 3600) / 60;
187        let seconds = total_secs % 60;
188        let millis = duration.subsec_millis();
189
190        if hours > 0 {
191            format!("{hours}h {minutes}m {seconds}s")
192        } else if minutes > 0 {
193            format!("{minutes}m {seconds}s")
194        } else if seconds > 0 {
195            format!("{seconds}.{millis:03}s")
196        } else {
197            format!("{}ms", duration.as_millis())
198        }
199    }
200
201    /// Check if a duration has elapsed since a given time
202    pub fn has_elapsed(since: SystemTime, duration: Duration) -> bool {
203        since.elapsed().unwrap_or(Duration::ZERO) >= duration
204    }
205}
206
207/// String utility functions
208pub mod string {
209    use super::*;
210    use std::collections::HashMap;
211
212    /// Sanitize a service name for use in network protocols
213    pub fn sanitize_service_name(name: &str) -> String {
214        name.chars()
215            .map(|c| {
216                if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
217                    c
218                } else {
219                    '_'
220                }
221            })
222            .collect()
223    }
224
225    /// Validate a service type string
226    pub fn validate_service_type(service_type: &str) -> Result<()> {
227        if service_type.is_empty() {
228            return Err(DiscoveryError::invalid_service("Service type cannot be empty"));
229        }
230
231        if !service_type.starts_with('_') {
232            return Err(DiscoveryError::invalid_service(
233                "Service type must start with underscore",
234            ));
235        }
236
237        if !service_type.contains("._tcp") && !service_type.contains("._udp") {
238            return Err(DiscoveryError::invalid_service(
239                "Service type must contain ._tcp or ._udp",
240            ));
241        }
242
243        Ok(())
244    }
245
246    /// Parse key-value pairs from a string (e.g., TXT record format)
247    pub fn parse_txt_record(txt_data: &str) -> HashMap<String, String> {
248        let mut attributes = HashMap::new();
249
250        for pair in txt_data.split(';') {
251            if let Some(eq_pos) = pair.find('=') {
252                let key = pair[..eq_pos].trim().to_string();
253                let value = pair[eq_pos + 1..].trim().to_string();
254                attributes.insert(key, value);
255            } else {
256                // Key without value
257                attributes.insert(pair.trim().to_string(), String::new());
258            }
259        }
260
261        attributes
262    }
263
264    /// Format key-value pairs as TXT record string
265    pub fn format_txt_record(attributes: &HashMap<String, String>) -> String {
266        attributes
267            .iter()
268            .map(|(k, v)| {
269                if v.is_empty() {
270                    k.clone()
271                } else {
272                    format!("{k}={v}")
273                }
274            })
275            .collect::<Vec<_>>()
276            .join(";")
277    }
278}
279
280/// Validation utility functions
281pub mod validation {
282    use super::*;
283
284    /// Validate that a port number is in a valid range
285    pub fn validate_port(port: u16) -> Result<()> {
286        if port == 0 {
287            return Err(DiscoveryError::invalid_service("Port cannot be zero"));
288        }
289        Ok(())
290    }
291
292    /// Validate that a timeout is reasonable
293    pub fn validate_timeout(timeout: Duration) -> Result<()> {
294        if timeout.is_zero() {
295            return Err(DiscoveryError::invalid_service("Timeout cannot be zero"));
296        }
297
298        if timeout > Duration::from_secs(300) {
299            warn!("Timeout is very long: {:?}", timeout);
300        }
301
302        Ok(())
303    }
304
305    /// Validate an IP address for service discovery
306    pub fn validate_ip_address(ip: &IpAddr) -> Result<()> {
307        match ip {
308            IpAddr::V4(ipv4) => {
309                if ipv4.is_unspecified() {
310                    return Err(DiscoveryError::invalid_service(
311                        "IPv4 address cannot be unspecified (0.0.0.0)",
312                    ));
313                }
314                if ipv4.is_broadcast() {
315                    return Err(DiscoveryError::invalid_service(
316                        "IPv4 address cannot be broadcast (255.255.255.255)",
317                    ));
318                }
319            }
320            IpAddr::V6(ipv6) => {
321                if ipv6.is_unspecified() {
322                    return Err(DiscoveryError::invalid_service(
323                        "IPv6 address cannot be unspecified (::)",
324                    ));
325                }
326            }
327        }
328        Ok(())
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_get_network_interfaces() {
338        let result = network::get_network_interfaces();
339        assert!(result.is_ok());
340        let interfaces = result.unwrap();
341        assert!(!interfaces.is_empty());
342        
343        // Should always have at least loopback
344        assert!(interfaces.iter().any(|i| i.name == "lo"));
345    }
346
347    #[test]
348    fn test_is_private_ip() {
349        assert!(network::is_private_ip(&"192.168.1.1".parse().unwrap()));
350        assert!(network::is_private_ip(&"10.0.0.1".parse().unwrap()));
351        assert!(network::is_private_ip(&"172.16.0.1".parse().unwrap()));
352        assert!(network::is_private_ip(&"127.0.0.1".parse().unwrap()));
353        assert!(!network::is_private_ip(&"8.8.8.8".parse().unwrap()));
354    }
355
356    #[test]
357    fn test_current_timestamp() {
358        let timestamp = time::current_timestamp();
359        assert!(timestamp > 0);
360        
361        let timestamp_millis = time::current_timestamp_millis();
362        assert!(timestamp_millis > timestamp * 1000);
363    }
364
365    #[test]
366    fn test_duration_to_string() {
367        assert_eq!(time::duration_to_string(Duration::from_millis(500)), "500ms");
368        assert_eq!(time::duration_to_string(Duration::from_secs(5)), "5.000s");
369        assert_eq!(time::duration_to_string(Duration::from_secs(65)), "1m 5s");
370        assert_eq!(time::duration_to_string(Duration::from_secs(3665)), "1h 1m 5s");
371    }
372
373    #[test]
374    fn test_sanitize_service_name() {
375        assert_eq!(string::sanitize_service_name("My Service!"), "My_Service_");
376        assert_eq!(string::sanitize_service_name("test-service_1.0"), "test-service_1.0");
377    }
378
379    #[test]
380    fn test_validate_service_type() {
381        assert!(string::validate_service_type("_http._tcp").is_ok());
382        assert!(string::validate_service_type("_myservice._udp").is_ok());
383        assert!(string::validate_service_type("http._tcp").is_err());
384        assert!(string::validate_service_type("_http").is_err());
385        assert!(string::validate_service_type("").is_err());
386    }
387
388    #[test]
389    fn test_parse_txt_record() {
390        let txt = "version=1.0;protocol=HTTP;enabled";
391        let attrs = string::parse_txt_record(txt);
392        
393        assert_eq!(attrs.get("version"), Some(&"1.0".to_string()));
394        assert_eq!(attrs.get("protocol"), Some(&"HTTP".to_string()));
395        assert_eq!(attrs.get("enabled"), Some(&"".to_string()));
396    }
397
398    #[test]
399    fn test_format_txt_record() {
400        let mut attrs = std::collections::HashMap::new();
401        attrs.insert("version".to_string(), "1.0".to_string());
402        attrs.insert("enabled".to_string(), "".to_string());
403        
404        let txt = string::format_txt_record(&attrs);
405        assert!(txt.contains("version=1.0"));
406        assert!(txt.contains("enabled"));
407    }
408
409    #[tokio::test]
410    async fn test_port_availability() {
411        // Test with a likely available port
412        let available = network::is_port_available(0).await; // Port 0 should bind to any available port
413        assert!(available);
414    }
415
416    #[test]
417    fn test_validate_port() {
418        assert!(validation::validate_port(8080).is_ok());
419        assert!(validation::validate_port(0).is_err());
420    }
421
422    #[test]
423    fn test_validate_timeout() {
424        assert!(validation::validate_timeout(Duration::from_secs(5)).is_ok());
425        assert!(validation::validate_timeout(Duration::ZERO).is_err());
426    }
427
428    #[test]
429    fn test_validate_ip_address() {
430        assert!(validation::validate_ip_address(&"192.168.1.1".parse().unwrap()).is_ok());
431        assert!(validation::validate_ip_address(&"0.0.0.0".parse().unwrap()).is_err());
432        assert!(validation::validate_ip_address(&"255.255.255.255".parse().unwrap()).is_err());
433        assert!(validation::validate_ip_address(&"::1".parse().unwrap()).is_ok());
434        assert!(validation::validate_ip_address(&"::".parse().unwrap()).is_err());
435    }
436}