Skip to main content

dynamo_runtime/discovery/
mock.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{
5    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
6    DiscoverySpec, DiscoveryStream,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tokio_util::sync::CancellationToken;
12
13/// Shared in-memory registry for mock discovery
14#[derive(Clone, Default)]
15pub struct SharedMockRegistry {
16    instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
17}
18
19impl SharedMockRegistry {
20    pub fn new() -> Self {
21        Self::default()
22    }
23}
24
25/// Mock implementation of Discovery for testing
26/// We can potentially remove this once we have KVStoreDiscovery fully tested
27pub struct MockDiscovery {
28    instance_id: u64,
29    registry: SharedMockRegistry,
30}
31
32impl MockDiscovery {
33    pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
34        let instance_id = instance_id.unwrap_or_else(|| {
35            use std::sync::atomic::{AtomicU64, Ordering};
36            static COUNTER: AtomicU64 = AtomicU64::new(1);
37            COUNTER.fetch_add(1, Ordering::SeqCst)
38        });
39
40        Self {
41            instance_id,
42            registry,
43        }
44    }
45}
46
47/// Helper function to check if an instance matches a discovery query
48fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
49    match (instance, query) {
50        // Endpoint matching
51        (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
52        (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53            &inst.namespace == namespace
54        }
55        (
56            DiscoveryInstance::Endpoint(inst),
57            DiscoveryQuery::ComponentEndpoints {
58                namespace,
59                component,
60            },
61        ) => &inst.namespace == namespace && &inst.component == component,
62        (
63            DiscoveryInstance::Endpoint(inst),
64            DiscoveryQuery::Endpoint {
65                namespace,
66                component,
67                endpoint,
68            },
69        ) => {
70            &inst.namespace == namespace
71                && &inst.component == component
72                && &inst.endpoint == endpoint
73        }
74
75        // Model matching
76        (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77        (
78            DiscoveryInstance::Model {
79                namespace: inst_ns, ..
80            },
81            DiscoveryQuery::NamespacedModels { namespace },
82        ) => inst_ns == namespace,
83        (
84            DiscoveryInstance::Model {
85                namespace: inst_ns,
86                component: inst_comp,
87                ..
88            },
89            DiscoveryQuery::ComponentModels {
90                namespace,
91                component,
92            },
93        ) => inst_ns == namespace && inst_comp == component,
94        (
95            DiscoveryInstance::Model {
96                namespace: inst_ns,
97                component: inst_comp,
98                endpoint: inst_ep,
99                ..
100            },
101            DiscoveryQuery::EndpointModels {
102                namespace,
103                component,
104                endpoint,
105            },
106        ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
107
108        // EventChannel matching - unified query
109        (
110            DiscoveryInstance::EventChannel {
111                namespace: inst_ns,
112                component: inst_comp,
113                topic: inst_topic,
114                ..
115            },
116            DiscoveryQuery::EventChannels(query),
117        ) => {
118            query.namespace.as_ref().is_none_or(|ns| ns == inst_ns)
119                && query.component.as_ref().is_none_or(|c| c == inst_comp)
120                && query.topic.as_ref().is_none_or(|t| t == inst_topic)
121        }
122
123        // Cross-type matches return false
124        (
125            DiscoveryInstance::Endpoint(_),
126            DiscoveryQuery::AllModels
127            | DiscoveryQuery::NamespacedModels { .. }
128            | DiscoveryQuery::ComponentModels { .. }
129            | DiscoveryQuery::EndpointModels { .. }
130            | DiscoveryQuery::EventChannels(_),
131        ) => false,
132        (
133            DiscoveryInstance::Model { .. },
134            DiscoveryQuery::AllEndpoints
135            | DiscoveryQuery::NamespacedEndpoints { .. }
136            | DiscoveryQuery::ComponentEndpoints { .. }
137            | DiscoveryQuery::Endpoint { .. }
138            | DiscoveryQuery::EventChannels(_),
139        ) => false,
140        (
141            DiscoveryInstance::EventChannel { .. },
142            DiscoveryQuery::AllEndpoints
143            | DiscoveryQuery::NamespacedEndpoints { .. }
144            | DiscoveryQuery::ComponentEndpoints { .. }
145            | DiscoveryQuery::Endpoint { .. }
146            | DiscoveryQuery::AllModels
147            | DiscoveryQuery::NamespacedModels { .. }
148            | DiscoveryQuery::ComponentModels { .. }
149            | DiscoveryQuery::EndpointModels { .. },
150        ) => false,
151    }
152}
153
154#[async_trait]
155impl Discovery for MockDiscovery {
156    fn instance_id(&self) -> u64 {
157        self.instance_id
158    }
159
160    async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
161        let instance = spec.with_instance_id(self.instance_id);
162
163        self.registry
164            .instances
165            .lock()
166            .unwrap()
167            .push(instance.clone());
168
169        Ok(instance)
170    }
171
172    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
173        let target_id = instance.id();
174
175        self.registry
176            .instances
177            .lock()
178            .unwrap()
179            .retain(|i| i.id() != target_id);
180
181        Ok(())
182    }
183
184    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
185        let instances = self.registry.instances.lock().unwrap();
186        Ok(instances
187            .iter()
188            .filter(|instance| matches_query(instance, &query))
189            .cloned()
190            .collect())
191    }
192
193    async fn list_and_watch(
194        &self,
195        query: DiscoveryQuery,
196        _cancel_token: Option<CancellationToken>,
197    ) -> Result<DiscoveryStream> {
198        use std::collections::HashSet;
199
200        let registry = self.registry.clone();
201
202        let stream = async_stream::stream! {
203            let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
204
205            loop {
206                let current: Vec<_> = {
207                    let instances = registry.instances.lock().unwrap();
208                    instances
209                        .iter()
210                        .filter(|instance| matches_query(instance, &query))
211                        .cloned()
212                        .collect()
213                };
214
215                let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
216
217                // Emit Added events for new instances
218                for instance in current {
219                    let id = instance.id();
220                    if known_instances.insert(id) {
221                        yield Ok(DiscoveryEvent::Added(instance));
222                    }
223                }
224
225                // Emit Removed events for instances that are gone
226                for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
227                    known_instances.remove(&id);
228                    yield Ok(DiscoveryEvent::Removed(id));
229                }
230
231                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
232            }
233        };
234
235        Ok(Box::pin(stream))
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use futures::StreamExt;
243
244    fn model_spec(
245        namespace: &str,
246        component: &str,
247        endpoint: &str,
248        model_name: &str,
249    ) -> DiscoverySpec {
250        DiscoverySpec::Model {
251            namespace: namespace.to_string(),
252            component: component.to_string(),
253            endpoint: endpoint.to_string(),
254            card_json: serde_json::json!({
255                "display_name": model_name,
256            }),
257            model_suffix: None,
258        }
259    }
260
261    fn lora_model_spec(
262        namespace: &str,
263        component: &str,
264        endpoint: &str,
265        model_name: &str,
266        source_path: &str,
267        lora_name: &str,
268    ) -> DiscoverySpec {
269        DiscoverySpec::Model {
270            namespace: namespace.to_string(),
271            component: component.to_string(),
272            endpoint: endpoint.to_string(),
273            card_json: serde_json::json!({
274                "display_name": model_name,
275                "source_path": source_path,
276                "lora": {
277                    "name": lora_name,
278                },
279            }),
280            model_suffix: Some(lora_name.to_string()),
281        }
282    }
283
284    #[tokio::test]
285    async fn test_mock_discovery_add_and_remove() {
286        let registry = SharedMockRegistry::new();
287        let client1 = MockDiscovery::new(Some(1), registry.clone());
288        let client2 = MockDiscovery::new(Some(2), registry.clone());
289
290        let spec = DiscoverySpec::Endpoint {
291            namespace: "test-ns".to_string(),
292            component: "test-comp".to_string(),
293            endpoint: "test-ep".to_string(),
294            transport: crate::component::TransportType::Nats("test-subject".to_string()),
295            device_type: None,
296        };
297
298        let query = DiscoveryQuery::Endpoint {
299            namespace: "test-ns".to_string(),
300            component: "test-comp".to_string(),
301            endpoint: "test-ep".to_string(),
302        };
303
304        // Start watching
305        let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
306
307        // Add first instance
308        let instance1 = client1.register(spec.clone()).await.unwrap();
309
310        let event = stream.next().await.unwrap().unwrap();
311        match event {
312            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
313                assert_eq!(inst.instance_id, 1);
314            }
315            _ => panic!("Expected Added event for instance-1"),
316        }
317
318        // Add second instance
319        client2.register(spec.clone()).await.unwrap();
320
321        let event = stream.next().await.unwrap().unwrap();
322        match event {
323            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
324                assert_eq!(inst.instance_id, 2);
325            }
326            _ => panic!("Expected Added event for instance-2"),
327        }
328
329        // Remove first instance
330        client1.unregister(instance1).await.unwrap();
331
332        let event = stream.next().await.unwrap().unwrap();
333        match event {
334            DiscoveryEvent::Removed(id) => {
335                let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
336                assert_eq!(endpoint_id.instance_id, 1);
337            }
338            _ => panic!("Expected Removed event for instance-1"),
339        }
340    }
341
342    #[tokio::test]
343    async fn register_allows_same_model_name_on_same_endpoint() {
344        let registry = SharedMockRegistry::new();
345        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
346        let discovery2 = MockDiscovery::new(Some(2), registry);
347        let spec = model_spec("ns", "comp", "generate", "model-a");
348
349        discovery1.register(spec.clone()).await.unwrap();
350        discovery2.register(spec).await.unwrap();
351
352        let instances = discovery1
353            .list(DiscoveryQuery::EndpointModels {
354                namespace: "ns".to_string(),
355                component: "comp".to_string(),
356                endpoint: "generate".to_string(),
357            })
358            .await
359            .unwrap();
360        assert_eq!(instances.len(), 2);
361    }
362
363    #[tokio::test]
364    async fn register_rejects_different_model_name_on_same_endpoint() {
365        let registry = SharedMockRegistry::new();
366        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
367        let discovery2 = MockDiscovery::new(Some(2), registry);
368
369        discovery1
370            .register(model_spec("ns", "comp", "generate", "model-a"))
371            .await
372            .unwrap();
373
374        let err = discovery2
375            .register(model_spec("ns", "comp", "generate", "model-b"))
376            .await
377            .unwrap_err();
378
379        assert!(err.to_string().contains(
380            "Cannot register model 'model-b' on endpoint 'ns/comp/generate': a different model 'model-a' is already registered there"
381        ));
382
383        let instances = discovery1
384            .list(DiscoveryQuery::EndpointModels {
385                namespace: "ns".to_string(),
386                component: "comp".to_string(),
387                endpoint: "generate".to_string(),
388            })
389            .await
390            .unwrap();
391        assert_eq!(instances.len(), 1);
392    }
393
394    #[tokio::test]
395    async fn register_allows_different_model_names_on_different_endpoints() {
396        let registry = SharedMockRegistry::new();
397        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
398        let discovery2 = MockDiscovery::new(Some(2), registry);
399
400        discovery1
401            .register(model_spec("ns", "comp", "generate-a", "model-a"))
402            .await
403            .unwrap();
404        discovery2
405            .register(model_spec("ns", "comp", "generate-b", "model-b"))
406            .await
407            .unwrap();
408    }
409
410    #[tokio::test]
411    async fn register_allows_lora_adapter_on_same_endpoint() {
412        let registry = SharedMockRegistry::new();
413        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
414        let discovery2 = MockDiscovery::new(Some(2), registry);
415
416        discovery1
417            .register(DiscoverySpec::Model {
418                namespace: "ns".to_string(),
419                component: "comp".to_string(),
420                endpoint: "generate".to_string(),
421                card_json: serde_json::json!({
422                    "display_name": "base-model",
423                    "source_path": "base-repo",
424                }),
425                model_suffix: None,
426            })
427            .await
428            .unwrap();
429
430        discovery2
431            .register(lora_model_spec(
432                "ns",
433                "comp",
434                "generate",
435                "adapter-a",
436                "base-repo",
437                "adapter-a",
438            ))
439            .await
440            .unwrap();
441    }
442
443    #[tokio::test]
444    async fn register_rejects_lora_adapter_for_different_base_model() {
445        let registry = SharedMockRegistry::new();
446        let discovery1 = MockDiscovery::new(Some(1), registry.clone());
447        let discovery2 = MockDiscovery::new(Some(2), registry);
448
449        discovery1
450            .register(DiscoverySpec::Model {
451                namespace: "ns".to_string(),
452                component: "comp".to_string(),
453                endpoint: "generate".to_string(),
454                card_json: serde_json::json!({
455                    "display_name": "base-model",
456                    "source_path": "base-repo",
457                }),
458                model_suffix: None,
459            })
460            .await
461            .unwrap();
462
463        let err = discovery2
464            .register(lora_model_spec(
465                "ns",
466                "comp",
467                "generate",
468                "adapter-a",
469                "other-base-repo",
470                "adapter-a",
471            ))
472            .await
473            .unwrap_err();
474
475        assert!(err.to_string().contains(
476            "Cannot register model 'adapter-a' on endpoint 'ns/comp/generate': a different model 'base-model' is already registered there"
477        ));
478    }
479}