dynamo_runtime/discovery/
mock.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{
5    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
6    DiscoverySpec, DiscoveryStream,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tokio_util::sync::CancellationToken;
12
13/// Shared in-memory registry for mock discovery
14#[derive(Clone, Default)]
15pub struct SharedMockRegistry {
16    instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
17}
18
19impl SharedMockRegistry {
20    pub fn new() -> Self {
21        Self::default()
22    }
23}
24
25/// Mock implementation of Discovery for testing
26/// We can potentially remove this once we have KVStoreDiscovery fully tested
27pub struct MockDiscovery {
28    instance_id: u64,
29    registry: SharedMockRegistry,
30}
31
32impl MockDiscovery {
33    pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
34        let instance_id = instance_id.unwrap_or_else(|| {
35            use std::sync::atomic::{AtomicU64, Ordering};
36            static COUNTER: AtomicU64 = AtomicU64::new(1);
37            COUNTER.fetch_add(1, Ordering::SeqCst)
38        });
39
40        Self {
41            instance_id,
42            registry,
43        }
44    }
45}
46
47/// Helper function to check if an instance matches a discovery query
48fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
49    match (instance, query) {
50        // Endpoint matching
51        (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
52        (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53            &inst.namespace == namespace
54        }
55        (
56            DiscoveryInstance::Endpoint(inst),
57            DiscoveryQuery::ComponentEndpoints {
58                namespace,
59                component,
60            },
61        ) => &inst.namespace == namespace && &inst.component == component,
62        (
63            DiscoveryInstance::Endpoint(inst),
64            DiscoveryQuery::Endpoint {
65                namespace,
66                component,
67                endpoint,
68            },
69        ) => {
70            &inst.namespace == namespace
71                && &inst.component == component
72                && &inst.endpoint == endpoint
73        }
74
75        // Model matching
76        (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77        (
78            DiscoveryInstance::Model {
79                namespace: inst_ns, ..
80            },
81            DiscoveryQuery::NamespacedModels { namespace },
82        ) => inst_ns == namespace,
83        (
84            DiscoveryInstance::Model {
85                namespace: inst_ns,
86                component: inst_comp,
87                ..
88            },
89            DiscoveryQuery::ComponentModels {
90                namespace,
91                component,
92            },
93        ) => inst_ns == namespace && inst_comp == component,
94        (
95            DiscoveryInstance::Model {
96                namespace: inst_ns,
97                component: inst_comp,
98                endpoint: inst_ep,
99                ..
100            },
101            DiscoveryQuery::EndpointModels {
102                namespace,
103                component,
104                endpoint,
105            },
106        ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
107
108        // Cross-type matches return false
109        (
110            DiscoveryInstance::Endpoint(_),
111            DiscoveryQuery::AllModels
112            | DiscoveryQuery::NamespacedModels { .. }
113            | DiscoveryQuery::ComponentModels { .. }
114            | DiscoveryQuery::EndpointModels { .. },
115        ) => false,
116        (
117            DiscoveryInstance::Model { .. },
118            DiscoveryQuery::AllEndpoints
119            | DiscoveryQuery::NamespacedEndpoints { .. }
120            | DiscoveryQuery::ComponentEndpoints { .. }
121            | DiscoveryQuery::Endpoint { .. },
122        ) => false,
123    }
124}
125
126#[async_trait]
127impl Discovery for MockDiscovery {
128    fn instance_id(&self) -> u64 {
129        self.instance_id
130    }
131
132    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
133        let instance = spec.with_instance_id(self.instance_id);
134
135        self.registry
136            .instances
137            .lock()
138            .unwrap()
139            .push(instance.clone());
140
141        Ok(instance)
142    }
143
144    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
145        let instance_id = instance.instance_id();
146
147        self.registry
148            .instances
149            .lock()
150            .unwrap()
151            .retain(|i| i.instance_id() != instance_id);
152
153        Ok(())
154    }
155
156    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
157        let instances = self.registry.instances.lock().unwrap();
158        Ok(instances
159            .iter()
160            .filter(|instance| matches_query(instance, &query))
161            .cloned()
162            .collect())
163    }
164
165    async fn list_and_watch(
166        &self,
167        query: DiscoveryQuery,
168        _cancel_token: Option<CancellationToken>,
169    ) -> Result<DiscoveryStream> {
170        use std::collections::HashSet;
171
172        let registry = self.registry.clone();
173
174        let stream = async_stream::stream! {
175            let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
176
177            loop {
178                let current: Vec<_> = {
179                    let instances = registry.instances.lock().unwrap();
180                    instances
181                        .iter()
182                        .filter(|instance| matches_query(instance, &query))
183                        .cloned()
184                        .collect()
185                };
186
187                let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
188
189                // Emit Added events for new instances
190                for instance in current {
191                    let id = instance.id();
192                    if known_instances.insert(id) {
193                        yield Ok(DiscoveryEvent::Added(instance));
194                    }
195                }
196
197                // Emit Removed events for instances that are gone
198                for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
199                    known_instances.remove(&id);
200                    yield Ok(DiscoveryEvent::Removed(id));
201                }
202
203                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
204            }
205        };
206
207        Ok(Box::pin(stream))
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use futures::StreamExt;
215
216    #[tokio::test]
217    async fn test_mock_discovery_add_and_remove() {
218        let registry = SharedMockRegistry::new();
219        let client1 = MockDiscovery::new(Some(1), registry.clone());
220        let client2 = MockDiscovery::new(Some(2), registry.clone());
221
222        let spec = DiscoverySpec::Endpoint {
223            namespace: "test-ns".to_string(),
224            component: "test-comp".to_string(),
225            endpoint: "test-ep".to_string(),
226            transport: crate::component::TransportType::Nats("test-subject".to_string()),
227        };
228
229        let query = DiscoveryQuery::Endpoint {
230            namespace: "test-ns".to_string(),
231            component: "test-comp".to_string(),
232            endpoint: "test-ep".to_string(),
233        };
234
235        // Start watching
236        let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
237
238        // Add first instance
239        client1.register(spec.clone()).await.unwrap();
240
241        let event = stream.next().await.unwrap().unwrap();
242        match event {
243            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
244                assert_eq!(inst.instance_id, 1);
245            }
246            _ => panic!("Expected Added event for instance-1"),
247        }
248
249        // Add second instance
250        client2.register(spec.clone()).await.unwrap();
251
252        let event = stream.next().await.unwrap().unwrap();
253        match event {
254            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
255                assert_eq!(inst.instance_id, 2);
256            }
257            _ => panic!("Expected Added event for instance-2"),
258        }
259
260        // Remove first instance
261        registry.instances.lock().unwrap().retain(|i| match i {
262            DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
263            DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
264        });
265
266        let event = stream.next().await.unwrap().unwrap();
267        match event {
268            DiscoveryEvent::Removed(id) => {
269                let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
270                assert_eq!(endpoint_id.instance_id, 1);
271            }
272            _ => panic!("Expected Removed event for instance-1"),
273        }
274    }
275}