Skip to main content

dynamo_runtime/discovery/
metadata.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Result;
5use serde::Deserialize as _;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
10
11/// Deserializes a JSON `null` or missing field as `T::default()`.
12///
13/// Kubernetes Server-Side Apply with `schema = "disabled"` can write an empty
14/// object `{}` as `null` for nested free-form fields. Without this helper, the
15/// daemon fails to deserialize the `DynamoWorkerMetadata` CR, and the worker is
16/// excluded from the `MetadataSnapshot` (i.e. invisible to service discovery),
17/// causing `KubeDiscoveryClient::list` to return 0 instances and all inference
18/// requests to 404. One concrete example is vLLM elastic EP scaling:
19/// `scale_elastic_ep` reinitializes event plane sockets, which triggers
20/// `unregister_event_channel()`, leaving `event_channels` as an empty map `{}`.
21/// SSA then writes it back as `null`, breaking deserialization until this helper
22/// treats `null` as an empty map. The issue applies to any event plane
23/// implementation, not only a specific transport.
24fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
25where
26    D: serde::Deserializer<'de>,
27    T: Default + serde::Deserialize<'de>,
28{
29    Ok(Option::<T>::deserialize(deserializer)?.unwrap_or_default())
30}
31
32/// Metadata stored on each pod and exposed via HTTP endpoint
33/// This struct holds all discovery registrations for this pod instance
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
35pub struct DiscoveryMetadata {
36    /// Registered endpoint instances (key: path string from EndpointInstanceId::to_path())
37    #[serde(default, deserialize_with = "deserialize_null_default")]
38    endpoints: HashMap<String, DiscoveryInstance>,
39    /// Registered model card instances (key: path string from ModelCardInstanceId::to_path())
40    #[serde(default, deserialize_with = "deserialize_null_default")]
41    model_cards: HashMap<String, DiscoveryInstance>,
42    /// Registered event channel instances (key: path string from EventChannelInstanceId::to_path())
43    #[serde(default, deserialize_with = "deserialize_null_default")]
44    event_channels: HashMap<String, DiscoveryInstance>,
45}
46
47impl DiscoveryMetadata {
48    /// Create a new empty metadata store
49    pub fn new() -> Self {
50        Self {
51            endpoints: HashMap::new(),
52            model_cards: HashMap::new(),
53            event_channels: HashMap::new(),
54        }
55    }
56
57    /// Register an endpoint instance
58    pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
59        match instance.id() {
60            DiscoveryInstanceId::Endpoint(key) => {
61                self.endpoints.insert(key.to_path(), instance);
62                Ok(())
63            }
64            DiscoveryInstanceId::Model(_) => {
65                anyhow::bail!("Cannot register non-endpoint instance as endpoint")
66            }
67            DiscoveryInstanceId::EventChannel(_) => {
68                anyhow::bail!("Cannot register EventChannel instance as endpoint")
69            }
70        }
71    }
72
73    /// Register a model card instance
74    pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
75        match instance.id() {
76            DiscoveryInstanceId::Model(key) => {
77                self.model_cards.insert(key.to_path(), instance);
78                Ok(())
79            }
80            DiscoveryInstanceId::Endpoint(_) => {
81                anyhow::bail!("Cannot register non-model-card instance as model card")
82            }
83            DiscoveryInstanceId::EventChannel(_) => {
84                anyhow::bail!("Cannot register EventChannel instance as model card")
85            }
86        }
87    }
88
89    /// Unregister an endpoint instance
90    pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
91        match instance.id() {
92            DiscoveryInstanceId::Endpoint(key) => {
93                self.endpoints.remove(&key.to_path());
94                Ok(())
95            }
96            DiscoveryInstanceId::Model(_) => {
97                anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
98            }
99            DiscoveryInstanceId::EventChannel(_) => {
100                anyhow::bail!("Cannot unregister EventChannel instance as endpoint")
101            }
102        }
103    }
104
105    /// Unregister a model card instance
106    pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
107        match instance.id() {
108            DiscoveryInstanceId::Model(key) => {
109                self.model_cards.remove(&key.to_path());
110                Ok(())
111            }
112            DiscoveryInstanceId::Endpoint(_) => {
113                anyhow::bail!("Cannot unregister non-model-card instance as model card")
114            }
115            DiscoveryInstanceId::EventChannel(_) => {
116                anyhow::bail!("Cannot unregister EventChannel instance as model card")
117            }
118        }
119    }
120
121    /// Register an event channel instance
122    pub fn register_event_channel(&mut self, instance: DiscoveryInstance) -> Result<()> {
123        match instance.id() {
124            DiscoveryInstanceId::EventChannel(key) => {
125                self.event_channels.insert(key.to_path(), instance);
126                Ok(())
127            }
128            DiscoveryInstanceId::Endpoint(_) => {
129                anyhow::bail!("Cannot register Endpoint instance as event channel")
130            }
131            DiscoveryInstanceId::Model(_) => {
132                anyhow::bail!("Cannot register Model instance as event channel")
133            }
134        }
135    }
136
137    /// Unregister an event channel instance
138    pub fn unregister_event_channel(&mut self, instance: &DiscoveryInstance) -> Result<()> {
139        match instance.id() {
140            DiscoveryInstanceId::EventChannel(key) => {
141                self.event_channels.remove(&key.to_path());
142                Ok(())
143            }
144            DiscoveryInstanceId::Endpoint(_) => {
145                anyhow::bail!("Cannot unregister Endpoint instance as event channel")
146            }
147            DiscoveryInstanceId::Model(_) => {
148                anyhow::bail!("Cannot unregister Model instance as event channel")
149            }
150        }
151    }
152
153    /// Get all registered endpoints
154    pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
155        self.endpoints.values().cloned().collect()
156    }
157
158    /// Get all registered model cards
159    pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
160        self.model_cards.values().cloned().collect()
161    }
162
163    /// Get all registered event channels
164    pub fn get_all_event_channels(&self) -> Vec<DiscoveryInstance> {
165        self.event_channels.values().cloned().collect()
166    }
167
168    /// Get all registered instances (endpoints, model cards, and event channels)
169    pub fn get_all(&self) -> Vec<DiscoveryInstance> {
170        self.endpoints
171            .values()
172            .chain(self.model_cards.values())
173            .chain(self.event_channels.values())
174            .cloned()
175            .collect()
176    }
177
178    /// Filter this metadata by query
179    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
180        let all_instances = match query {
181            DiscoveryQuery::AllEndpoints
182            | DiscoveryQuery::NamespacedEndpoints { .. }
183            | DiscoveryQuery::ComponentEndpoints { .. }
184            | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
185
186            DiscoveryQuery::AllModels
187            | DiscoveryQuery::NamespacedModels { .. }
188            | DiscoveryQuery::ComponentModels { .. }
189            | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
190
191            // EventChannel queries now return actual event channels
192            DiscoveryQuery::EventChannels(_) => self.get_all_event_channels(),
193        };
194
195        filter_instances(all_instances, query)
196    }
197}
198
199impl Default for DiscoveryMetadata {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205/// Filter instances by query predicate
206fn filter_instances(
207    instances: Vec<DiscoveryInstance>,
208    query: &DiscoveryQuery,
209) -> Vec<DiscoveryInstance> {
210    match query {
211        DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
212
213        DiscoveryQuery::NamespacedEndpoints { namespace } => instances
214            .into_iter()
215            .filter(|inst| match inst {
216                DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
217                _ => false,
218            })
219            .collect(),
220
221        DiscoveryQuery::ComponentEndpoints {
222            namespace,
223            component,
224        } => instances
225            .into_iter()
226            .filter(|inst| match inst {
227                DiscoveryInstance::Endpoint(i) => {
228                    &i.namespace == namespace && &i.component == component
229                }
230                _ => false,
231            })
232            .collect(),
233
234        DiscoveryQuery::Endpoint {
235            namespace,
236            component,
237            endpoint,
238        } => instances
239            .into_iter()
240            .filter(|inst| match inst {
241                DiscoveryInstance::Endpoint(i) => {
242                    &i.namespace == namespace
243                        && &i.component == component
244                        && &i.endpoint == endpoint
245                }
246                _ => false,
247            })
248            .collect(),
249
250        DiscoveryQuery::NamespacedModels { namespace } => instances
251            .into_iter()
252            .filter(|inst| match inst {
253                DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
254                _ => false,
255            })
256            .collect(),
257
258        DiscoveryQuery::ComponentModels {
259            namespace,
260            component,
261        } => instances
262            .into_iter()
263            .filter(|inst| match inst {
264                DiscoveryInstance::Model {
265                    namespace: ns,
266                    component: comp,
267                    ..
268                } => ns == namespace && comp == component,
269                _ => false,
270            })
271            .collect(),
272
273        DiscoveryQuery::EndpointModels {
274            namespace,
275            component,
276            endpoint,
277        } => instances
278            .into_iter()
279            .filter(|inst| match inst {
280                DiscoveryInstance::Model {
281                    namespace: ns,
282                    component: comp,
283                    endpoint: ep,
284                    ..
285                } => ns == namespace && comp == component && ep == endpoint,
286                _ => false,
287            })
288            .collect(),
289
290        // EventChannel queries - unified filtering with optional scope filters
291        DiscoveryQuery::EventChannels(query) => instances
292            .into_iter()
293            .filter(|inst| match inst {
294                DiscoveryInstance::EventChannel {
295                    namespace: ns,
296                    component: comp,
297                    topic: t,
298                    ..
299                } => {
300                    // Filter by namespace if specified
301                    query.namespace.as_ref().is_none_or(|qns| qns == ns)
302                        // Filter by component if specified
303                        && query.component.as_ref().is_none_or(|qc| qc == comp)
304                        // Filter by topic if specified
305                        && query.topic.as_ref().is_none_or(|qt| qt == t)
306                }
307                _ => false,
308            })
309            .collect(),
310    }
311}
312
313/// Snapshot of all discovered instances and their metadata
314#[derive(Clone, Debug)]
315pub struct MetadataSnapshot {
316    /// Map of instance_id -> metadata
317    pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
318    /// Map of instance_id -> CR generation for change detection
319    /// Keys match `instances` keys exactly - only ready pods with CRs are included
320    pub generations: HashMap<u64, i64>,
321    /// Sequence number for debugging
322    pub sequence: u64,
323    /// Timestamp for observability
324    pub timestamp: std::time::Instant,
325}
326
327impl MetadataSnapshot {
328    pub fn empty() -> Self {
329        Self {
330            instances: HashMap::new(),
331            generations: HashMap::new(),
332            sequence: 0,
333            timestamp: std::time::Instant::now(),
334        }
335    }
336
337    /// Compare with previous snapshot and return true if changed.
338    /// Logs diagnostic info about what changed.
339    /// This is done on the basis of the generation of the DynamoWorkerMetadata CRs that are owned by ready workers
340    pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
341        if self.generations == prev.generations {
342            tracing::trace!(
343                "Snapshot (seq={}): no changes, {} instances",
344                self.sequence,
345                self.instances.len()
346            );
347            return false;
348        }
349
350        // Compute diff for logging
351        let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
352        let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
353
354        let added: Vec<_> = curr_ids
355            .difference(&prev_ids)
356            .map(|id| format!("{:x}", id))
357            .collect();
358        let removed: Vec<_> = prev_ids
359            .difference(&curr_ids)
360            .map(|id| format!("{:x}", id))
361            .collect();
362        let updated: Vec<_> = self
363            .generations
364            .iter()
365            .filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
366            .map(|(k, _)| format!("{:x}", k))
367            .collect();
368
369        tracing::info!(
370            "Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
371            self.sequence,
372            self.instances.len(),
373            added,
374            removed,
375            updated
376        );
377
378        true
379    }
380
381    /// Filter all instances in the snapshot by query
382    pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
383        self.instances
384            .values()
385            .flat_map(|metadata| metadata.filter(query))
386            .collect()
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::component::{Instance, TransportType};
394    use crate::discovery::EventChannelQuery;
395
396    #[test]
397    fn test_metadata_serde() {
398        let mut metadata = DiscoveryMetadata::new();
399
400        // Add an endpoint
401        let instance = DiscoveryInstance::Endpoint(Instance {
402            namespace: "test".to_string(),
403            component: "comp1".to_string(),
404            endpoint: "ep1".to_string(),
405            instance_id: 123,
406            transport: TransportType::Nats("nats://localhost:4222".to_string()),
407            device_type: None,
408        });
409
410        metadata.register_endpoint(instance).unwrap();
411
412        // Serialize
413        let json = serde_json::to_string(&metadata).unwrap();
414
415        // Deserialize
416        let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
417
418        assert_eq!(deserialized.endpoints.len(), 1);
419        assert_eq!(deserialized.model_cards.len(), 0);
420    }
421
422    #[tokio::test]
423    async fn test_concurrent_registration() {
424        use tokio::sync::RwLock;
425
426        let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
427
428        // Spawn multiple tasks registering concurrently
429        let handles: Vec<_> = (0..10)
430            .map(|i| {
431                let metadata = metadata.clone();
432                tokio::spawn(async move {
433                    let mut meta = metadata.write().await;
434                    let instance = DiscoveryInstance::Endpoint(Instance {
435                        namespace: "test".to_string(),
436                        component: "comp1".to_string(),
437                        endpoint: format!("ep{}", i),
438                        instance_id: i,
439                        transport: TransportType::Nats("nats://localhost:4222".to_string()),
440                        device_type: None,
441                    });
442                    meta.register_endpoint(instance).unwrap();
443                })
444            })
445            .collect();
446
447        // Wait for all to complete
448        for handle in handles {
449            handle.await.unwrap();
450        }
451
452        // Verify all registrations succeeded
453        let meta = metadata.read().await;
454        assert_eq!(meta.endpoints.len(), 10);
455    }
456
457    #[tokio::test]
458    async fn test_metadata_accessors() {
459        let mut metadata = DiscoveryMetadata::new();
460
461        // Register endpoints
462        for i in 0..3 {
463            let instance = DiscoveryInstance::Endpoint(Instance {
464                namespace: "test".to_string(),
465                component: "comp1".to_string(),
466                endpoint: format!("ep{}", i),
467                instance_id: i,
468                transport: TransportType::Nats("nats://localhost:4222".to_string()),
469                device_type: None,
470            });
471            metadata.register_endpoint(instance).unwrap();
472        }
473
474        // Register model cards
475        for i in 0..2 {
476            let instance = DiscoveryInstance::Model {
477                namespace: "test".to_string(),
478                component: "comp1".to_string(),
479                endpoint: format!("ep{}", i),
480                instance_id: i,
481                card_json: serde_json::json!({"model": "test"}),
482                model_suffix: None,
483            };
484            metadata.register_model_card(instance).unwrap();
485        }
486
487        assert_eq!(metadata.get_all_endpoints().len(), 3);
488        assert_eq!(metadata.get_all_model_cards().len(), 2);
489        assert_eq!(metadata.get_all().len(), 5);
490    }
491
492    #[tokio::test]
493    async fn test_event_channel_registration() {
494        use crate::discovery::EventTransport;
495
496        let mut metadata = DiscoveryMetadata::new();
497
498        // Register event channels
499        for i in 0..3 {
500            let instance = DiscoveryInstance::EventChannel {
501                namespace: "test".to_string(),
502                component: "comp1".to_string(),
503                topic: "test-topic".to_string(),
504                instance_id: i,
505                transport: EventTransport::zmq(format!("tcp://localhost:{}", 5000 + i)),
506            };
507            metadata.register_event_channel(instance).unwrap();
508        }
509
510        // Test get_all_event_channels
511        assert_eq!(metadata.get_all_event_channels().len(), 3);
512
513        // Test get_all includes event channels
514        assert_eq!(metadata.get_all().len(), 3);
515
516        // Test filter by all event channels
517        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(EventChannelQuery::all()));
518        assert_eq!(filtered.len(), 3);
519
520        // Test filter by component
521        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
522            EventChannelQuery::component("test", "comp1"),
523        ));
524        assert_eq!(filtered.len(), 3);
525
526        // Test filter with non-matching query
527        let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
528            EventChannelQuery::component("other", "comp1"),
529        ));
530        assert_eq!(filtered.len(), 0);
531
532        // Test unregister
533        let instance = DiscoveryInstance::EventChannel {
534            namespace: "test".to_string(),
535            component: "comp1".to_string(),
536            topic: "test-topic".to_string(),
537            instance_id: 0,
538            transport: EventTransport::zmq("tcp://localhost:5000"),
539        };
540        metadata.unregister_event_channel(&instance).unwrap();
541        assert_eq!(metadata.get_all_event_channels().len(), 2);
542    }
543
544    #[tokio::test]
545    async fn test_mixed_instances() {
546        use crate::discovery::EventTransport;
547
548        let mut metadata = DiscoveryMetadata::new();
549
550        // Register one of each type
551        let endpoint = DiscoveryInstance::Endpoint(Instance {
552            namespace: "test".to_string(),
553            component: "comp1".to_string(),
554            endpoint: "ep1".to_string(),
555            instance_id: 1,
556            transport: TransportType::Nats("nats://localhost:4222".to_string()),
557            device_type: None,
558        });
559        metadata.register_endpoint(endpoint).unwrap();
560
561        let model = DiscoveryInstance::Model {
562            namespace: "test".to_string(),
563            component: "comp1".to_string(),
564            endpoint: "ep1".to_string(),
565            instance_id: 2,
566            card_json: serde_json::json!({"model": "test"}),
567            model_suffix: None,
568        };
569        metadata.register_model_card(model).unwrap();
570
571        let event_channel = DiscoveryInstance::EventChannel {
572            namespace: "test".to_string(),
573            component: "comp1".to_string(),
574            topic: "test-topic".to_string(),
575            instance_id: 3,
576            transport: EventTransport::zmq("tcp://localhost:5000"),
577        };
578        metadata.register_event_channel(event_channel).unwrap();
579
580        // Verify get_all returns all three
581        assert_eq!(metadata.get_all().len(), 3);
582        assert_eq!(metadata.get_all_endpoints().len(), 1);
583        assert_eq!(metadata.get_all_model_cards().len(), 1);
584        assert_eq!(metadata.get_all_event_channels().len(), 1);
585    }
586}