Skip to main content

dynamo_runtime/discovery/
metadata.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Result;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
9
10/// Metadata stored on each pod and exposed via HTTP endpoint
11/// This struct holds all discovery registrations for this pod instance
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct DiscoveryMetadata {
14    /// Registered endpoint instances (key: path string from EndpointInstanceId::to_path())
15    endpoints: HashMap<String, DiscoveryInstance>,
16    /// Registered model card instances (key: path string from ModelCardInstanceId::to_path())
17    model_cards: HashMap<String, DiscoveryInstance>,
18    /// Registered event channel instances (key: path string from EventChannelInstanceId::to_path())
19    event_channels: HashMap<String, DiscoveryInstance>,
20}
21
22impl DiscoveryMetadata {
23    /// Create a new empty metadata store
24    pub fn new() -> Self {
25        Self {
26            endpoints: HashMap::new(),
27            model_cards: HashMap::new(),
28            event_channels: HashMap::new(),
29        }
30    }
31
32    /// Register an endpoint instance
33    pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
34        match instance.id() {
35            DiscoveryInstanceId::Endpoint(key) => {
36                self.endpoints.insert(key.to_path(), instance);
37                Ok(())
38            }
39            DiscoveryInstanceId::Model(_) => {
40                anyhow::bail!("Cannot register non-endpoint instance as endpoint")
41            }
42            DiscoveryInstanceId::EventChannel(_) => {
43                anyhow::bail!("Cannot register EventChannel instance as endpoint")
44            }
45        }
46    }
47
48    /// Register a model card instance
49    pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
50        match instance.id() {
51            DiscoveryInstanceId::Model(key) => {
52                self.model_cards.insert(key.to_path(), instance);
53                Ok(())
54            }
55            DiscoveryInstanceId::Endpoint(_) => {
56                anyhow::bail!("Cannot register non-model-card instance as model card")
57            }
58            DiscoveryInstanceId::EventChannel(_) => {
59                anyhow::bail!("Cannot register EventChannel instance as model card")
60            }
61        }
62    }
63
64    /// Unregister an endpoint instance
65    pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
66        match instance.id() {
67            DiscoveryInstanceId::Endpoint(key) => {
68                self.endpoints.remove(&key.to_path());
69                Ok(())
70            }
71            DiscoveryInstanceId::Model(_) => {
72                anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
73            }
74            DiscoveryInstanceId::EventChannel(_) => {
75                anyhow::bail!("Cannot unregister EventChannel instance as endpoint")
76            }
77        }
78    }
79
80    /// Unregister a model card instance
81    pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
82        match instance.id() {
83            DiscoveryInstanceId::Model(key) => {
84                self.model_cards.remove(&key.to_path());
85                Ok(())
86            }
87            DiscoveryInstanceId::Endpoint(_) => {
88                anyhow::bail!("Cannot unregister non-model-card instance as model card")
89            }
90            DiscoveryInstanceId::EventChannel(_) => {
91                anyhow::bail!("Cannot unregister EventChannel instance as model card")
92            }
93        }
94    }
95
96    /// Register an event channel instance
97    pub fn register_event_channel(&mut self, instance: DiscoveryInstance) -> Result<()> {
98        match instance.id() {
99            DiscoveryInstanceId::EventChannel(key) => {
100                self.event_channels.insert(key.to_path(), instance);
101                Ok(())
102            }
103            DiscoveryInstanceId::Endpoint(_) => {
104                anyhow::bail!("Cannot register Endpoint instance as event channel")
105            }
106            DiscoveryInstanceId::Model(_) => {
107                anyhow::bail!("Cannot register Model instance as event channel")
108            }
109        }
110    }
111
112    /// Unregister an event channel instance
113    pub fn unregister_event_channel(&mut self, instance: &DiscoveryInstance) -> Result<()> {
114        match instance.id() {
115            DiscoveryInstanceId::EventChannel(key) => {
116                self.event_channels.remove(&key.to_path());
117                Ok(())
118            }
119            DiscoveryInstanceId::Endpoint(_) => {
120                anyhow::bail!("Cannot unregister Endpoint instance as event channel")
121            }
122            DiscoveryInstanceId::Model(_) => {
123                anyhow::bail!("Cannot unregister Model instance as event channel")
124            }
125        }
126    }
127
128    /// Get all registered endpoints
129    pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
130        self.endpoints.values().cloned().collect()
131    }
132
133    /// Get all registered model cards
134    pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
135        self.model_cards.values().cloned().collect()
136    }
137
138    /// Get all registered event channels
139    pub fn get_all_event_channels(&self) -> Vec<DiscoveryInstance> {
140        self.event_channels.values().cloned().collect()
141    }
142
143    /// Get all registered instances (endpoints, model cards, and event channels)
144    pub fn get_all(&self) -> Vec<DiscoveryInstance> {
145        self.endpoints
146            .values()
147            .chain(self.model_cards.values())
148            .chain(self.event_channels.values())
149            .cloned()
150            .collect()
151    }
152
153    /// Filter this metadata by query
154    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
155        let all_instances = match query {
156            DiscoveryQuery::AllEndpoints
157            | DiscoveryQuery::NamespacedEndpoints { .. }
158            | DiscoveryQuery::ComponentEndpoints { .. }
159            | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
160
161            DiscoveryQuery::AllModels
162            | DiscoveryQuery::NamespacedModels { .. }
163            | DiscoveryQuery::ComponentModels { .. }
164            | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
165
166            // EventChannel queries now return actual event channels
167            DiscoveryQuery::EventChannels(_) => self.get_all_event_channels(),
168        };
169
170        filter_instances(all_instances, query)
171    }
172}
173
174impl Default for DiscoveryMetadata {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180/// Filter instances by query predicate
181fn filter_instances(
182    instances: Vec<DiscoveryInstance>,
183    query: &DiscoveryQuery,
184) -> Vec<DiscoveryInstance> {
185    match query {
186        DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
187
188        DiscoveryQuery::NamespacedEndpoints { namespace } => instances
189            .into_iter()
190            .filter(|inst| match inst {
191                DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
192                _ => false,
193            })
194            .collect(),
195
196        DiscoveryQuery::ComponentEndpoints {
197            namespace,
198            component,
199        } => instances
200            .into_iter()
201            .filter(|inst| match inst {
202                DiscoveryInstance::Endpoint(i) => {
203                    &i.namespace == namespace && &i.component == component
204                }
205                _ => false,
206            })
207            .collect(),
208
209        DiscoveryQuery::Endpoint {
210            namespace,
211            component,
212            endpoint,
213        } => instances
214            .into_iter()
215            .filter(|inst| match inst {
216                DiscoveryInstance::Endpoint(i) => {
217                    &i.namespace == namespace
218                        && &i.component == component
219                        && &i.endpoint == endpoint
220                }
221                _ => false,
222            })
223            .collect(),
224
225        DiscoveryQuery::NamespacedModels { namespace } => instances
226            .into_iter()
227            .filter(|inst| match inst {
228                DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
229                _ => false,
230            })
231            .collect(),
232
233        DiscoveryQuery::ComponentModels {
234            namespace,
235            component,
236        } => instances
237            .into_iter()
238            .filter(|inst| match inst {
239                DiscoveryInstance::Model {
240                    namespace: ns,
241                    component: comp,
242                    ..
243                } => ns == namespace && comp == component,
244                _ => false,
245            })
246            .collect(),
247
248        DiscoveryQuery::EndpointModels {
249            namespace,
250            component,
251            endpoint,
252        } => instances
253            .into_iter()
254            .filter(|inst| match inst {
255                DiscoveryInstance::Model {
256                    namespace: ns,
257                    component: comp,
258                    endpoint: ep,
259                    ..
260                } => ns == namespace && comp == component && ep == endpoint,
261                _ => false,
262            })
263            .collect(),
264
265        // EventChannel queries - unified filtering with optional scope filters
266        DiscoveryQuery::EventChannels(query) => instances
267            .into_iter()
268            .filter(|inst| match inst {
269                DiscoveryInstance::EventChannel {
270                    namespace: ns,
271                    component: comp,
272                    topic: t,
273                    ..
274                } => {
275                    // Filter by namespace if specified
276                    query.namespace.as_ref().is_none_or(|qns| qns == ns)
277                        // Filter by component if specified
278                        && query.component.as_ref().is_none_or(|qc| qc == comp)
279                        // Filter by topic if specified
280                        && query.topic.as_ref().is_none_or(|qt| qt == t)
281                }
282                _ => false,
283            })
284            .collect(),
285    }
286}
287
288/// Snapshot of all discovered instances and their metadata
289#[derive(Clone, Debug)]
290pub struct MetadataSnapshot {
291    /// Map of instance_id -> metadata
292    pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
293    /// Map of instance_id -> CR generation for change detection
294    /// Keys match `instances` keys exactly - only ready pods with CRs are included
295    pub generations: HashMap<u64, i64>,
296    /// Sequence number for debugging
297    pub sequence: u64,
298    /// Timestamp for observability
299    pub timestamp: std::time::Instant,
300}
301
302impl MetadataSnapshot {
303    pub fn empty() -> Self {
304        Self {
305            instances: HashMap::new(),
306            generations: HashMap::new(),
307            sequence: 0,
308            timestamp: std::time::Instant::now(),
309        }
310    }
311
312    /// Compare with previous snapshot and return true if changed.
313    /// Logs diagnostic info about what changed.
314    /// This is done on the basis of the generation of the DynamoWorkerMetadata CRs that are owned by ready workers
315    pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
316        if self.generations == prev.generations {
317            tracing::trace!(
318                "Snapshot (seq={}): no changes, {} instances",
319                self.sequence,
320                self.instances.len()
321            );
322            return false;
323        }
324
325        // Compute diff for logging
326        let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
327        let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
328
329        let added: Vec<_> = curr_ids
330            .difference(&prev_ids)
331            .map(|id| format!("{:x}", id))
332            .collect();
333        let removed: Vec<_> = prev_ids
334            .difference(&curr_ids)
335            .map(|id| format!("{:x}", id))
336            .collect();
337        let updated: Vec<_> = self
338            .generations
339            .iter()
340            .filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
341            .map(|(k, _)| format!("{:x}", k))
342            .collect();
343
344        tracing::info!(
345            "Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
346            self.sequence,
347            self.instances.len(),
348            added,
349            removed,
350            updated
351        );
352
353        true
354    }
355
356    /// Filter all instances in the snapshot by query
357    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
358        self.instances
359            .values()
360            .flat_map(|metadata| metadata.filter(query))
361            .collect()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::component::{Instance, TransportType};
369    use crate::discovery::EventChannelQuery;
370
371    #[test]
372    fn test_metadata_serde() {
373        let mut metadata = DiscoveryMetadata::new();
374
375        // Add an endpoint
376        let instance = DiscoveryInstance::Endpoint(Instance {
377            namespace: "test".to_string(),
378            component: "comp1".to_string(),
379            endpoint: "ep1".to_string(),
380            instance_id: 123,
381            transport: TransportType::Nats("nats://localhost:4222".to_string()),
382        });
383
384        metadata.register_endpoint(instance).unwrap();
385
386        // Serialize
387        let json = serde_json::to_string(&metadata).unwrap();
388
389        // Deserialize
390        let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
391
392        assert_eq!(deserialized.endpoints.len(), 1);
393        assert_eq!(deserialized.model_cards.len(), 0);
394    }
395
396    #[tokio::test]
397    async fn test_concurrent_registration() {
398        use tokio::sync::RwLock;
399
400        let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
401
402        // Spawn multiple tasks registering concurrently
403        let handles: Vec<_> = (0..10)
404            .map(|i| {
405                let metadata = metadata.clone();
406                tokio::spawn(async move {
407                    let mut meta = metadata.write().await;
408                    let instance = DiscoveryInstance::Endpoint(Instance {
409                        namespace: "test".to_string(),
410                        component: "comp1".to_string(),
411                        endpoint: format!("ep{}", i),
412                        instance_id: i,
413                        transport: TransportType::Nats("nats://localhost:4222".to_string()),
414                    });
415                    meta.register_endpoint(instance).unwrap();
416                })
417            })
418            .collect();
419
420        // Wait for all to complete
421        for handle in handles {
422            handle.await.unwrap();
423        }
424
425        // Verify all registrations succeeded
426        let meta = metadata.read().await;
427        assert_eq!(meta.endpoints.len(), 10);
428    }
429
430    #[tokio::test]
431    async fn test_metadata_accessors() {
432        let mut metadata = DiscoveryMetadata::new();
433
434        // Register endpoints
435        for i in 0..3 {
436            let instance = DiscoveryInstance::Endpoint(Instance {
437                namespace: "test".to_string(),
438                component: "comp1".to_string(),
439                endpoint: format!("ep{}", i),
440                instance_id: i,
441                transport: TransportType::Nats("nats://localhost:4222".to_string()),
442            });
443            metadata.register_endpoint(instance).unwrap();
444        }
445
446        // Register model cards
447        for i in 0..2 {
448            let instance = DiscoveryInstance::Model {
449                namespace: "test".to_string(),
450                component: "comp1".to_string(),
451                endpoint: format!("ep{}", i),
452                instance_id: i,
453                card_json: serde_json::json!({"model": "test"}),
454                model_suffix: None,
455            };
456            metadata.register_model_card(instance).unwrap();
457        }
458
459        assert_eq!(metadata.get_all_endpoints().len(), 3);
460        assert_eq!(metadata.get_all_model_cards().len(), 2);
461        assert_eq!(metadata.get_all().len(), 5);
462    }
463
464    #[tokio::test]
465    async fn test_event_channel_registration() {
466        use crate::discovery::EventTransport;
467
468        let mut metadata = DiscoveryMetadata::new();
469
470        // Register event channels
471        for i in 0..3 {
472            let instance = DiscoveryInstance::EventChannel {
473                namespace: "test".to_string(),
474                component: "comp1".to_string(),
475                topic: "test-topic".to_string(),
476                instance_id: i,
477                transport: EventTransport::zmq(format!("tcp://localhost:{}", 5000 + i)),
478            };
479            metadata.register_event_channel(instance).unwrap();
480        }
481
482        // Test get_all_event_channels
483        assert_eq!(metadata.get_all_event_channels().len(), 3);
484
485        // Test get_all includes event channels
486        assert_eq!(metadata.get_all().len(), 3);
487
488        // Test filter by all event channels
489        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(EventChannelQuery::all()));
490        assert_eq!(filtered.len(), 3);
491
492        // Test filter by component
493        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
494            EventChannelQuery::component("test", "comp1"),
495        ));
496        assert_eq!(filtered.len(), 3);
497
498        // Test filter with non-matching query
499        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
500            EventChannelQuery::component("other", "comp1"),
501        ));
502        assert_eq!(filtered.len(), 0);
503
504        // Test unregister
505        let instance = DiscoveryInstance::EventChannel {
506            namespace: "test".to_string(),
507            component: "comp1".to_string(),
508            topic: "test-topic".to_string(),
509            instance_id: 0,
510            transport: EventTransport::zmq("tcp://localhost:5000"),
511        };
512        metadata.unregister_event_channel(&instance).unwrap();
513        assert_eq!(metadata.get_all_event_channels().len(), 2);
514    }
515
516    #[tokio::test]
517    async fn test_mixed_instances() {
518        use crate::discovery::EventTransport;
519
520        let mut metadata = DiscoveryMetadata::new();
521
522        // Register one of each type
523        let endpoint = DiscoveryInstance::Endpoint(Instance {
524            namespace: "test".to_string(),
525            component: "comp1".to_string(),
526            endpoint: "ep1".to_string(),
527            instance_id: 1,
528            transport: TransportType::Nats("nats://localhost:4222".to_string()),
529        });
530        metadata.register_endpoint(endpoint).unwrap();
531
532        let model = DiscoveryInstance::Model {
533            namespace: "test".to_string(),
534            component: "comp1".to_string(),
535            endpoint: "ep1".to_string(),
536            instance_id: 2,
537            card_json: serde_json::json!({"model": "test"}),
538            model_suffix: None,
539        };
540        metadata.register_model_card(model).unwrap();
541
542        let event_channel = DiscoveryInstance::EventChannel {
543            namespace: "test".to_string(),
544            component: "comp1".to_string(),
545            topic: "test-topic".to_string(),
546            instance_id: 3,
547            transport: EventTransport::zmq("tcp://localhost:5000"),
548        };
549        metadata.register_event_channel(event_channel).unwrap();
550
551        // Verify get_all returns all three
552        assert_eq!(metadata.get_all().len(), 3);
553        assert_eq!(metadata.get_all_endpoints().len(), 1);
554        assert_eq!(metadata.get_all_model_cards().len(), 1);
555        assert_eq!(metadata.get_all_event_channels().len(), 1);
556    }
557}