auto_discovery/
types.rs

1//! Type definitions for the auto-discovery library
2
3use crate::service::ServiceInfo;
4use crate::error::{DiscoveryError, Result};
5use serde::{Deserialize, Serialize};
6use std::{
7    collections::HashMap,
8    fmt,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr},
10    str::FromStr,
11};
12
13/// Represents a service type for discovery
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct ServiceType {
16    /// The service string without protocol (e.g., "_http", "_myservice")
17    service_name: String,
18    /// The protocol string (e.g., "_tcp", "_udp")
19    protocol: String,
20    /// Optional domain for the service
21    domain: Option<String>,
22}
23
24impl ServiceType {
25    /// Create a new service type with default TCP protocol
26    pub fn new<S: Into<String>>(service: S) -> Result<Self> {
27        let service_type_str = service.into();
28
29        if service_type_str.is_empty() {
30            return Err(DiscoveryError::invalid_service("Service type cannot be empty"));
31        }
32
33        // Handle UPnP URN format (urn:schemas-upnp-org:service:ContentDirectory:1)
34        if service_type_str.starts_with("urn:") {
35            return Ok(ServiceType {
36                service_name: service_type_str.clone(),
37                protocol: "".to_string(), // UPnP doesn't use traditional protocols
38                domain: None,
39            });
40        }
41
42        // Parse service type like "_http._tcp.local" or "_http._tcp"
43        let parts: Vec<&str> = service_type_str.split('.').collect();
44        
45        if parts.len() < 2 {
46            return Err(DiscoveryError::invalid_service(
47                "Service type must contain protocol (e.g., '._tcp')",
48            ));
49        }
50
51        // Extract service name (first part)
52        let service_name = parts[0].to_string();
53        
54        // Extract protocol (second part, should start with _)
55        let protocol_part = parts[1];
56        if !protocol_part.starts_with('_') {
57            return Err(DiscoveryError::invalid_service(
58                "Service type must contain protocol (e.g., '._tcp')",
59            ));
60        }
61        let protocol = format!(".{protocol_part}");
62        
63        // Extract domain if present (third part and beyond)
64        let domain = if parts.len() > 2 {
65            Some(parts[2..].join("."))
66        } else {
67            None
68        };
69
70        // Ensure service has leading underscore
71        let final_service_name = if service_name.starts_with('_') {
72            service_name
73        } else {
74            format!("_{service_name}")
75        };
76
77        Ok(ServiceType {
78            service_name: final_service_name,
79            protocol,
80            domain,
81        })
82    }
83
84    /// Create a new service type with specified protocol
85    pub fn with_protocol<S1: Into<String>, S2: Into<String>>(service: S1, protocol: S2) -> Result<Self> {
86        let mut protocol_str = protocol.into();
87        if &protocol_str[0..1] != "_" {
88            protocol_str = format!("_{protocol_str}");
89        }
90
91        Ok(ServiceType {
92            service_name: service.into(),
93            protocol: protocol_str,
94            domain: None,
95        })
96    }
97
98    /// Create a new service type with specified domain
99    pub fn with_domain<S: Into<String>>(service: S, domain: S) -> Result<Self> {
100        Ok(ServiceType {
101            service_name: service.into(),
102            protocol: "_tcp".to_string(),
103            domain: Some(domain.into()),
104        })
105    }
106
107    /// Get the service string
108    pub fn service_name(&self) -> &str {
109        &self.service_name
110    }
111
112    /// Get the protocol string
113    pub fn protocol(&self) -> &str {
114        &self.protocol
115    }
116
117    /// Get the domain if present
118    pub fn domain(&self) -> Option<&str> {
119        self.domain.as_deref()
120    }
121
122    /// Convert to a fully qualified service string
123    pub fn full_name(&self) -> String {
124        // For UPnP URNs, return the service name as-is since it's already complete
125        if self.service_name.starts_with("urn:") {
126            return self.service_name.clone();
127        }
128        
129        match &self.domain {
130            None => format!("{}_{}", self.service_name, self.protocol),
131            Some(domain) => format!("{}_{}.{}", self.service_name, self.protocol, domain),
132        }
133    }
134
135    /// Check if the service type is valid
136    pub fn is_valid(&self) -> bool {
137        !self.service_name.is_empty() && !self.protocol.is_empty()
138    }
139}
140
141impl FromStr for ServiceType {
142    type Err = DiscoveryError;
143
144    fn from_str(s: &str) -> Result<Self> {
145        ServiceType::new(s)
146    }
147}
148
149impl fmt::Display for ServiceType {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        if let Some(domain) = &self.domain {
152            write!(
153                f,
154                "{}{}.{}",
155                self.service_name, self.protocol, domain
156            )
157        } else {
158            write!(f, "{}{}", self.service_name, self.protocol)
159        }
160    }
161}
162
163impl From<ServiceType> for String {
164    fn from(service_type: ServiceType) -> Self {
165        service_type.to_string()
166    }
167}
168
169/// Protocol type for service discovery
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
171pub enum ProtocolType {
172    /// Multicast DNS
173    #[default]
174    Mdns,
175    /// DNS Service Discovery
176    DnsSd,
177    /// Universal Plug and Play
178    Upnp,
179}
180
181impl fmt::Display for ProtocolType {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        match self {
184            ProtocolType::Mdns => write!(f, "mDNS"),
185            ProtocolType::DnsSd => write!(f, "DNS-SD"),
186            ProtocolType::Upnp => write!(f, "UPnP"),
187        }
188    }
189}
190
191/// Network interface information
192#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
193pub struct NetworkInterface {
194    /// Interface name
195    pub name: String,
196    /// IPv4 addresses
197    pub ipv4_addresses: Vec<Ipv4Addr>,
198    /// IPv6 addresses
199    pub ipv6_addresses: Vec<Ipv6Addr>,
200    /// Whether the interface is active
201    pub is_up: bool,
202    /// Whether the interface supports multicast
203    pub supports_multicast: bool,
204}
205
206impl NetworkInterface {
207    /// Create a new network interface
208    pub fn new<S: Into<String>>(name: S) -> Self {
209        Self {
210            name: name.into(),
211            ipv4_addresses: Vec::new(),
212            ipv6_addresses: Vec::new(),
213            is_up: false,
214            supports_multicast: false,
215        }
216    }
217
218    /// Add an IPv4 address
219    pub fn with_ipv4(mut self, addr: Ipv4Addr) -> Self {
220        self.ipv4_addresses.push(addr);
221        self
222    }
223
224    /// Add an IPv6 address
225    pub fn with_ipv6(mut self, addr: Ipv6Addr) -> Self {
226        self.ipv6_addresses.push(addr);
227        self
228    }
229
230    /// Set interface status
231    pub fn with_status(mut self, is_up: bool, supports_multicast: bool) -> Self {
232        self.is_up = is_up;
233        self.supports_multicast = supports_multicast;
234        self
235    }
236
237    /// Get all IP addresses
238    pub fn all_addresses(&self) -> Vec<IpAddr> {
239        let mut addresses = Vec::new();
240        addresses.extend(self.ipv4_addresses.iter().map(|&addr| IpAddr::V4(addr)));
241        addresses.extend(self.ipv6_addresses.iter().map(|&addr| IpAddr::V6(addr)));
242        addresses
243    }
244}
245
246/// Service attributes as key-value pairs
247pub type ServiceAttributes = HashMap<String, String>;
248
249/// Filter for discovered services
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct DiscoveryFilter {
252    /// Service type filters
253    pub service_type_filters: Vec<ServiceType>,
254    /// Protocol type filters
255    pub protocol_filters: Vec<ProtocolType>,
256    /// Custom attribute filter patterns (key-value regex patterns)
257    pub attribute_patterns: Vec<(String, String)>,
258}
259
260impl DiscoveryFilter {
261    /// Create a new empty filter
262    pub fn new() -> Self {
263        Self {
264            service_type_filters: Vec::new(),
265            protocol_filters: Vec::new(),
266            attribute_patterns: Vec::new(),
267        }
268    }
269
270    /// Add a service type filter
271    pub fn with_service_type(mut self, service_type: ServiceType) -> Self {
272        self.service_type_filters.push(service_type);
273        self
274    }
275
276    /// Add a protocol filter
277    pub fn with_protocol(mut self, protocol: ProtocolType) -> Self {
278        self.protocol_filters.push(protocol);
279        self
280    }
281
282    /// Add an attribute pattern filter (key regex, value regex)
283    pub fn with_attribute_pattern(mut self, key_pattern: String, value_pattern: String) -> Self {
284        self.attribute_patterns.push((key_pattern, value_pattern));
285        self
286    }
287
288    /// Check if a service matches this filter
289    pub fn matches(&self, service: &ServiceInfo) -> bool {
290        // Check service type filters
291        if !self.service_type_filters.is_empty() 
292            && !self.service_type_filters.contains(&service.service_type) {
293            return false;
294        }
295
296        // Check protocol filters
297        if !self.protocol_filters.is_empty() 
298            && !self.protocol_filters.contains(&service.protocol_type) {
299            return false;
300        }
301
302        // Check attribute pattern filters
303        for (key_pattern, value_pattern) in &self.attribute_patterns {
304            let mut matches = false;
305            for (key, value) in &service.attributes {
306                // Simple string matching for now (could be enhanced with regex)
307                if key.contains(key_pattern) && value.contains(value_pattern) {
308                    matches = true;
309                    break;
310                }
311            }
312            if !matches {
313                return false;
314            }
315        }
316
317        true
318    }
319}
320
321impl Default for DiscoveryFilter {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_service_type() -> Result<()> {
333        let service = ServiceType::new("_http._tcp")?;
334        assert_eq!(service.service_name, "_http");
335        assert_eq!(service.protocol, "._tcp");
336        assert_eq!(service.domain, None);
337        assert_eq!(service.to_string(), "_http._tcp");
338        Ok(())
339    }
340
341    #[test]
342    fn test_service_type_with_domain() -> Result<()> {
343        let service = ServiceType::new("_http._tcp.local")?;
344        assert_eq!(service.service_name, "_http");
345        assert_eq!(service.protocol, "._tcp");
346        assert_eq!(service.domain, Some("local".to_string()));
347        assert_eq!(service.to_string(), "_http._tcp.local");
348        Ok(())
349    }
350
351    #[test]
352    fn test_invalid_service_type() {
353        assert!(ServiceType::new("").is_err());
354        assert!(ServiceType::new("invalid").is_err());
355        assert!(ServiceType::new("_http").is_err()); // Missing protocol
356    }
357
358    #[test] 
359    fn test_discovery_filter() -> Result<()> {
360        use crate::service::ServiceInfo;
361        
362        let filter = DiscoveryFilter::new()
363            .with_service_type(ServiceType::new("_http._tcp")?);
364
365        let service = ServiceInfo::new(
366            "Test Service",
367            "_http._tcp",
368            8080,
369            Some(vec![("version", "1.0")]),
370        )?;
371
372        assert!(filter.matches(&service));
373        Ok(())
374    }
375
376    #[test]
377    fn test_protocol_type_default() {
378        assert_eq!(ProtocolType::default(), ProtocolType::Mdns);
379    }
380}