dynamo_runtime/discovery/
mock.rs1use super::{
5 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10use tokio_util::sync::CancellationToken;
11
12#[derive(Clone, Default)]
14pub struct SharedMockRegistry {
15 instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
16}
17
18impl SharedMockRegistry {
19 pub fn new() -> Self {
20 Self::default()
21 }
22}
23
24pub struct MockDiscovery {
27 instance_id: u64,
28 registry: SharedMockRegistry,
29}
30
31impl MockDiscovery {
32 pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
33 let instance_id = instance_id.unwrap_or_else(|| {
34 use std::sync::atomic::{AtomicU64, Ordering};
35 static COUNTER: AtomicU64 = AtomicU64::new(1);
36 COUNTER.fetch_add(1, Ordering::SeqCst)
37 });
38
39 Self {
40 instance_id,
41 registry,
42 }
43 }
44}
45
46fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
48 match (instance, query) {
49 (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
51 (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
52 &inst.namespace == namespace
53 }
54 (
55 DiscoveryInstance::Endpoint(inst),
56 DiscoveryQuery::ComponentEndpoints {
57 namespace,
58 component,
59 },
60 ) => &inst.namespace == namespace && &inst.component == component,
61 (
62 DiscoveryInstance::Endpoint(inst),
63 DiscoveryQuery::Endpoint {
64 namespace,
65 component,
66 endpoint,
67 },
68 ) => {
69 &inst.namespace == namespace
70 && &inst.component == component
71 && &inst.endpoint == endpoint
72 }
73
74 (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
76 (
77 DiscoveryInstance::Model {
78 namespace: inst_ns, ..
79 },
80 DiscoveryQuery::NamespacedModels { namespace },
81 ) => inst_ns == namespace,
82 (
83 DiscoveryInstance::Model {
84 namespace: inst_ns,
85 component: inst_comp,
86 ..
87 },
88 DiscoveryQuery::ComponentModels {
89 namespace,
90 component,
91 },
92 ) => inst_ns == namespace && inst_comp == component,
93 (
94 DiscoveryInstance::Model {
95 namespace: inst_ns,
96 component: inst_comp,
97 endpoint: inst_ep,
98 ..
99 },
100 DiscoveryQuery::EndpointModels {
101 namespace,
102 component,
103 endpoint,
104 },
105 ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
106
107 (
109 DiscoveryInstance::Endpoint(_),
110 DiscoveryQuery::AllModels
111 | DiscoveryQuery::NamespacedModels { .. }
112 | DiscoveryQuery::ComponentModels { .. }
113 | DiscoveryQuery::EndpointModels { .. },
114 ) => false,
115 (
116 DiscoveryInstance::Model { .. },
117 DiscoveryQuery::AllEndpoints
118 | DiscoveryQuery::NamespacedEndpoints { .. }
119 | DiscoveryQuery::ComponentEndpoints { .. }
120 | DiscoveryQuery::Endpoint { .. },
121 ) => false,
122 }
123}
124
125#[async_trait]
126impl Discovery for MockDiscovery {
127 fn instance_id(&self) -> u64 {
128 self.instance_id
129 }
130
131 async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
132 let instance = spec.with_instance_id(self.instance_id);
133
134 self.registry
135 .instances
136 .lock()
137 .unwrap()
138 .push(instance.clone());
139
140 Ok(instance)
141 }
142
143 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
144 let instances = self.registry.instances.lock().unwrap();
145 Ok(instances
146 .iter()
147 .filter(|instance| matches_query(instance, &query))
148 .cloned()
149 .collect())
150 }
151
152 async fn list_and_watch(
153 &self,
154 query: DiscoveryQuery,
155 _cancel_token: Option<CancellationToken>,
156 ) -> Result<DiscoveryStream> {
157 use std::collections::HashSet;
158
159 let registry = self.registry.clone();
160
161 let stream = async_stream::stream! {
162 let mut known_instances = HashSet::new();
163
164 loop {
165 let current: Vec<_> = {
166 let instances = registry.instances.lock().unwrap();
167 instances
168 .iter()
169 .filter(|instance| matches_query(instance, &query))
170 .cloned()
171 .collect()
172 };
173
174 let current_ids: HashSet<_> = current.iter().map(|i| {
175 match i {
176 DiscoveryInstance::Endpoint(inst) => inst.instance_id,
177 DiscoveryInstance::Model { instance_id, .. } => *instance_id,
178 }
179 }).collect();
180
181 for instance in current {
183 let id = match &instance {
184 DiscoveryInstance::Endpoint(inst) => inst.instance_id,
185 DiscoveryInstance::Model { instance_id, .. } => *instance_id,
186 };
187 if known_instances.insert(id) {
188 yield Ok(DiscoveryEvent::Added(instance));
189 }
190 }
191
192 for id in known_instances.difference(¤t_ids).cloned().collect::<Vec<_>>() {
194 yield Ok(DiscoveryEvent::Removed(id));
195 known_instances.remove(&id);
196 }
197
198 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
199 }
200 };
201
202 Ok(Box::pin(stream))
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use futures::StreamExt;
210
211 #[tokio::test]
212 async fn test_mock_discovery_add_and_remove() {
213 let registry = SharedMockRegistry::new();
214 let client1 = MockDiscovery::new(Some(1), registry.clone());
215 let client2 = MockDiscovery::new(Some(2), registry.clone());
216
217 let spec = DiscoverySpec::Endpoint {
218 namespace: "test-ns".to_string(),
219 component: "test-comp".to_string(),
220 endpoint: "test-ep".to_string(),
221 transport: crate::component::TransportType::Nats("test-subject".to_string()),
222 };
223
224 let query = DiscoveryQuery::Endpoint {
225 namespace: "test-ns".to_string(),
226 component: "test-comp".to_string(),
227 endpoint: "test-ep".to_string(),
228 };
229
230 let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
232
233 client1.register(spec.clone()).await.unwrap();
235
236 let event = stream.next().await.unwrap().unwrap();
237 match event {
238 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
239 assert_eq!(inst.instance_id, 1);
240 }
241 _ => panic!("Expected Added event for instance-1"),
242 }
243
244 client2.register(spec.clone()).await.unwrap();
246
247 let event = stream.next().await.unwrap().unwrap();
248 match event {
249 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
250 assert_eq!(inst.instance_id, 2);
251 }
252 _ => panic!("Expected Added event for instance-2"),
253 }
254
255 registry.instances.lock().unwrap().retain(|i| match i {
257 DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
258 DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
259 });
260
261 let event = stream.next().await.unwrap().unwrap();
262 match event {
263 DiscoveryEvent::Removed(instance_id) => {
264 assert_eq!(instance_id, 1);
265 }
266 _ => panic!("Expected Removed event for instance-1"),
267 }
268 }
269}