dynamo_runtime/discovery/
mock.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{
5    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
6};
7use anyhow::Result;
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10use tokio_util::sync::CancellationToken;
11
12/// Shared in-memory registry for mock discovery
13#[derive(Clone, Default)]
14pub struct SharedMockRegistry {
15    instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
16}
17
18impl SharedMockRegistry {
19    pub fn new() -> Self {
20        Self::default()
21    }
22}
23
24/// Mock implementation of Discovery for testing
25/// We can potentially remove this once we have KVStoreDiscovery fully tested
26pub struct MockDiscovery {
27    instance_id: u64,
28    registry: SharedMockRegistry,
29}
30
31impl MockDiscovery {
32    pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
33        let instance_id = instance_id.unwrap_or_else(|| {
34            use std::sync::atomic::{AtomicU64, Ordering};
35            static COUNTER: AtomicU64 = AtomicU64::new(1);
36            COUNTER.fetch_add(1, Ordering::SeqCst)
37        });
38
39        Self {
40            instance_id,
41            registry,
42        }
43    }
44}
45
46/// Helper function to check if an instance matches a discovery query
47fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
48    match (instance, query) {
49        // Endpoint matching
50        (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
51        (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
52            &inst.namespace == namespace
53        }
54        (
55            DiscoveryInstance::Endpoint(inst),
56            DiscoveryQuery::ComponentEndpoints {
57                namespace,
58                component,
59            },
60        ) => &inst.namespace == namespace && &inst.component == component,
61        (
62            DiscoveryInstance::Endpoint(inst),
63            DiscoveryQuery::Endpoint {
64                namespace,
65                component,
66                endpoint,
67            },
68        ) => {
69            &inst.namespace == namespace
70                && &inst.component == component
71                && &inst.endpoint == endpoint
72        }
73
74        // Model matching
75        (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
76        (
77            DiscoveryInstance::Model {
78                namespace: inst_ns, ..
79            },
80            DiscoveryQuery::NamespacedModels { namespace },
81        ) => inst_ns == namespace,
82        (
83            DiscoveryInstance::Model {
84                namespace: inst_ns,
85                component: inst_comp,
86                ..
87            },
88            DiscoveryQuery::ComponentModels {
89                namespace,
90                component,
91            },
92        ) => inst_ns == namespace && inst_comp == component,
93        (
94            DiscoveryInstance::Model {
95                namespace: inst_ns,
96                component: inst_comp,
97                endpoint: inst_ep,
98                ..
99            },
100            DiscoveryQuery::EndpointModels {
101                namespace,
102                component,
103                endpoint,
104            },
105        ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
106
107        // Cross-type matches return false
108        (
109            DiscoveryInstance::Endpoint(_),
110            DiscoveryQuery::AllModels
111            | DiscoveryQuery::NamespacedModels { .. }
112            | DiscoveryQuery::ComponentModels { .. }
113            | DiscoveryQuery::EndpointModels { .. },
114        ) => false,
115        (
116            DiscoveryInstance::Model { .. },
117            DiscoveryQuery::AllEndpoints
118            | DiscoveryQuery::NamespacedEndpoints { .. }
119            | DiscoveryQuery::ComponentEndpoints { .. }
120            | DiscoveryQuery::Endpoint { .. },
121        ) => false,
122    }
123}
124
125#[async_trait]
126impl Discovery for MockDiscovery {
127    fn instance_id(&self) -> u64 {
128        self.instance_id
129    }
130
131    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
132        let instance = spec.with_instance_id(self.instance_id);
133
134        self.registry
135            .instances
136            .lock()
137            .unwrap()
138            .push(instance.clone());
139
140        Ok(instance)
141    }
142
143    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
144        let instances = self.registry.instances.lock().unwrap();
145        Ok(instances
146            .iter()
147            .filter(|instance| matches_query(instance, &query))
148            .cloned()
149            .collect())
150    }
151
152    async fn list_and_watch(
153        &self,
154        query: DiscoveryQuery,
155        _cancel_token: Option<CancellationToken>,
156    ) -> Result<DiscoveryStream> {
157        use std::collections::HashSet;
158
159        let registry = self.registry.clone();
160
161        let stream = async_stream::stream! {
162            let mut known_instances = HashSet::new();
163
164            loop {
165                let current: Vec<_> = {
166                    let instances = registry.instances.lock().unwrap();
167                    instances
168                        .iter()
169                        .filter(|instance| matches_query(instance, &query))
170                        .cloned()
171                        .collect()
172                };
173
174                let current_ids: HashSet<_> = current.iter().map(|i| {
175                    match i {
176                        DiscoveryInstance::Endpoint(inst) => inst.instance_id,
177                        DiscoveryInstance::Model { instance_id, .. } => *instance_id,
178                    }
179                }).collect();
180
181                // Emit Added events for new instances
182                for instance in current {
183                    let id = match &instance {
184                        DiscoveryInstance::Endpoint(inst) => inst.instance_id,
185                        DiscoveryInstance::Model { instance_id, .. } => *instance_id,
186                    };
187                    if known_instances.insert(id) {
188                        yield Ok(DiscoveryEvent::Added(instance));
189                    }
190                }
191
192                // Emit Removed events for instances that are gone
193                for id in known_instances.difference(&current_ids).cloned().collect::<Vec<_>>() {
194                    yield Ok(DiscoveryEvent::Removed(id));
195                    known_instances.remove(&id);
196                }
197
198                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
199            }
200        };
201
202        Ok(Box::pin(stream))
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use futures::StreamExt;
210
211    #[tokio::test]
212    async fn test_mock_discovery_add_and_remove() {
213        let registry = SharedMockRegistry::new();
214        let client1 = MockDiscovery::new(Some(1), registry.clone());
215        let client2 = MockDiscovery::new(Some(2), registry.clone());
216
217        let spec = DiscoverySpec::Endpoint {
218            namespace: "test-ns".to_string(),
219            component: "test-comp".to_string(),
220            endpoint: "test-ep".to_string(),
221            transport: crate::component::TransportType::Nats("test-subject".to_string()),
222        };
223
224        let query = DiscoveryQuery::Endpoint {
225            namespace: "test-ns".to_string(),
226            component: "test-comp".to_string(),
227            endpoint: "test-ep".to_string(),
228        };
229
230        // Start watching
231        let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
232
233        // Add first instance
234        client1.register(spec.clone()).await.unwrap();
235
236        let event = stream.next().await.unwrap().unwrap();
237        match event {
238            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
239                assert_eq!(inst.instance_id, 1);
240            }
241            _ => panic!("Expected Added event for instance-1"),
242        }
243
244        // Add second instance
245        client2.register(spec.clone()).await.unwrap();
246
247        let event = stream.next().await.unwrap().unwrap();
248        match event {
249            DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
250                assert_eq!(inst.instance_id, 2);
251            }
252            _ => panic!("Expected Added event for instance-2"),
253        }
254
255        // Remove first instance
256        registry.instances.lock().unwrap().retain(|i| match i {
257            DiscoveryInstance::Endpoint(inst) => inst.instance_id != 1,
258            DiscoveryInstance::Model { instance_id, .. } => *instance_id != 1,
259        });
260
261        let event = stream.next().await.unwrap().unwrap();
262        match event {
263            DiscoveryEvent::Removed(instance_id) => {
264                assert_eq!(instance_id, 1);
265            }
266            _ => panic!("Expected Removed event for instance-1"),
267        }
268    }
269}