dynamo_runtime/discovery/
mock.rs1use super::{
5 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
6 DiscoverySpec, DiscoveryStream,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tokio_util::sync::CancellationToken;
12
13#[derive(Clone, Default)]
15pub struct SharedMockRegistry {
16 instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
17}
18
19impl SharedMockRegistry {
20 pub fn new() -> Self {
21 Self::default()
22 }
23}
24
25pub struct MockDiscovery {
28 instance_id: u64,
29 registry: SharedMockRegistry,
30}
31
32impl MockDiscovery {
33 pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
34 let instance_id = instance_id.unwrap_or_else(|| {
35 use std::sync::atomic::{AtomicU64, Ordering};
36 static COUNTER: AtomicU64 = AtomicU64::new(1);
37 COUNTER.fetch_add(1, Ordering::SeqCst)
38 });
39
40 Self {
41 instance_id,
42 registry,
43 }
44 }
45}
46
47fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
49 match (instance, query) {
50 (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
52 (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53 &inst.namespace == namespace
54 }
55 (
56 DiscoveryInstance::Endpoint(inst),
57 DiscoveryQuery::ComponentEndpoints {
58 namespace,
59 component,
60 },
61 ) => &inst.namespace == namespace && &inst.component == component,
62 (
63 DiscoveryInstance::Endpoint(inst),
64 DiscoveryQuery::Endpoint {
65 namespace,
66 component,
67 endpoint,
68 },
69 ) => {
70 &inst.namespace == namespace
71 && &inst.component == component
72 && &inst.endpoint == endpoint
73 }
74
75 (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77 (
78 DiscoveryInstance::Model {
79 namespace: inst_ns, ..
80 },
81 DiscoveryQuery::NamespacedModels { namespace },
82 ) => inst_ns == namespace,
83 (
84 DiscoveryInstance::Model {
85 namespace: inst_ns,
86 component: inst_comp,
87 ..
88 },
89 DiscoveryQuery::ComponentModels {
90 namespace,
91 component,
92 },
93 ) => inst_ns == namespace && inst_comp == component,
94 (
95 DiscoveryInstance::Model {
96 namespace: inst_ns,
97 component: inst_comp,
98 endpoint: inst_ep,
99 ..
100 },
101 DiscoveryQuery::EndpointModels {
102 namespace,
103 component,
104 endpoint,
105 },
106 ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
107
108 (
110 DiscoveryInstance::Endpoint(_),
111 DiscoveryQuery::AllModels
112 | DiscoveryQuery::NamespacedModels { .. }
113 | DiscoveryQuery::ComponentModels { .. }
114 | DiscoveryQuery::EndpointModels { .. },
115 ) => false,
116 (
117 DiscoveryInstance::Model { .. },
118 DiscoveryQuery::AllEndpoints
119 | DiscoveryQuery::NamespacedEndpoints { .. }
120 | DiscoveryQuery::ComponentEndpoints { .. }
121 | DiscoveryQuery::Endpoint { .. },
122 ) => false,
123 }
124}
125
126#[async_trait]
127impl Discovery for MockDiscovery {
128 fn instance_id(&self) -> u64 {
129 self.instance_id
130 }
131
132 async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
133 let instance = spec.with_instance_id(self.instance_id);
134
135 self.registry
136 .instances
137 .lock()
138 .unwrap()
139 .push(instance.clone());
140
141 Ok(instance)
142 }
143
144 async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
145 let instance_id = instance.instance_id();
146
147 self.registry
148 .instances
149 .lock()
150 .unwrap()
151 .retain(|i| i.instance_id() != instance_id);
152
153 Ok(())
154 }
155
156 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
157 let instances = self.registry.instances.lock().unwrap();
158 Ok(instances
159 .iter()
160 .filter(|instance| matches_query(instance, &query))
161 .cloned()
162 .collect())
163 }
164
165 async fn list_and_watch(
166 &self,
167 query: DiscoveryQuery,
168 _cancel_token: Option<CancellationToken>,
169 ) -> Result<DiscoveryStream> {
170 use std::collections::HashSet;
171
172 let registry = self.registry.clone();
173
174 let stream = async_stream::stream! {
175 let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
176
177 loop {
178 let current: Vec<_> = {
179 let instances = registry.instances.lock().unwrap();
180 instances
181 .iter()
182 .filter(|instance| matches_query(instance, &query))
183 .cloned()
184 .collect()
185 };
186
187 let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
188
189 for instance in current {
191 let id = instance.id();
192 if known_instances.insert(id) {
193 yield Ok(DiscoveryEvent::Added(instance));
194 }
195 }
196
197 for id in known_instances.difference(¤t_ids).cloned().collect::<Vec<_>>() {
199 known_instances.remove(&id);
200 yield Ok(DiscoveryEvent::Removed(id));
201 }
202
203 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
204 }
205 };
206
207 Ok(Box::pin(stream))
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use futures::StreamExt;
215
216 #[tokio::test]
217 async fn test_mock_discovery_add_and_remove() {
218 let registry = SharedMockRegistry::new();
219 let client1 = MockDiscovery::new(Some(1), registry.clone());
220 let client2 = MockDiscovery::new(Some(2), registry.clone());
221
222 let spec = DiscoverySpec::Endpoint {
223 namespace: "test-ns".to_string(),
224 component: "test-comp".to_string(),
225 endpoint: "test-ep".to_string(),
226 transport: crate::component::TransportType::Nats("test-subject".to_string()),
227 };
228
229 let query = DiscoveryQuery::Endpoint {
230 namespace: "test-ns".to_string(),
231 component: "test-comp".to_string(),
232 endpoint: "test-ep".to_string(),
233 };
234
235 let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
237
238 client1.register(spec.clone()).await.unwrap();
240
241 let event = stream.next().await.unwrap().unwrap();
242 match event {
243 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
244 assert_eq!(inst.instance_id, 1);
245 }
246 _ => panic!("Expected Added event for instance-1"),
247 }
248
249 client2.register(spec.clone()).await.unwrap();
251
252 let event = stream.next().await.unwrap().unwrap();
253 match event {
254 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
255 assert_eq!(inst.instance_id, 2);
256 }
257 _ => panic!("Expected Added event for instance-2"),
258 }
259
260 registry.instances.lock().unwrap().retain(|i| match i {
262 DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
263 DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
264 });
265
266 let event = stream.next().await.unwrap().unwrap();
267 match event {
268 DiscoveryEvent::Removed(id) => {
269 let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
270 assert_eq!(endpoint_id.instance_id, 1);
271 }
272 _ => panic!("Expected Removed event for instance-1"),
273 }
274 }
275}