dynamo_runtime/discovery/
mock.rs1use super::{
5 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
6 DiscoverySpec, DiscoveryStream,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tokio_util::sync::CancellationToken;
12
13#[derive(Clone, Default)]
15pub struct SharedMockRegistry {
16 instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
17}
18
19impl SharedMockRegistry {
20 pub fn new() -> Self {
21 Self::default()
22 }
23}
24
25pub struct MockDiscovery {
28 instance_id: u64,
29 registry: SharedMockRegistry,
30}
31
32impl MockDiscovery {
33 pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
34 let instance_id = instance_id.unwrap_or_else(|| {
35 use std::sync::atomic::{AtomicU64, Ordering};
36 static COUNTER: AtomicU64 = AtomicU64::new(1);
37 COUNTER.fetch_add(1, Ordering::SeqCst)
38 });
39
40 Self {
41 instance_id,
42 registry,
43 }
44 }
45}
46
47fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
49 match (instance, query) {
50 (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
52 (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53 &inst.namespace == namespace
54 }
55 (
56 DiscoveryInstance::Endpoint(inst),
57 DiscoveryQuery::ComponentEndpoints {
58 namespace,
59 component,
60 },
61 ) => &inst.namespace == namespace && &inst.component == component,
62 (
63 DiscoveryInstance::Endpoint(inst),
64 DiscoveryQuery::Endpoint {
65 namespace,
66 component,
67 endpoint,
68 },
69 ) => {
70 &inst.namespace == namespace
71 && &inst.component == component
72 && &inst.endpoint == endpoint
73 }
74
75 (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77 (
78 DiscoveryInstance::Model {
79 namespace: inst_ns, ..
80 },
81 DiscoveryQuery::NamespacedModels { namespace },
82 ) => inst_ns == namespace,
83 (
84 DiscoveryInstance::Model {
85 namespace: inst_ns,
86 component: inst_comp,
87 ..
88 },
89 DiscoveryQuery::ComponentModels {
90 namespace,
91 component,
92 },
93 ) => inst_ns == namespace && inst_comp == component,
94 (
95 DiscoveryInstance::Model {
96 namespace: inst_ns,
97 component: inst_comp,
98 endpoint: inst_ep,
99 ..
100 },
101 DiscoveryQuery::EndpointModels {
102 namespace,
103 component,
104 endpoint,
105 },
106 ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
107
108 (
110 DiscoveryInstance::EventChannel {
111 namespace: inst_ns,
112 component: inst_comp,
113 topic: inst_topic,
114 ..
115 },
116 DiscoveryQuery::EventChannels(query),
117 ) => {
118 query.namespace.as_ref().is_none_or(|ns| ns == inst_ns)
119 && query.component.as_ref().is_none_or(|c| c == inst_comp)
120 && query.topic.as_ref().is_none_or(|t| t == inst_topic)
121 }
122
123 (
125 DiscoveryInstance::Endpoint(_),
126 DiscoveryQuery::AllModels
127 | DiscoveryQuery::NamespacedModels { .. }
128 | DiscoveryQuery::ComponentModels { .. }
129 | DiscoveryQuery::EndpointModels { .. }
130 | DiscoveryQuery::EventChannels(_),
131 ) => false,
132 (
133 DiscoveryInstance::Model { .. },
134 DiscoveryQuery::AllEndpoints
135 | DiscoveryQuery::NamespacedEndpoints { .. }
136 | DiscoveryQuery::ComponentEndpoints { .. }
137 | DiscoveryQuery::Endpoint { .. }
138 | DiscoveryQuery::EventChannels(_),
139 ) => false,
140 (
141 DiscoveryInstance::EventChannel { .. },
142 DiscoveryQuery::AllEndpoints
143 | DiscoveryQuery::NamespacedEndpoints { .. }
144 | DiscoveryQuery::ComponentEndpoints { .. }
145 | DiscoveryQuery::Endpoint { .. }
146 | DiscoveryQuery::AllModels
147 | DiscoveryQuery::NamespacedModels { .. }
148 | DiscoveryQuery::ComponentModels { .. }
149 | DiscoveryQuery::EndpointModels { .. },
150 ) => false,
151 }
152}
153
154#[async_trait]
155impl Discovery for MockDiscovery {
156 fn instance_id(&self) -> u64 {
157 self.instance_id
158 }
159
160 async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
161 let instance = spec.with_instance_id(self.instance_id);
162
163 self.registry
164 .instances
165 .lock()
166 .unwrap()
167 .push(instance.clone());
168
169 Ok(instance)
170 }
171
172 async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
173 let instance_id = instance.instance_id();
174
175 self.registry
176 .instances
177 .lock()
178 .unwrap()
179 .retain(|i| i.instance_id() != instance_id);
180
181 Ok(())
182 }
183
184 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
185 let instances = self.registry.instances.lock().unwrap();
186 Ok(instances
187 .iter()
188 .filter(|instance| matches_query(instance, &query))
189 .cloned()
190 .collect())
191 }
192
193 async fn list_and_watch(
194 &self,
195 query: DiscoveryQuery,
196 _cancel_token: Option<CancellationToken>,
197 ) -> Result<DiscoveryStream> {
198 use std::collections::HashSet;
199
200 let registry = self.registry.clone();
201
202 let stream = async_stream::stream! {
203 let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
204
205 loop {
206 let current: Vec<_> = {
207 let instances = registry.instances.lock().unwrap();
208 instances
209 .iter()
210 .filter(|instance| matches_query(instance, &query))
211 .cloned()
212 .collect()
213 };
214
215 let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
216
217 for instance in current {
219 let id = instance.id();
220 if known_instances.insert(id) {
221 yield Ok(DiscoveryEvent::Added(instance));
222 }
223 }
224
225 for id in known_instances.difference(¤t_ids).cloned().collect::<Vec<_>>() {
227 known_instances.remove(&id);
228 yield Ok(DiscoveryEvent::Removed(id));
229 }
230
231 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
232 }
233 };
234
235 Ok(Box::pin(stream))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use futures::StreamExt;
243
244 #[tokio::test]
245 async fn test_mock_discovery_add_and_remove() {
246 let registry = SharedMockRegistry::new();
247 let client1 = MockDiscovery::new(Some(1), registry.clone());
248 let client2 = MockDiscovery::new(Some(2), registry.clone());
249
250 let spec = DiscoverySpec::Endpoint {
251 namespace: "test-ns".to_string(),
252 component: "test-comp".to_string(),
253 endpoint: "test-ep".to_string(),
254 transport: crate::component::TransportType::Nats("test-subject".to_string()),
255 };
256
257 let query = DiscoveryQuery::Endpoint {
258 namespace: "test-ns".to_string(),
259 component: "test-comp".to_string(),
260 endpoint: "test-ep".to_string(),
261 };
262
263 let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
265
266 client1.register(spec.clone()).await.unwrap();
268
269 let event = stream.next().await.unwrap().unwrap();
270 match event {
271 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
272 assert_eq!(inst.instance_id, 1);
273 }
274 _ => panic!("Expected Added event for instance-1"),
275 }
276
277 client2.register(spec.clone()).await.unwrap();
279
280 let event = stream.next().await.unwrap().unwrap();
281 match event {
282 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
283 assert_eq!(inst.instance_id, 2);
284 }
285 _ => panic!("Expected Added event for instance-2"),
286 }
287
288 registry.instances.lock().unwrap().retain(|i| match i {
290 DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
291 DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
292 DiscoveryInstance::EventChannel { instance_id, .. } => *instance_id != 1,
293 });
294
295 let event = stream.next().await.unwrap().unwrap();
296 match event {
297 DiscoveryEvent::Removed(id) => {
298 let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
299 assert_eq!(endpoint_id.instance_id, 1);
300 }
301 _ => panic!("Expected Removed event for instance-1"),
302 }
303 }
304}