Skip to main content

atomr_discovery/
lib.rs

1//! atomr-discovery.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8
9#[derive(Debug, Clone)]
10pub struct ResolvedTarget {
11    pub host: String,
12    pub port: Option<u16>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Resolved {
17    pub service_name: String,
18    pub addresses: Vec<ResolvedTarget>,
19}
20
21#[async_trait]
22pub trait ServiceDiscovery: Send + Sync + 'static {
23    async fn lookup(&self, service_name: &str) -> Resolved;
24}
25
26#[derive(Default)]
27pub struct StaticDiscovery {
28    services: RwLock<HashMap<String, Vec<ResolvedTarget>>>,
29}
30
31impl StaticDiscovery {
32    pub fn new() -> Arc<Self> {
33        Arc::new(Self::default())
34    }
35
36    pub fn register(&self, name: impl Into<String>, target: ResolvedTarget) {
37        self.services.write().entry(name.into()).or_default().push(target);
38    }
39}
40
41#[async_trait]
42impl ServiceDiscovery for StaticDiscovery {
43    async fn lookup(&self, service_name: &str) -> Resolved {
44        Resolved {
45            service_name: service_name.into(),
46            addresses: self.services.read().get(service_name).cloned().unwrap_or_default(),
47        }
48    }
49}
50
51/// Chain of discovery backends. `lookup` walks providers in order and
52/// returns the first non-empty resolution.
53pub struct AggregateDiscovery {
54    providers: Vec<Arc<dyn ServiceDiscovery>>,
55}
56
57impl AggregateDiscovery {
58    pub fn new(providers: Vec<Arc<dyn ServiceDiscovery>>) -> Arc<Self> {
59        Arc::new(Self { providers })
60    }
61
62    pub fn provider_count(&self) -> usize {
63        self.providers.len()
64    }
65}
66
67#[async_trait]
68impl ServiceDiscovery for AggregateDiscovery {
69    async fn lookup(&self, service_name: &str) -> Resolved {
70        for p in &self.providers {
71            let r = p.lookup(service_name).await;
72            if !r.addresses.is_empty() {
73                return r;
74            }
75        }
76        Resolved { service_name: service_name.into(), addresses: Vec::new() }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[tokio::test]
85    async fn static_discovery_resolves() {
86        let d = StaticDiscovery::new();
87        d.register("svc", ResolvedTarget { host: "1.2.3.4".into(), port: Some(8080) });
88        let r = d.lookup("svc").await;
89        assert_eq!(r.addresses.len(), 1);
90    }
91
92    #[tokio::test]
93    async fn aggregate_falls_through_to_second_provider_when_first_empty() {
94        let empty = StaticDiscovery::new();
95        let full = StaticDiscovery::new();
96        full.register("svc", ResolvedTarget { host: "10.0.0.1".into(), port: None });
97        let agg = AggregateDiscovery::new(vec![empty, full]);
98        let r = agg.lookup("svc").await;
99        assert_eq!(r.addresses.len(), 1);
100        assert_eq!(r.addresses[0].host, "10.0.0.1");
101    }
102
103    #[tokio::test]
104    async fn aggregate_returns_first_nonempty_provider() {
105        let a = StaticDiscovery::new();
106        a.register("svc", ResolvedTarget { host: "first".into(), port: None });
107        let b = StaticDiscovery::new();
108        b.register("svc", ResolvedTarget { host: "second".into(), port: None });
109        let agg = AggregateDiscovery::new(vec![a, b]);
110        let r = agg.lookup("svc").await;
111        assert_eq!(r.addresses.len(), 1);
112        assert_eq!(r.addresses[0].host, "first");
113    }
114
115    #[tokio::test]
116    async fn aggregate_empty_when_no_providers_resolve() {
117        let a = StaticDiscovery::new();
118        let b = StaticDiscovery::new();
119        let agg = AggregateDiscovery::new(vec![a, b]);
120        let r = agg.lookup("svc").await;
121        assert!(r.addresses.is_empty());
122        assert_eq!(r.service_name, "svc");
123    }
124
125    #[tokio::test]
126    async fn aggregate_with_no_providers_resolves_empty() {
127        let agg = AggregateDiscovery::new(Vec::new());
128        assert_eq!(agg.provider_count(), 0);
129        let r = agg.lookup("svc").await;
130        assert!(r.addresses.is_empty());
131    }
132}