1use crate::{
8 error::{DiscoveryError, Result},
9 service::ServiceInfo,
10 types::{ServiceType, ProtocolType},
11};
12use std::{
13 collections::HashMap,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17use tokio::sync::RwLock;
18use tracing::{debug, info, warn};
19
20#[derive(Debug, Clone)]
22pub struct ServiceEntry {
23 pub service: ServiceInfo,
25 pub timestamp: Instant,
27 pub is_local: bool,
29 pub ttl: Option<Duration>,
31 pub protocol: ProtocolType,
33}
34
35impl ServiceEntry {
36 pub fn new_local(service: ServiceInfo, protocol: ProtocolType) -> Self {
38 Self {
39 service,
40 timestamp: Instant::now(),
41 is_local: true,
42 ttl: None, protocol,
44 }
45 }
46
47 pub fn new_discovered(service: ServiceInfo, protocol: ProtocolType, ttl: Option<Duration>) -> Self {
49 Self {
50 service,
51 timestamp: Instant::now(),
52 is_local: false,
53 ttl,
54 protocol,
55 }
56 }
57
58 pub fn is_expired(&self) -> bool {
60 if let Some(ttl) = self.ttl {
61 self.timestamp.elapsed() > ttl
62 } else {
63 false
64 }
65 }
66
67 pub fn service_id(&self) -> String {
69 format!("{}:{}:{}", self.service.name(), self.service.service_type(), self.service.port())
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct ServiceFilter {
76 pub service_types: Option<Vec<ServiceType>>,
78 pub protocols: Option<Vec<ProtocolType>>,
80 pub name_contains: Option<String>,
82 pub local_only: bool,
84 pub discovered_only: bool,
86 pub max_age: Option<Duration>,
88}
89
90
91
92impl ServiceFilter {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn with_service_types(mut self, types: Vec<ServiceType>) -> Self {
100 self.service_types = Some(types);
101 self
102 }
103
104 pub fn with_protocols(mut self, protocols: Vec<ProtocolType>) -> Self {
106 self.protocols = Some(protocols);
107 self
108 }
109
110 pub fn with_name_contains(mut self, name: String) -> Self {
112 self.name_contains = Some(name);
113 self
114 }
115
116 pub fn local_only(mut self) -> Self {
118 self.local_only = true;
119 self
120 }
121
122 pub fn discovered_only(mut self) -> Self {
124 self.discovered_only = true;
125 self
126 }
127
128 pub fn with_max_age(mut self, max_age: Duration) -> Self {
130 self.max_age = Some(max_age);
131 self
132 }
133
134 pub fn matches(&self, entry: &ServiceEntry) -> bool {
136 if entry.is_expired() {
138 return false;
139 }
140
141 if let Some(max_age) = self.max_age {
143 if entry.timestamp.elapsed() > max_age {
144 return false;
145 }
146 }
147
148 if self.local_only && !entry.is_local {
150 return false;
151 }
152 if self.discovered_only && entry.is_local {
153 return false;
154 }
155
156 if let Some(ref types) = self.service_types {
158 if !types.iter().any(|t| t.to_string() == entry.service.service_type().to_string()) {
159 return false;
160 }
161 }
162
163 if let Some(ref protocols) = self.protocols {
165 if !protocols.contains(&entry.protocol) {
166 return false;
167 }
168 }
169
170 if let Some(ref name) = self.name_contains {
172 if !entry.service.name().contains(name) {
173 return false;
174 }
175 }
176
177 true
178 }
179}
180
181pub struct ServiceRegistry {
183 services: Arc<RwLock<HashMap<String, ServiceEntry>>>,
185 default_ttl: Duration,
187 max_services: usize,
189}
190
191impl ServiceRegistry {
192 pub fn new() -> Self {
194 Self {
195 services: Arc::new(RwLock::new(HashMap::new())),
196 default_ttl: Duration::from_secs(300), max_services: 1000,
198 }
199 }
200
201 pub fn with_settings(default_ttl: Duration, max_services: usize) -> Self {
203 Self {
204 services: Arc::new(RwLock::new(HashMap::new())),
205 default_ttl,
206 max_services,
207 }
208 }
209
210 pub async fn register_local_service(&self, service: ServiceInfo, protocol: ProtocolType) -> Result<()> {
212 let entry = ServiceEntry::new_local(service, protocol);
213 let service_id = entry.service_id();
214
215 let mut services = self.services.write().await;
216 services.insert(service_id.clone(), entry);
217
218 info!("Registered local service: {}", service_id);
219 Ok(())
220 }
221
222 pub async fn unregister_local_service(&self, service_id: &str) -> Result<()> {
224 let mut services = self.services.write().await;
225 if services.remove(service_id).is_some() {
226 info!("Unregistered local service: {}", service_id);
227 Ok(())
228 } else {
229 warn!("Attempted to unregister unknown service: {}", service_id);
230 Err(DiscoveryError::service_not_found(service_id))
231 }
232 }
233
234 pub async fn add_discovered_service(&self, service: ServiceInfo, protocol: ProtocolType, ttl: Option<Duration>) -> Result<()> {
236 let ttl = ttl.unwrap_or(self.default_ttl);
237 let entry = ServiceEntry::new_discovered(service, protocol, Some(ttl));
238 let service_id = entry.service_id();
239
240 let mut services = self.services.write().await;
241
242 if services.len() >= self.max_services {
244 if let Some(oldest_expired) = self.find_oldest_expired(&services) {
246 services.remove(&oldest_expired);
247 } else {
248 warn!("Service registry at capacity, cannot add new service");
249 return Err(DiscoveryError::configuration("Service registry at capacity"));
250 }
251 }
252
253 services.insert(service_id.clone(), entry);
254 debug!("Added discovered service: {}", service_id);
255 Ok(())
256 }
257
258 pub async fn find_services(&self, filter: &ServiceFilter) -> Vec<ServiceInfo> {
260 let services = self.services.read().await;
261
262 services
263 .values()
264 .filter(|entry| filter.matches(entry))
265 .map(|entry| entry.service.clone())
266 .collect()
267 }
268
269 pub async fn get_local_services(&self) -> Vec<ServiceInfo> {
271 let filter = ServiceFilter::new().local_only();
272 self.find_services(&filter).await
273 }
274
275 pub async fn get_discovered_services(&self) -> Vec<ServiceInfo> {
277 let filter = ServiceFilter::new().discovered_only();
278 self.find_services(&filter).await
279 }
280
281 pub async fn get_services_by_type(&self, service_type: &ServiceType) -> Vec<ServiceInfo> {
283 let filter = ServiceFilter::new().with_service_types(vec![service_type.clone()]);
284 self.find_services(&filter).await
285 }
286
287 pub async fn get_services_by_protocol(&self, protocol: ProtocolType) -> Vec<ServiceInfo> {
289 let filter = ServiceFilter::new().with_protocols(vec![protocol]);
290 self.find_services(&filter).await
291 }
292
293 pub async fn is_local_service(&self, service_id: &str) -> bool {
295 let services = self.services.read().await;
296 services.get(service_id).map(|entry| entry.is_local).unwrap_or(false)
297 }
298
299 pub async fn contains_service(&self, service_id: &str) -> bool {
301 let services = self.services.read().await;
302 services.contains_key(service_id)
303 }
304
305 pub async fn cleanup_expired(&self) -> usize {
307 let mut services = self.services.write().await;
308 let initial_count = services.len();
309
310 services.retain(|_, entry| !entry.is_expired());
311
312 let removed_count = initial_count - services.len();
313 if removed_count > 0 {
314 debug!("Cleaned up {} expired services", removed_count);
315 }
316
317 removed_count
318 }
319
320 pub async fn stats(&self) -> RegistryStats {
322 let services = self.services.read().await;
323
324 let mut local_count = 0;
325 let mut discovered_count = 0;
326 let mut expired_count = 0;
327
328 for entry in services.values() {
329 if entry.is_local {
330 local_count += 1;
331 } else {
332 discovered_count += 1;
333 }
334
335 if entry.is_expired() {
336 expired_count += 1;
337 }
338 }
339
340 RegistryStats {
341 total_services: services.len(),
342 local_services: local_count,
343 discovered_services: discovered_count,
344 expired_services: expired_count,
345 }
346 }
347
348 fn find_oldest_expired(&self, services: &HashMap<String, ServiceEntry>) -> Option<String> {
350 services
351 .iter()
352 .filter(|(_, entry)| entry.is_expired())
353 .min_by_key(|(_, entry)| entry.timestamp)
354 .map(|(id, _)| id.clone())
355 }
356}
357
358#[derive(Debug, Clone)]
360pub struct RegistryStats {
361 pub total_services: usize,
363 pub local_services: usize,
365 pub discovered_services: usize,
367 pub expired_services: usize,
369}
370
371impl Default for ServiceRegistry {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use std::net::{IpAddr, Ipv4Addr};
381 use tokio::time::{sleep, Duration};
382
383 #[tokio::test]
384 async fn test_register_and_find_local_service() {
385 let registry = ServiceRegistry::new();
386
387 let service = ServiceInfo::new("test", "_http._tcp", 8080, None)
388 .unwrap()
389 .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
390
391 registry.register_local_service(service.clone(), ProtocolType::Mdns).await.unwrap();
392
393 let local_services = registry.get_local_services().await;
394 assert_eq!(local_services.len(), 1);
395 assert_eq!(local_services[0].name(), service.name());
396 }
397
398 #[tokio::test]
399 async fn test_discover_and_find_service() {
400 let registry = ServiceRegistry::new();
401
402 let service = ServiceInfo::new("discovered", "_http._tcp", 9090, None)
403 .unwrap()
404 .with_address(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)));
405
406 registry.add_discovered_service(service.clone(), ProtocolType::Upnp, Some(Duration::from_secs(60))).await.unwrap();
407
408 let discovered_services = registry.get_discovered_services().await;
409 assert_eq!(discovered_services.len(), 1);
410 assert_eq!(discovered_services[0].name(), service.name());
411 }
412
413 #[tokio::test]
414 async fn test_service_filter() {
415 let registry = ServiceRegistry::new();
416
417 let http_service = ServiceInfo::new("web", "_http._tcp", 80, None)
418 .unwrap()
419 .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
420
421 let ssh_service = ServiceInfo::new("ssh", "_ssh._tcp", 22, None)
422 .unwrap()
423 .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
424
425 registry.register_local_service(http_service.clone(), ProtocolType::Mdns).await.unwrap();
426 registry.add_discovered_service(ssh_service.clone(), ProtocolType::Upnp, Some(Duration::from_secs(60))).await.unwrap();
427
428 let http_services = registry.get_services_by_type(&ServiceType::new("_http._tcp").unwrap()).await;
430 assert_eq!(http_services.len(), 1);
431 assert_eq!(http_services[0].name(), "web");
432
433 let mdns_services = registry.get_services_by_protocol(ProtocolType::Mdns).await;
435 assert_eq!(mdns_services.len(), 1);
436 assert_eq!(mdns_services[0].name(), "web");
437
438 let local_services = registry.get_local_services().await;
440 assert_eq!(local_services.len(), 1);
441 assert_eq!(local_services[0].name(), "web");
442 }
443
444 #[tokio::test]
445 async fn test_service_expiration() {
446 let registry = ServiceRegistry::new();
447
448 let service = ServiceInfo::new("temp", "_http._tcp", 8080, None)
449 .unwrap()
450 .with_address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
451
452 registry.add_discovered_service(service.clone(), ProtocolType::Mdns, Some(Duration::from_millis(50))).await.unwrap();
454
455 let services = registry.get_discovered_services().await;
457 assert_eq!(services.len(), 1);
458
459 sleep(Duration::from_millis(100)).await;
461
462 let services = registry.get_discovered_services().await;
464 assert_eq!(services.len(), 0);
465
466 let removed = registry.cleanup_expired().await;
468 assert_eq!(removed, 1);
469 }
470}