1use crate::service::ServiceInfo;
4use crate::error::{DiscoveryError, Result};
5use serde::{Deserialize, Serialize};
6use std::{
7 collections::HashMap,
8 fmt,
9 net::{IpAddr, Ipv4Addr, Ipv6Addr},
10 str::FromStr,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct ServiceType {
16 service_name: String,
18 protocol: String,
20 domain: Option<String>,
22}
23
24impl ServiceType {
25 pub fn new<S: Into<String>>(service: S) -> Result<Self> {
27 let service_type_str = service.into();
28
29 if service_type_str.is_empty() {
30 return Err(DiscoveryError::invalid_service("Service type cannot be empty"));
31 }
32
33 if service_type_str.starts_with("urn:") {
35 return Ok(ServiceType {
36 service_name: service_type_str.clone(),
37 protocol: "".to_string(), domain: None,
39 });
40 }
41
42 let parts: Vec<&str> = service_type_str.split('.').collect();
44
45 if parts.len() < 2 {
46 return Err(DiscoveryError::invalid_service(
47 "Service type must contain protocol (e.g., '._tcp')",
48 ));
49 }
50
51 let service_name = parts[0].to_string();
53
54 let protocol_part = parts[1];
56 if !protocol_part.starts_with('_') {
57 return Err(DiscoveryError::invalid_service(
58 "Service type must contain protocol (e.g., '._tcp')",
59 ));
60 }
61 let protocol = format!(".{protocol_part}");
62
63 let domain = if parts.len() > 2 {
65 Some(parts[2..].join("."))
66 } else {
67 None
68 };
69
70 let final_service_name = if service_name.starts_with('_') {
72 service_name
73 } else {
74 format!("_{service_name}")
75 };
76
77 Ok(ServiceType {
78 service_name: final_service_name,
79 protocol,
80 domain,
81 })
82 }
83
84 pub fn with_protocol<S1: Into<String>, S2: Into<String>>(service: S1, protocol: S2) -> Result<Self> {
86 let mut protocol_str = protocol.into();
87 if &protocol_str[0..1] != "_" {
88 protocol_str = format!("_{protocol_str}");
89 }
90
91 Ok(ServiceType {
92 service_name: service.into(),
93 protocol: protocol_str,
94 domain: None,
95 })
96 }
97
98 pub fn with_domain<S: Into<String>>(service: S, domain: S) -> Result<Self> {
100 Ok(ServiceType {
101 service_name: service.into(),
102 protocol: "_tcp".to_string(),
103 domain: Some(domain.into()),
104 })
105 }
106
107 pub fn service_name(&self) -> &str {
109 &self.service_name
110 }
111
112 pub fn protocol(&self) -> &str {
114 &self.protocol
115 }
116
117 pub fn domain(&self) -> Option<&str> {
119 self.domain.as_deref()
120 }
121
122 pub fn full_name(&self) -> String {
124 if self.service_name.starts_with("urn:") {
126 return self.service_name.clone();
127 }
128
129 match &self.domain {
130 None => format!("{}_{}", self.service_name, self.protocol),
131 Some(domain) => format!("{}_{}.{}", self.service_name, self.protocol, domain),
132 }
133 }
134
135 pub fn is_valid(&self) -> bool {
137 !self.service_name.is_empty() && !self.protocol.is_empty()
138 }
139}
140
141impl FromStr for ServiceType {
142 type Err = DiscoveryError;
143
144 fn from_str(s: &str) -> Result<Self> {
145 ServiceType::new(s)
146 }
147}
148
149impl fmt::Display for ServiceType {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 if let Some(domain) = &self.domain {
152 write!(
153 f,
154 "{}{}.{}",
155 self.service_name, self.protocol, domain
156 )
157 } else {
158 write!(f, "{}{}", self.service_name, self.protocol)
159 }
160 }
161}
162
163impl From<ServiceType> for String {
164 fn from(service_type: ServiceType) -> Self {
165 service_type.to_string()
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
171pub enum ProtocolType {
172 #[default]
174 Mdns,
175 DnsSd,
177 Upnp,
179}
180
181impl fmt::Display for ProtocolType {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 match self {
184 ProtocolType::Mdns => write!(f, "mDNS"),
185 ProtocolType::DnsSd => write!(f, "DNS-SD"),
186 ProtocolType::Upnp => write!(f, "UPnP"),
187 }
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
193pub struct NetworkInterface {
194 pub name: String,
196 pub ipv4_addresses: Vec<Ipv4Addr>,
198 pub ipv6_addresses: Vec<Ipv6Addr>,
200 pub is_up: bool,
202 pub supports_multicast: bool,
204}
205
206impl NetworkInterface {
207 pub fn new<S: Into<String>>(name: S) -> Self {
209 Self {
210 name: name.into(),
211 ipv4_addresses: Vec::new(),
212 ipv6_addresses: Vec::new(),
213 is_up: false,
214 supports_multicast: false,
215 }
216 }
217
218 pub fn with_ipv4(mut self, addr: Ipv4Addr) -> Self {
220 self.ipv4_addresses.push(addr);
221 self
222 }
223
224 pub fn with_ipv6(mut self, addr: Ipv6Addr) -> Self {
226 self.ipv6_addresses.push(addr);
227 self
228 }
229
230 pub fn with_status(mut self, is_up: bool, supports_multicast: bool) -> Self {
232 self.is_up = is_up;
233 self.supports_multicast = supports_multicast;
234 self
235 }
236
237 pub fn all_addresses(&self) -> Vec<IpAddr> {
239 let mut addresses = Vec::new();
240 addresses.extend(self.ipv4_addresses.iter().map(|&addr| IpAddr::V4(addr)));
241 addresses.extend(self.ipv6_addresses.iter().map(|&addr| IpAddr::V6(addr)));
242 addresses
243 }
244}
245
246pub type ServiceAttributes = HashMap<String, String>;
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct DiscoveryFilter {
252 pub service_type_filters: Vec<ServiceType>,
254 pub protocol_filters: Vec<ProtocolType>,
256 pub attribute_patterns: Vec<(String, String)>,
258}
259
260impl DiscoveryFilter {
261 pub fn new() -> Self {
263 Self {
264 service_type_filters: Vec::new(),
265 protocol_filters: Vec::new(),
266 attribute_patterns: Vec::new(),
267 }
268 }
269
270 pub fn with_service_type(mut self, service_type: ServiceType) -> Self {
272 self.service_type_filters.push(service_type);
273 self
274 }
275
276 pub fn with_protocol(mut self, protocol: ProtocolType) -> Self {
278 self.protocol_filters.push(protocol);
279 self
280 }
281
282 pub fn with_attribute_pattern(mut self, key_pattern: String, value_pattern: String) -> Self {
284 self.attribute_patterns.push((key_pattern, value_pattern));
285 self
286 }
287
288 pub fn matches(&self, service: &ServiceInfo) -> bool {
290 if !self.service_type_filters.is_empty()
292 && !self.service_type_filters.contains(&service.service_type) {
293 return false;
294 }
295
296 if !self.protocol_filters.is_empty()
298 && !self.protocol_filters.contains(&service.protocol_type) {
299 return false;
300 }
301
302 for (key_pattern, value_pattern) in &self.attribute_patterns {
304 let mut matches = false;
305 for (key, value) in &service.attributes {
306 if key.contains(key_pattern) && value.contains(value_pattern) {
308 matches = true;
309 break;
310 }
311 }
312 if !matches {
313 return false;
314 }
315 }
316
317 true
318 }
319}
320
321impl Default for DiscoveryFilter {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_service_type() -> Result<()> {
333 let service = ServiceType::new("_http._tcp")?;
334 assert_eq!(service.service_name, "_http");
335 assert_eq!(service.protocol, "._tcp");
336 assert_eq!(service.domain, None);
337 assert_eq!(service.to_string(), "_http._tcp");
338 Ok(())
339 }
340
341 #[test]
342 fn test_service_type_with_domain() -> Result<()> {
343 let service = ServiceType::new("_http._tcp.local")?;
344 assert_eq!(service.service_name, "_http");
345 assert_eq!(service.protocol, "._tcp");
346 assert_eq!(service.domain, Some("local".to_string()));
347 assert_eq!(service.to_string(), "_http._tcp.local");
348 Ok(())
349 }
350
351 #[test]
352 fn test_invalid_service_type() {
353 assert!(ServiceType::new("").is_err());
354 assert!(ServiceType::new("invalid").is_err());
355 assert!(ServiceType::new("_http").is_err()); }
357
358 #[test]
359 fn test_discovery_filter() -> Result<()> {
360 use crate::service::ServiceInfo;
361
362 let filter = DiscoveryFilter::new()
363 .with_service_type(ServiceType::new("_http._tcp")?);
364
365 let service = ServiceInfo::new(
366 "Test Service",
367 "_http._tcp",
368 8080,
369 Some(vec![("version", "1.0")]),
370 )?;
371
372 assert!(filter.matches(&service));
373 Ok(())
374 }
375
376 #[test]
377 fn test_protocol_type_default() {
378 assert_eq!(ProtocolType::default(), ProtocolType::Mdns);
379 }
380}