auto_discovery/
registry.rs

1//! Service Registry for managing discovered and registered services
2//!
3//! This module provides a centralized registry for managing service discovery state
4//! across different protocols. It handles both locally registered services and
5//! services discovered from the network.
6
7use crate::{
8    error::{DiscoveryError, Result},
9    service::ServiceInfo,
10    types::{ServiceType, ProtocolType},
11};
12use std::{
13    collections::HashMap,
14    sync::Arc,
15    time::{Duration, Instant},
16};
17use tokio::sync::RwLock;
18use tracing::{debug, info, warn};
19
20/// Entry in the service registry with metadata
21#[derive(Debug, Clone)]
22pub struct ServiceEntry {
23    /// The service information
24    pub service: ServiceInfo,
25    /// When the service was registered/discovered
26    pub timestamp: Instant,
27    /// Whether this is a locally registered service
28    pub is_local: bool,
29    /// Time-to-live for the service entry
30    pub ttl: Option<Duration>,
31    /// The protocol that discovered/registered this service
32    pub protocol: ProtocolType,
33}
34
35impl ServiceEntry {
36    /// Create a new service entry for a locally registered service
37    pub fn new_local(service: ServiceInfo, protocol: ProtocolType) -> Self {
38        Self {
39            service,
40            timestamp: Instant::now(),
41            is_local: true,
42            ttl: None, // Local services don't expire
43            protocol,
44        }
45    }
46
47    /// Create a new service entry for a discovered service
48    pub fn new_discovered(service: ServiceInfo, protocol: ProtocolType, ttl: Option<Duration>) -> Self {
49        Self {
50            service,
51            timestamp: Instant::now(),
52            is_local: false,
53            ttl,
54            protocol,
55        }
56    }
57
58    /// Check if this service entry has expired
59    pub fn is_expired(&self) -> bool {
60        if let Some(ttl) = self.ttl {
61            self.timestamp.elapsed() > ttl
62        } else {
63            false
64        }
65    }
66
67    /// Get the service ID for indexing
68    pub fn service_id(&self) -> String {
69        format!("{}:{}:{}", self.service.name(), self.service.service_type(), self.service.port())
70    }
71}
72
73/// Filter for querying services from the registry
74#[derive(Debug, Clone, Default)]
75pub struct ServiceFilter {
76    /// Filter by service types
77    pub service_types: Option<Vec<ServiceType>>,
78    /// Filter by protocols
79    pub protocols: Option<Vec<ProtocolType>>,
80    /// Filter by service name (contains)
81    pub name_contains: Option<String>,
82    /// Include only local services
83    pub local_only: bool,
84    /// Include only discovered services
85    pub discovered_only: bool,
86    /// Maximum age of services to include
87    pub max_age: Option<Duration>,
88}
89
90
91
92impl ServiceFilter {
93    /// Create a new empty filter
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Filter by service types
99    pub fn with_service_types(mut self, types: Vec<ServiceType>) -> Self {
100        self.service_types = Some(types);
101        self
102    }
103
104    /// Filter by protocols
105    pub fn with_protocols(mut self, protocols: Vec<ProtocolType>) -> Self {
106        self.protocols = Some(protocols);
107        self
108    }
109
110    /// Filter by service name
111    pub fn with_name_contains(mut self, name: String) -> Self {
112        self.name_contains = Some(name);
113        self
114    }
115
116    /// Include only local services
117    pub fn local_only(mut self) -> Self {
118        self.local_only = true;
119        self
120    }
121
122    /// Include only discovered services
123    pub fn discovered_only(mut self) -> Self {
124        self.discovered_only = true;
125        self
126    }
127
128    /// Set maximum age of services
129    pub fn with_max_age(mut self, max_age: Duration) -> Self {
130        self.max_age = Some(max_age);
131        self
132    }
133
134    /// Check if a service entry matches this filter
135    pub fn matches(&self, entry: &ServiceEntry) -> bool {
136        // Check if expired
137        if entry.is_expired() {
138            return false;
139        }
140
141        // Check max age
142        if let Some(max_age) = self.max_age {
143            if entry.timestamp.elapsed() > max_age {
144                return false;
145            }
146        }
147
148        // Check local/discovered filter
149        if self.local_only && !entry.is_local {
150            return false;
151        }
152        if self.discovered_only && entry.is_local {
153            return false;
154        }
155
156        // Check service types
157        if let Some(ref types) = self.service_types {
158            if !types.iter().any(|t| t.to_string() == entry.service.service_type().to_string()) {
159                return false;
160            }
161        }
162
163        // Check protocols
164        if let Some(ref protocols) = self.protocols {
165            if !protocols.contains(&entry.protocol) {
166                return false;
167            }
168        }
169
170        // Check name contains
171        if let Some(ref name) = self.name_contains {
172            if !entry.service.name().contains(name) {
173                return false;
174            }
175        }
176
177        true
178    }
179}
180
181/// Centralized service registry for managing discovered and registered services
182pub struct ServiceRegistry {
183    /// All services indexed by service ID
184    services: Arc<RwLock<HashMap<String, ServiceEntry>>>,
185    /// Default TTL for discovered services
186    default_ttl: Duration,
187    /// Maximum number of services to store
188    max_services: usize,
189}
190
191impl ServiceRegistry {
192    /// Create a new service registry
193    pub fn new() -> Self {
194        Self {
195            services: Arc::new(RwLock::new(HashMap::new())),
196            default_ttl: Duration::from_secs(300), // 5 minutes
197            max_services: 1000,
198        }
199    }
200
201    /// Create a new service registry with custom settings
202    pub fn with_settings(default_ttl: Duration, max_services: usize) -> Self {
203        Self {
204            services: Arc::new(RwLock::new(HashMap::new())),
205            default_ttl,
206            max_services,
207        }
208    }
209
210    /// Register a local service
211    pub async fn register_local_service(&self, service: ServiceInfo, protocol: ProtocolType) -> Result<()> {
212        let entry = ServiceEntry::new_local(service, protocol);
213        let service_id = entry.service_id();
214        
215        let mut services = self.services.write().await;
216        services.insert(service_id.clone(), entry);
217        
218        info!("Registered local service: {}", service_id);
219        Ok(())
220    }
221
222    /// Unregister a local service
223    pub async fn unregister_local_service(&self, service_id: &str) -> Result<()> {
224        let mut services = self.services.write().await;
225        if services.remove(service_id).is_some() {
226            info!("Unregistered local service: {}", service_id);
227            Ok(())
228        } else {
229            warn!("Attempted to unregister unknown service: {}", service_id);
230            Err(DiscoveryError::service_not_found(service_id))
231        }
232    }
233
234    /// Add a discovered service
235    pub async fn add_discovered_service(&self, service: ServiceInfo, protocol: ProtocolType, ttl: Option<Duration>) -> Result<()> {
236        let ttl = ttl.unwrap_or(self.default_ttl);
237        let entry = ServiceEntry::new_discovered(service, protocol, Some(ttl));
238        let service_id = entry.service_id();
239        
240        let mut services = self.services.write().await;
241        
242        // Check if we're at capacity
243        if services.len() >= self.max_services {
244            // Remove oldest expired service
245            if let Some(oldest_expired) = self.find_oldest_expired(&services) {
246                services.remove(&oldest_expired);
247            } else {
248                warn!("Service registry at capacity, cannot add new service");
249                return Err(DiscoveryError::configuration("Service registry at capacity"));
250            }
251        }
252        
253        services.insert(service_id.clone(), entry);
254        debug!("Added discovered service: {}", service_id);
255        Ok(())
256    }
257
258    /// Find services matching the given filter
259    pub async fn find_services(&self, filter: &ServiceFilter) -> Vec<ServiceInfo> {
260        let services = self.services.read().await;
261        
262        services
263            .values()
264            .filter(|entry| filter.matches(entry))
265            .map(|entry| entry.service.clone())
266            .collect()
267    }
268
269    /// Get all locally registered services
270    pub async fn get_local_services(&self) -> Vec<ServiceInfo> {
271        let filter = ServiceFilter::new().local_only();
272        self.find_services(&filter).await
273    }
274
275    /// Get all discovered services
276    pub async fn get_discovered_services(&self) -> Vec<ServiceInfo> {
277        let filter = ServiceFilter::new().discovered_only();
278        self.find_services(&filter).await
279    }
280
281    /// Get services by type
282    pub async fn get_services_by_type(&self, service_type: &ServiceType) -> Vec<ServiceInfo> {
283        let filter = ServiceFilter::new().with_service_types(vec![service_type.clone()]);
284        self.find_services(&filter).await
285    }
286
287    /// Get services by protocol
288    pub async fn get_services_by_protocol(&self, protocol: ProtocolType) -> Vec<ServiceInfo> {
289        let filter = ServiceFilter::new().with_protocols(vec![protocol]);
290        self.find_services(&filter).await
291    }
292
293    /// Check if a service is registered locally
294    pub async fn is_local_service(&self, service_id: &str) -> bool {
295        let services = self.services.read().await;
296        services.get(service_id).map(|entry| entry.is_local).unwrap_or(false)
297    }
298
299    /// Check if a service exists in the registry
300    pub async fn contains_service(&self, service_id: &str) -> bool {
301        let services = self.services.read().await;
302        services.contains_key(service_id)
303    }
304
305    /// Clean up expired services
306    pub async fn cleanup_expired(&self) -> usize {
307        let mut services = self.services.write().await;
308        let initial_count = services.len();
309        
310        services.retain(|_, entry| !entry.is_expired());
311        
312        let removed_count = initial_count - services.len();
313        if removed_count > 0 {
314            debug!("Cleaned up {} expired services", removed_count);
315        }
316        
317        removed_count
318    }
319
320    /// Get registry statistics
321    pub async fn stats(&self) -> RegistryStats {
322        let services = self.services.read().await;
323        
324        let mut local_count = 0;
325        let mut discovered_count = 0;
326        let mut expired_count = 0;
327        
328        for entry in services.values() {
329            if entry.is_local {
330                local_count += 1;
331            } else {
332                discovered_count += 1;
333            }
334            
335            if entry.is_expired() {
336                expired_count += 1;
337            }
338        }
339        
340        RegistryStats {
341            total_services: services.len(),
342            local_services: local_count,
343            discovered_services: discovered_count,
344            expired_services: expired_count,
345        }
346    }
347
348    /// Find the oldest expired service for cleanup
349    fn find_oldest_expired(&self, services: &HashMap<String, ServiceEntry>) -> Option<String> {
350        services
351            .iter()
352            .filter(|(_, entry)| entry.is_expired())
353            .min_by_key(|(_, entry)| entry.timestamp)
354            .map(|(id, _)| id.clone())
355    }
356}
357
358/// Registry statistics
359#[derive(Debug, Clone)]
360pub struct RegistryStats {
361    /// Total number of services
362    pub total_services: usize,
363    /// Number of local services
364    pub local_services: usize,
365    /// Number of discovered services
366    pub discovered_services: usize,
367    /// Number of expired services
368    pub expired_services: usize,
369}
370
371impl Default for ServiceRegistry {
372    fn default() -> Self {
373        Self::new()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use std::net::{IpAddr, Ipv4Addr};
381    use tokio::time::{sleep, Duration};
382
383    #[tokio::test]
384    async fn test_register_and_find_local_service() {
385        let registry = ServiceRegistry::new();
386        
387        let service = ServiceInfo::new("test", "_http._tcp", 8080, None)
388            .unwrap()
389            .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
390        
391        registry.register_local_service(service.clone(), ProtocolType::Mdns).await.unwrap();
392        
393        let local_services = registry.get_local_services().await;
394        assert_eq!(local_services.len(), 1);
395        assert_eq!(local_services[0].name(), service.name());
396    }
397
398    #[tokio::test]
399    async fn test_discover_and_find_service() {
400        let registry = ServiceRegistry::new();
401        
402        let service = ServiceInfo::new("discovered", "_http._tcp", 9090, None)
403            .unwrap()
404            .with_address(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)));
405        
406        registry.add_discovered_service(service.clone(), ProtocolType::Upnp, Some(Duration::from_secs(60))).await.unwrap();
407        
408        let discovered_services = registry.get_discovered_services().await;
409        assert_eq!(discovered_services.len(), 1);
410        assert_eq!(discovered_services[0].name(), service.name());
411    }
412
413    #[tokio::test]
414    async fn test_service_filter() {
415        let registry = ServiceRegistry::new();
416        
417        let http_service = ServiceInfo::new("web", "_http._tcp", 80, None)
418            .unwrap()
419            .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
420        
421        let ssh_service = ServiceInfo::new("ssh", "_ssh._tcp", 22, None)
422            .unwrap()
423            .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
424        
425        registry.register_local_service(http_service.clone(), ProtocolType::Mdns).await.unwrap();
426        registry.add_discovered_service(ssh_service.clone(), ProtocolType::Upnp, Some(Duration::from_secs(60))).await.unwrap();
427        
428        // Test filter by type
429        let http_services = registry.get_services_by_type(&ServiceType::new("_http._tcp").unwrap()).await;
430        assert_eq!(http_services.len(), 1);
431        assert_eq!(http_services[0].name(), "web");
432        
433        // Test filter by protocol
434        let mdns_services = registry.get_services_by_protocol(ProtocolType::Mdns).await;
435        assert_eq!(mdns_services.len(), 1);
436        assert_eq!(mdns_services[0].name(), "web");
437        
438        // Test local only filter
439        let local_services = registry.get_local_services().await;
440        assert_eq!(local_services.len(), 1);
441        assert_eq!(local_services[0].name(), "web");
442    }
443
444    #[tokio::test]
445    async fn test_service_expiration() {
446        let registry = ServiceRegistry::new();
447        
448        let service = ServiceInfo::new("temp", "_http._tcp", 8080, None)
449            .unwrap()
450            .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
451        
452        // Add service with very short TTL
453        registry.add_discovered_service(service.clone(), ProtocolType::Mdns, Some(Duration::from_millis(50))).await.unwrap();
454        
455        // Should find service immediately
456        let services = registry.get_discovered_services().await;
457        assert_eq!(services.len(), 1);
458        
459        // Wait for expiration
460        sleep(Duration::from_millis(100)).await;
461        
462        // Should not find expired service
463        let services = registry.get_discovered_services().await;
464        assert_eq!(services.len(), 0);
465        
466        // Cleanup should remove expired service
467        let removed = registry.cleanup_expired().await;
468        assert_eq!(removed, 1);
469    }
470}