Skip to main content

dynamo_runtime/discovery/
kube.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4mod crd;
5mod daemon;
6mod utils;
7
8pub use crd::{DynamoWorkerMetadata, DynamoWorkerMetadataSpec};
9// hash_pod_name is used by C bindings (EPP) for pod-level worker ID mapping.
10pub use utils::hash_pod_name;
11
12use crd::{apply_cr, build_cr};
13use daemon::DiscoveryDaemon;
14use utils::{KubeDiscoveryMode, PodInfo};
15
16use crate::CancellationToken;
17use crate::discovery::{
18    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryMetadata,
19    DiscoveryQuery, DiscoverySpec, DiscoveryStream, MetadataSnapshot,
20};
21use anyhow::Result;
22use async_trait::async_trait;
23use kube::{Api, Client as KubeClient, api::DeleteParams};
24use std::collections::HashSet;
25use std::sync::Arc;
26use tokio::sync::RwLock;
27
28/// Kubernetes-based discovery client
29#[derive(Clone)]
30pub struct KubeDiscoveryClient {
31    instance_id: u64,
32    metadata: Arc<RwLock<DiscoveryMetadata>>,
33    metadata_watch: tokio::sync::watch::Receiver<Arc<MetadataSnapshot>>,
34    kube_client: KubeClient,
35    pod_info: PodInfo,
36}
37
38impl KubeDiscoveryClient {
39    /// Create a new Kubernetes discovery client
40    ///
41    /// # Arguments
42    /// * `metadata` - Shared metadata store (also used by system server)
43    /// * `cancel_token` - Cancellation token for shutdown
44    pub async fn new(
45        metadata: Arc<RwLock<DiscoveryMetadata>>,
46        cancel_token: CancellationToken,
47    ) -> Result<Self> {
48        let pod_info = PodInfo::from_env()?;
49        let instance_id = pod_info.target.instance_id();
50        let cr_name = pod_info.target.cr_name();
51
52        tracing::info!(
53            "Initializing KubeDiscoveryClient: mode={:?}, target={:?}, cr_name={}, instance_id={:x}, namespace={}, pod_uid={}",
54            pod_info.mode,
55            pod_info.target,
56            cr_name,
57            instance_id,
58            pod_info.pod_namespace,
59            pod_info.pod_uid
60        );
61
62        let kube_client = KubeClient::try_default()
63            .await
64            .map_err(|e| anyhow::anyhow!("Failed to create Kubernetes client: {}", e))?;
65
66        // In container mode, delete any stale CR from a previous incarnation of this container.
67        // In failover pods, the pod stays alive when a container crashes and restarts,
68        // so the old CR persists. Deleting it ensures the daemon doesn't see stale data.
69        // In pod mode this is unnecessary — pod restart creates a new pod (and new CR name).
70        if pod_info.mode == KubeDiscoveryMode::Container {
71            let cr_api: Api<DynamoWorkerMetadata> =
72                Api::namespaced(kube_client.clone(), &pod_info.pod_namespace);
73            match cr_api.delete(&cr_name, &DeleteParams::default()).await {
74                Ok(_) => tracing::info!("Deleted stale CR: {}", cr_name),
75                Err(kube::Error::Api(err_resp)) if err_resp.code == 404 => {
76                    tracing::debug!("No stale CR to delete: {}", cr_name);
77                }
78                Err(e) => {
79                    panic!(
80                        "Failed to clear stale CR '{}': {} — cannot start with stale discovery state",
81                        cr_name, e
82                    );
83                }
84            }
85        }
86
87        // Create watch channel with initial empty snapshot
88        let (watch_tx, watch_rx) = tokio::sync::watch::channel(Arc::new(MetadataSnapshot::empty()));
89
90        // Create and spawn daemon
91        let daemon = DiscoveryDaemon::new(kube_client.clone(), pod_info.clone(), cancel_token)?;
92
93        tokio::spawn(async move {
94            if let Err(e) = daemon.run(watch_tx).await {
95                tracing::error!("Discovery daemon failed: {e}");
96            }
97        });
98
99        tracing::info!("Discovery daemon started");
100
101        Ok(Self {
102            instance_id,
103            metadata,
104            metadata_watch: watch_rx,
105            kube_client,
106            pod_info,
107        })
108    }
109}
110
111#[async_trait]
112impl Discovery for KubeDiscoveryClient {
113    fn instance_id(&self) -> u64 {
114        self.instance_id
115    }
116
117    async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
118        let instance_id = self.instance_id();
119        let instance = spec.with_instance_id(instance_id);
120
121        tracing::debug!(
122            "Registering instance: {:?} with instance_id={:x}",
123            instance,
124            instance_id
125        );
126
127        // Write to local metadata and persist to CR
128        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
129        let mut metadata = self.metadata.write().await;
130
131        // Clone state for rollback in case CR persistence fails
132        let original_state = metadata.clone();
133
134        match &instance {
135            DiscoveryInstance::Endpoint(inst) => {
136                tracing::info!(
137                    "Registering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
138                    inst.namespace,
139                    inst.component,
140                    inst.endpoint,
141                    instance_id
142                );
143                metadata.register_endpoint(instance.clone())?;
144            }
145            DiscoveryInstance::Model {
146                namespace,
147                component,
148                endpoint,
149                ..
150            } => {
151                tracing::info!(
152                    "Registering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
153                    namespace,
154                    component,
155                    endpoint,
156                    instance_id
157                );
158                metadata.register_model_card(instance.clone())?;
159            }
160            DiscoveryInstance::EventChannel {
161                namespace,
162                component,
163                topic,
164                ..
165            } => {
166                tracing::info!(
167                    "Registering event channel: namespace={}, component={}, topic={}, instance_id={:x}",
168                    namespace,
169                    component,
170                    topic,
171                    instance_id
172                );
173                metadata.register_event_channel(instance.clone())?;
174            }
175        }
176
177        // Build and apply the CR with the updated metadata
178        // This persists the metadata to Kubernetes for other pods to discover
179        let cr_name = self.pod_info.target.cr_name();
180        let cr = build_cr(
181            &cr_name,
182            &self.pod_info.pod_name,
183            &self.pod_info.pod_uid,
184            &metadata,
185        )?;
186
187        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
188            // Rollback local state on CR persistence failure
189            tracing::warn!(
190                "Failed to persist metadata to CR, rolling back local state: {}",
191                e
192            );
193            *metadata = original_state;
194            return Err(e);
195        }
196
197        tracing::debug!("Persisted metadata to DynamoWorkerMetadata CR");
198
199        Ok(instance)
200    }
201
202    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
203        let instance_id = self.instance_id();
204
205        // Write to local metadata and persist to CR
206        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
207        let mut metadata = self.metadata.write().await;
208
209        // Clone state for rollback in case CR persistence fails
210        let original_state = metadata.clone();
211
212        match &instance {
213            DiscoveryInstance::Endpoint(inst) => {
214                tracing::info!(
215                    "Unregistering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
216                    inst.namespace,
217                    inst.component,
218                    inst.endpoint,
219                    instance_id
220                );
221                metadata.unregister_endpoint(&instance)?;
222            }
223            DiscoveryInstance::Model {
224                namespace,
225                component,
226                endpoint,
227                ..
228            } => {
229                tracing::info!(
230                    "Unregistering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
231                    namespace,
232                    component,
233                    endpoint,
234                    instance_id
235                );
236                metadata.unregister_model_card(&instance)?;
237            }
238            DiscoveryInstance::EventChannel {
239                namespace,
240                component,
241                topic,
242                ..
243            } => {
244                tracing::info!(
245                    "Unregistering event channel: namespace={}, component={}, topic={}, instance_id={:x}",
246                    namespace,
247                    component,
248                    topic,
249                    instance_id
250                );
251                metadata.unregister_event_channel(&instance)?;
252            }
253        }
254
255        // Build and apply the CR with the updated metadata
256        // This persists the removal to Kubernetes for other pods to see
257        let cr_name = self.pod_info.target.cr_name();
258        let cr = build_cr(
259            &cr_name,
260            &self.pod_info.pod_name,
261            &self.pod_info.pod_uid,
262            &metadata,
263        )?;
264
265        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
266            // Rollback local state on CR persistence failure
267            tracing::warn!(
268                "Failed to persist metadata removal to CR, rolling back local state: {}",
269                e
270            );
271            *metadata = original_state;
272            return Err(e);
273        }
274
275        tracing::debug!("Persisted metadata removal to DynamoWorkerMetadata CR");
276
277        Ok(())
278    }
279
280    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
281        tracing::debug!("KubeDiscoveryClient::list called with query={:?}", query);
282
283        // Get current snapshot (may be empty if daemon hasn't fetched yet)
284        let snapshot = self.metadata_watch.borrow().clone();
285
286        tracing::debug!(
287            "List using snapshot seq={} with {} instances",
288            snapshot.sequence,
289            snapshot.instances.len()
290        );
291
292        // Filter snapshot by query
293        let instances = snapshot.filter(&query);
294
295        tracing::info!(
296            "KubeDiscoveryClient::list returning {} instances for query={:?}",
297            instances.len(),
298            query
299        );
300
301        Ok(instances)
302    }
303
304    async fn list_and_watch(
305        &self,
306        query: DiscoveryQuery,
307        cancel_token: Option<CancellationToken>,
308    ) -> Result<DiscoveryStream> {
309        use tokio::sync::mpsc;
310
311        tracing::info!(
312            "KubeDiscoveryClient::list_and_watch started for query={:?}",
313            query
314        );
315
316        // Clone the watch receiver
317        let mut watch_rx = self.metadata_watch.clone();
318
319        // Create output stream
320        let (event_tx, event_rx) = mpsc::unbounded_channel();
321
322        // Generate unique stream identifier for tracing
323        let stream_id = uuid::Uuid::new_v4();
324
325        // Spawn task to process snapshots
326        tokio::spawn(async move {
327            // Initialize from current snapshot state
328            // This is critical: watch_rx.changed() only fires on FUTURE changes,
329            // so we must capture the current state first to detect removals correctly
330            let initial_snapshot = watch_rx.borrow_and_update().clone();
331
332            // Build initial map: DiscoveryInstanceId -> DiscoveryInstance
333            let initial: std::collections::HashMap<DiscoveryInstanceId, DiscoveryInstance> =
334                initial_snapshot
335                    .instances
336                    .values()
337                    .flat_map(|metadata| metadata.filter(&query))
338                    .map(|instance| (instance.id(), instance))
339                    .collect();
340
341            tracing::debug!(
342                stream_id = %stream_id,
343                initial_count = initial.len(),
344                "Watch started for query={:?}",
345                query
346            );
347
348            // Emit initial Added events (the "list" part of list_and_watch)
349            for instance in initial.values() {
350                tracing::info!(
351                    stream_id = %stream_id,
352                    instance_id = format!("{:x}", instance.instance_id()),
353                    "Emitting initial Added event"
354                );
355                if event_tx
356                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
357                    .is_err()
358                {
359                    tracing::debug!(
360                        stream_id = %stream_id,
361                        "Watch receiver dropped during initial sync"
362                    );
363                    return;
364                }
365            }
366
367            // Track known instances by their unique ID
368            let mut known: HashSet<DiscoveryInstanceId> = initial.into_keys().collect();
369
370            loop {
371                tracing::trace!(
372                    stream_id = %stream_id,
373                    known_count = known.len(),
374                    "Watch loop waiting for changes"
375                );
376
377                // Wait for next snapshot or cancellation
378                let watch_result = if let Some(ref token) = cancel_token {
379                    tokio::select! {
380                        result = watch_rx.changed() => result,
381                        _ = token.cancelled() => {
382                            tracing::info!(
383                                stream_id = %stream_id,
384                                "Watch cancelled via cancel token"
385                            );
386                            break;
387                        }
388                    }
389                } else {
390                    watch_rx.changed().await
391                };
392
393                match watch_result {
394                    Ok(()) => {
395                        // Get latest snapshot
396                        let snapshot = watch_rx.borrow_and_update().clone();
397
398                        // Build current map: DiscoveryInstanceId -> DiscoveryInstance
399                        let current: std::collections::HashMap<
400                            DiscoveryInstanceId,
401                            DiscoveryInstance,
402                        > = snapshot
403                            .instances
404                            .values()
405                            .flat_map(|metadata| metadata.filter(&query))
406                            .map(|instance| (instance.id(), instance))
407                            .collect();
408
409                        tracing::debug!(
410                            stream_id = %stream_id,
411                            seq = snapshot.sequence,
412                            current_count = current.len(),
413                            known_count = known.len(),
414                            "Watch received snapshot update"
415                        );
416
417                        // Compute diff using keys
418                        let current_keys: HashSet<&DiscoveryInstanceId> = current.keys().collect();
419                        let known_keys: HashSet<&DiscoveryInstanceId> = known.iter().collect();
420
421                        let added: Vec<&DiscoveryInstanceId> =
422                            current_keys.difference(&known_keys).copied().collect();
423
424                        let removed: Vec<DiscoveryInstanceId> = known_keys
425                            .difference(&current_keys)
426                            .map(|&id| id.clone())
427                            .collect();
428
429                        // Log diff results (even if empty, for debugging)
430                        if added.is_empty() && removed.is_empty() {
431                            tracing::debug!(
432                                stream_id = %stream_id,
433                                seq = snapshot.sequence,
434                                "Watch snapshot received but no diff detected"
435                            );
436                        } else {
437                            tracing::debug!(
438                                stream_id = %stream_id,
439                                seq = snapshot.sequence,
440                                added = added.len(),
441                                removed = removed.len(),
442                                total = current.len(),
443                                "Watch detected changes"
444                            );
445                        }
446
447                        // Emit Added events
448                        for id in added {
449                            if let Some(instance) = current.get(id) {
450                                tracing::info!(
451                                    stream_id = %stream_id,
452                                    instance_id = format!("{:x}", instance.instance_id()),
453                                    "Emitting Added event"
454                                );
455                                if event_tx
456                                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
457                                    .is_err()
458                                {
459                                    tracing::debug!(
460                                        stream_id = %stream_id,
461                                        "Watch receiver dropped"
462                                    );
463                                    return;
464                                }
465                            }
466                        }
467
468                        // Emit Removed events
469                        for id in removed {
470                            tracing::info!(
471                                stream_id = %stream_id,
472                                id = ?id,
473                                "Emitting Removed event"
474                            );
475                            if event_tx.send(Ok(DiscoveryEvent::Removed(id))).is_err() {
476                                tracing::debug!(stream_id = %stream_id, "Watch receiver dropped");
477                                return;
478                            }
479                        }
480
481                        // Update known set
482                        known = current.into_keys().collect();
483                    }
484                    Err(_) => {
485                        tracing::info!(
486                            stream_id = %stream_id,
487                            "Watch channel closed (daemon stopped)"
488                        );
489                        break;
490                    }
491                }
492            }
493        });
494
495        // Convert receiver to stream
496        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx);
497        Ok(Box::pin(stream))
498    }
499}