1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8
9#[derive(Debug, Clone)]
10pub struct ResolvedTarget {
11 pub host: String,
12 pub port: Option<u16>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Resolved {
17 pub service_name: String,
18 pub addresses: Vec<ResolvedTarget>,
19}
20
21#[async_trait]
22pub trait ServiceDiscovery: Send + Sync + 'static {
23 async fn lookup(&self, service_name: &str) -> Resolved;
24}
25
26#[derive(Default)]
27pub struct StaticDiscovery {
28 services: RwLock<HashMap<String, Vec<ResolvedTarget>>>,
29}
30
31impl StaticDiscovery {
32 pub fn new() -> Arc<Self> {
33 Arc::new(Self::default())
34 }
35
36 pub fn register(&self, name: impl Into<String>, target: ResolvedTarget) {
37 self.services.write().entry(name.into()).or_default().push(target);
38 }
39}
40
41#[async_trait]
42impl ServiceDiscovery for StaticDiscovery {
43 async fn lookup(&self, service_name: &str) -> Resolved {
44 Resolved {
45 service_name: service_name.into(),
46 addresses: self.services.read().get(service_name).cloned().unwrap_or_default(),
47 }
48 }
49}
50
51pub struct AggregateDiscovery {
54 providers: Vec<Arc<dyn ServiceDiscovery>>,
55}
56
57impl AggregateDiscovery {
58 pub fn new(providers: Vec<Arc<dyn ServiceDiscovery>>) -> Arc<Self> {
59 Arc::new(Self { providers })
60 }
61
62 pub fn provider_count(&self) -> usize {
63 self.providers.len()
64 }
65}
66
67#[async_trait]
68impl ServiceDiscovery for AggregateDiscovery {
69 async fn lookup(&self, service_name: &str) -> Resolved {
70 for p in &self.providers {
71 let r = p.lookup(service_name).await;
72 if !r.addresses.is_empty() {
73 return r;
74 }
75 }
76 Resolved { service_name: service_name.into(), addresses: Vec::new() }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[tokio::test]
85 async fn static_discovery_resolves() {
86 let d = StaticDiscovery::new();
87 d.register("svc", ResolvedTarget { host: "1.2.3.4".into(), port: Some(8080) });
88 let r = d.lookup("svc").await;
89 assert_eq!(r.addresses.len(), 1);
90 }
91
92 #[tokio::test]
93 async fn aggregate_falls_through_to_second_provider_when_first_empty() {
94 let empty = StaticDiscovery::new();
95 let full = StaticDiscovery::new();
96 full.register("svc", ResolvedTarget { host: "10.0.0.1".into(), port: None });
97 let agg = AggregateDiscovery::new(vec![empty, full]);
98 let r = agg.lookup("svc").await;
99 assert_eq!(r.addresses.len(), 1);
100 assert_eq!(r.addresses[0].host, "10.0.0.1");
101 }
102
103 #[tokio::test]
104 async fn aggregate_returns_first_nonempty_provider() {
105 let a = StaticDiscovery::new();
106 a.register("svc", ResolvedTarget { host: "first".into(), port: None });
107 let b = StaticDiscovery::new();
108 b.register("svc", ResolvedTarget { host: "second".into(), port: None });
109 let agg = AggregateDiscovery::new(vec![a, b]);
110 let r = agg.lookup("svc").await;
111 assert_eq!(r.addresses.len(), 1);
112 assert_eq!(r.addresses[0].host, "first");
113 }
114
115 #[tokio::test]
116 async fn aggregate_empty_when_no_providers_resolve() {
117 let a = StaticDiscovery::new();
118 let b = StaticDiscovery::new();
119 let agg = AggregateDiscovery::new(vec![a, b]);
120 let r = agg.lookup("svc").await;
121 assert!(r.addresses.is_empty());
122 assert_eq!(r.service_name, "svc");
123 }
124
125 #[tokio::test]
126 async fn aggregate_with_no_providers_resolves_empty() {
127 let agg = AggregateDiscovery::new(Vec::new());
128 assert_eq!(agg.provider_count(), 0);
129 let r = agg.lookup("svc").await;
130 assert!(r.addresses.is_empty());
131 }
132}