dynamo_runtime/discovery/
metadata.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Result;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
9
10/// Metadata stored on each pod and exposed via HTTP endpoint
11/// This struct holds all discovery registrations for this pod instance
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct DiscoveryMetadata {
14    /// Registered endpoint instances (key: path string from EndpointInstanceId::to_path())
15    endpoints: HashMap<String, DiscoveryInstance>,
16    /// Registered model card instances (key: path string from ModelCardInstanceId::to_path())
17    model_cards: HashMap<String, DiscoveryInstance>,
18}
19
20impl DiscoveryMetadata {
21    /// Create a new empty metadata store
22    pub fn new() -> Self {
23        Self {
24            endpoints: HashMap::new(),
25            model_cards: HashMap::new(),
26        }
27    }
28
29    /// Register an endpoint instance
30    pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
31        match instance.id() {
32            DiscoveryInstanceId::Endpoint(key) => {
33                self.endpoints.insert(key.to_path(), instance);
34                Ok(())
35            }
36            DiscoveryInstanceId::Model(_) => {
37                anyhow::bail!("Cannot register non-endpoint instance as endpoint")
38            }
39        }
40    }
41
42    /// Register a model card instance
43    pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
44        match instance.id() {
45            DiscoveryInstanceId::Model(key) => {
46                self.model_cards.insert(key.to_path(), instance);
47                Ok(())
48            }
49            DiscoveryInstanceId::Endpoint(_) => {
50                anyhow::bail!("Cannot register non-model-card instance as model card")
51            }
52        }
53    }
54
55    /// Unregister an endpoint instance
56    pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
57        match instance.id() {
58            DiscoveryInstanceId::Endpoint(key) => {
59                self.endpoints.remove(&key.to_path());
60                Ok(())
61            }
62            DiscoveryInstanceId::Model(_) => {
63                anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
64            }
65        }
66    }
67
68    /// Unregister a model card instance
69    pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
70        match instance.id() {
71            DiscoveryInstanceId::Model(key) => {
72                self.model_cards.remove(&key.to_path());
73                Ok(())
74            }
75            DiscoveryInstanceId::Endpoint(_) => {
76                anyhow::bail!("Cannot unregister non-model-card instance as model card")
77            }
78        }
79    }
80
81    /// Get all registered endpoints
82    pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
83        self.endpoints.values().cloned().collect()
84    }
85
86    /// Get all registered model cards
87    pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
88        self.model_cards.values().cloned().collect()
89    }
90
91    /// Get all registered instances (endpoints and model cards)
92    pub fn get_all(&self) -> Vec<DiscoveryInstance> {
93        self.endpoints
94            .values()
95            .chain(self.model_cards.values())
96            .cloned()
97            .collect()
98    }
99
100    /// Filter this metadata by query
101    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
102        let all_instances = match query {
103            DiscoveryQuery::AllEndpoints
104            | DiscoveryQuery::NamespacedEndpoints { .. }
105            | DiscoveryQuery::ComponentEndpoints { .. }
106            | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
107
108            DiscoveryQuery::AllModels
109            | DiscoveryQuery::NamespacedModels { .. }
110            | DiscoveryQuery::ComponentModels { .. }
111            | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
112        };
113
114        filter_instances(all_instances, query)
115    }
116}
117
118impl Default for DiscoveryMetadata {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124/// Filter instances by query predicate
125fn filter_instances(
126    instances: Vec<DiscoveryInstance>,
127    query: &DiscoveryQuery,
128) -> Vec<DiscoveryInstance> {
129    match query {
130        DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
131
132        DiscoveryQuery::NamespacedEndpoints { namespace } => instances
133            .into_iter()
134            .filter(|inst| match inst {
135                DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
136                _ => false,
137            })
138            .collect(),
139
140        DiscoveryQuery::ComponentEndpoints {
141            namespace,
142            component,
143        } => instances
144            .into_iter()
145            .filter(|inst| match inst {
146                DiscoveryInstance::Endpoint(i) => {
147                    &i.namespace == namespace && &i.component == component
148                }
149                _ => false,
150            })
151            .collect(),
152
153        DiscoveryQuery::Endpoint {
154            namespace,
155            component,
156            endpoint,
157        } => instances
158            .into_iter()
159            .filter(|inst| match inst {
160                DiscoveryInstance::Endpoint(i) => {
161                    &i.namespace == namespace
162                        && &i.component == component
163                        && &i.endpoint == endpoint
164                }
165                _ => false,
166            })
167            .collect(),
168
169        DiscoveryQuery::NamespacedModels { namespace } => instances
170            .into_iter()
171            .filter(|inst| match inst {
172                DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
173                _ => false,
174            })
175            .collect(),
176
177        DiscoveryQuery::ComponentModels {
178            namespace,
179            component,
180        } => instances
181            .into_iter()
182            .filter(|inst| match inst {
183                DiscoveryInstance::Model {
184                    namespace: ns,
185                    component: comp,
186                    ..
187                } => ns == namespace && comp == component,
188                _ => false,
189            })
190            .collect(),
191
192        DiscoveryQuery::EndpointModels {
193            namespace,
194            component,
195            endpoint,
196        } => instances
197            .into_iter()
198            .filter(|inst| match inst {
199                DiscoveryInstance::Model {
200                    namespace: ns,
201                    component: comp,
202                    endpoint: ep,
203                    ..
204                } => ns == namespace && comp == component && ep == endpoint,
205                _ => false,
206            })
207            .collect(),
208    }
209}
210
211/// Snapshot of all discovered instances and their metadata
212#[derive(Clone, Debug)]
213pub struct MetadataSnapshot {
214    /// Map of instance_id -> metadata
215    pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
216    /// Map of instance_id -> CR generation for change detection
217    /// Keys match `instances` keys exactly - only ready pods with CRs are included
218    pub generations: HashMap<u64, i64>,
219    /// Sequence number for debugging
220    pub sequence: u64,
221    /// Timestamp for observability
222    pub timestamp: std::time::Instant,
223}
224
225impl MetadataSnapshot {
226    pub fn empty() -> Self {
227        Self {
228            instances: HashMap::new(),
229            generations: HashMap::new(),
230            sequence: 0,
231            timestamp: std::time::Instant::now(),
232        }
233    }
234
235    /// Compare with previous snapshot and return true if changed.
236    /// Logs diagnostic info about what changed.
237    /// This is done on the basis of the generation of the DynamoWorkerMetadata CRs that are owned by ready workers
238    pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
239        if self.generations == prev.generations {
240            tracing::trace!(
241                "Snapshot (seq={}): no changes, {} instances",
242                self.sequence,
243                self.instances.len()
244            );
245            return false;
246        }
247
248        // Compute diff for logging
249        let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
250        let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
251
252        let added: Vec<_> = curr_ids
253            .difference(&prev_ids)
254            .map(|id| format!("{:x}", id))
255            .collect();
256        let removed: Vec<_> = prev_ids
257            .difference(&curr_ids)
258            .map(|id| format!("{:x}", id))
259            .collect();
260        let updated: Vec<_> = self
261            .generations
262            .iter()
263            .filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
264            .map(|(k, _)| format!("{:x}", k))
265            .collect();
266
267        tracing::info!(
268            "Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
269            self.sequence,
270            self.instances.len(),
271            added,
272            removed,
273            updated
274        );
275
276        true
277    }
278
279    /// Filter all instances in the snapshot by query
280    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
281        self.instances
282            .values()
283            .flat_map(|metadata| metadata.filter(query))
284            .collect()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::component::{Instance, TransportType};
292
293    #[test]
294    fn test_metadata_serde() {
295        let mut metadata = DiscoveryMetadata::new();
296
297        // Add an endpoint
298        let instance = DiscoveryInstance::Endpoint(Instance {
299            namespace: "test".to_string(),
300            component: "comp1".to_string(),
301            endpoint: "ep1".to_string(),
302            instance_id: 123,
303            transport: TransportType::Nats("nats://localhost:4222".to_string()),
304        });
305
306        metadata.register_endpoint(instance).unwrap();
307
308        // Serialize
309        let json = serde_json::to_string(&metadata).unwrap();
310
311        // Deserialize
312        let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
313
314        assert_eq!(deserialized.endpoints.len(), 1);
315        assert_eq!(deserialized.model_cards.len(), 0);
316    }
317
318    #[tokio::test]
319    async fn test_concurrent_registration() {
320        use tokio::sync::RwLock;
321
322        let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
323
324        // Spawn multiple tasks registering concurrently
325        let handles: Vec<_> = (0..10)
326            .map(|i| {
327                let metadata = metadata.clone();
328                tokio::spawn(async move {
329                    let mut meta = metadata.write().await;
330                    let instance = DiscoveryInstance::Endpoint(Instance {
331                        namespace: "test".to_string(),
332                        component: "comp1".to_string(),
333                        endpoint: format!("ep{}", i),
334                        instance_id: i,
335                        transport: TransportType::Nats("nats://localhost:4222".to_string()),
336                    });
337                    meta.register_endpoint(instance).unwrap();
338                })
339            })
340            .collect();
341
342        // Wait for all to complete
343        for handle in handles {
344            handle.await.unwrap();
345        }
346
347        // Verify all registrations succeeded
348        let meta = metadata.read().await;
349        assert_eq!(meta.endpoints.len(), 10);
350    }
351
352    #[tokio::test]
353    async fn test_metadata_accessors() {
354        let mut metadata = DiscoveryMetadata::new();
355
356        // Register endpoints
357        for i in 0..3 {
358            let instance = DiscoveryInstance::Endpoint(Instance {
359                namespace: "test".to_string(),
360                component: "comp1".to_string(),
361                endpoint: format!("ep{}", i),
362                instance_id: i,
363                transport: TransportType::Nats("nats://localhost:4222".to_string()),
364            });
365            metadata.register_endpoint(instance).unwrap();
366        }
367
368        // Register model cards
369        for i in 0..2 {
370            let instance = DiscoveryInstance::Model {
371                namespace: "test".to_string(),
372                component: "comp1".to_string(),
373                endpoint: format!("ep{}", i),
374                instance_id: i,
375                card_json: serde_json::json!({"model": "test"}),
376                model_suffix: None,
377            };
378            metadata.register_model_card(instance).unwrap();
379        }
380
381        assert_eq!(metadata.get_all_endpoints().len(), 3);
382        assert_eq!(metadata.get_all_model_cards().len(), 2);
383        assert_eq!(metadata.get_all().len(), 5);
384    }
385}