Skip to main content

dynamo_runtime/discovery/
kube.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4mod crd;
5mod daemon;
6mod utils;
7
8pub use crd::{DynamoWorkerMetadata, DynamoWorkerMetadataSpec};
9pub use utils::hash_pod_name;
10
11use crd::{apply_cr, build_cr};
12use daemon::DiscoveryDaemon;
13use utils::PodInfo;
14
15use crate::CancellationToken;
16use crate::discovery::{
17    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryMetadata,
18    DiscoveryQuery, DiscoverySpec, DiscoveryStream, MetadataSnapshot,
19};
20use anyhow::Result;
21use async_trait::async_trait;
22use kube::Client as KubeClient;
23use std::collections::HashSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26
27/// Kubernetes-based discovery client
28#[derive(Clone)]
29pub struct KubeDiscoveryClient {
30    instance_id: u64,
31    metadata: Arc<RwLock<DiscoveryMetadata>>,
32    metadata_watch: tokio::sync::watch::Receiver<Arc<MetadataSnapshot>>,
33    kube_client: KubeClient,
34    pod_info: PodInfo,
35}
36
37impl KubeDiscoveryClient {
38    /// Create a new Kubernetes discovery client
39    ///
40    /// # Arguments
41    /// * `metadata` - Shared metadata store (also used by system server)
42    /// * `cancel_token` - Cancellation token for shutdown
43    pub async fn new(
44        metadata: Arc<RwLock<DiscoveryMetadata>>,
45        cancel_token: CancellationToken,
46    ) -> Result<Self> {
47        let pod_info = PodInfo::from_env()?;
48        let instance_id = hash_pod_name(&pod_info.pod_name);
49
50        tracing::info!(
51            "Initializing KubeDiscoveryClient: pod_name={}, instance_id={:x}, namespace={}, pod_uid={}",
52            pod_info.pod_name,
53            instance_id,
54            pod_info.pod_namespace,
55            pod_info.pod_uid
56        );
57
58        let kube_client = KubeClient::try_default()
59            .await
60            .map_err(|e| anyhow::anyhow!("Failed to create Kubernetes client: {}", e))?;
61
62        // Create watch channel with initial empty snapshot
63        let (watch_tx, watch_rx) = tokio::sync::watch::channel(Arc::new(MetadataSnapshot::empty()));
64
65        // Create and spawn daemon
66        let daemon = DiscoveryDaemon::new(kube_client.clone(), pod_info.clone(), cancel_token)?;
67
68        tokio::spawn(async move {
69            if let Err(e) = daemon.run(watch_tx).await {
70                tracing::error!("Discovery daemon failed: {}", e);
71            }
72        });
73
74        tracing::info!("Discovery daemon started");
75
76        Ok(Self {
77            instance_id,
78            metadata,
79            metadata_watch: watch_rx,
80            kube_client,
81            pod_info,
82        })
83    }
84}
85
86#[async_trait]
87impl Discovery for KubeDiscoveryClient {
88    fn instance_id(&self) -> u64 {
89        self.instance_id
90    }
91
92    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
93        let instance_id = self.instance_id();
94        let instance = spec.with_instance_id(instance_id);
95
96        tracing::debug!(
97            "Registering instance: {:?} with instance_id={:x}",
98            instance,
99            instance_id
100        );
101
102        // Write to local metadata and persist to CR
103        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
104        let mut metadata = self.metadata.write().await;
105
106        // Clone state for rollback in case CR persistence fails
107        let original_state = metadata.clone();
108
109        match &instance {
110            DiscoveryInstance::Endpoint(inst) => {
111                tracing::info!(
112                    "Registering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
113                    inst.namespace,
114                    inst.component,
115                    inst.endpoint,
116                    instance_id
117                );
118                metadata.register_endpoint(instance.clone())?;
119            }
120            DiscoveryInstance::Model {
121                namespace,
122                component,
123                endpoint,
124                ..
125            } => {
126                tracing::info!(
127                    "Registering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
128                    namespace,
129                    component,
130                    endpoint,
131                    instance_id
132                );
133                metadata.register_model_card(instance.clone())?;
134            }
135            DiscoveryInstance::EventChannel {
136                namespace,
137                component,
138                topic,
139                ..
140            } => {
141                tracing::info!(
142                    "Registering event channel: namespace={}, component={}, topic={}, instance_id={:x}",
143                    namespace,
144                    component,
145                    topic,
146                    instance_id
147                );
148                metadata.register_event_channel(instance.clone())?;
149            }
150        }
151
152        // Build and apply the CR with the updated metadata
153        // This persists the metadata to Kubernetes for other pods to discover
154        let cr = build_cr(&self.pod_info.pod_name, &self.pod_info.pod_uid, &metadata)?;
155
156        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
157            // Rollback local state on CR persistence failure
158            tracing::warn!(
159                "Failed to persist metadata to CR, rolling back local state: {}",
160                e
161            );
162            *metadata = original_state;
163            return Err(e);
164        }
165
166        tracing::debug!("Persisted metadata to DynamoWorkerMetadata CR");
167
168        Ok(instance)
169    }
170
171    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
172        let instance_id = self.instance_id();
173
174        // Write to local metadata and persist to CR
175        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
176        let mut metadata = self.metadata.write().await;
177
178        // Clone state for rollback in case CR persistence fails
179        let original_state = metadata.clone();
180
181        match &instance {
182            DiscoveryInstance::Endpoint(inst) => {
183                tracing::info!(
184                    "Unregistering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
185                    inst.namespace,
186                    inst.component,
187                    inst.endpoint,
188                    instance_id
189                );
190                metadata.unregister_endpoint(&instance)?;
191            }
192            DiscoveryInstance::Model {
193                namespace,
194                component,
195                endpoint,
196                ..
197            } => {
198                tracing::info!(
199                    "Unregistering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
200                    namespace,
201                    component,
202                    endpoint,
203                    instance_id
204                );
205                metadata.unregister_model_card(&instance)?;
206            }
207            DiscoveryInstance::EventChannel {
208                namespace,
209                component,
210                topic,
211                ..
212            } => {
213                tracing::info!(
214                    "Unregistering event channel: namespace={}, component={}, topic={}, instance_id={:x}",
215                    namespace,
216                    component,
217                    topic,
218                    instance_id
219                );
220                metadata.unregister_event_channel(&instance)?;
221            }
222        }
223
224        // Build and apply the CR with the updated metadata
225        // This persists the removal to Kubernetes for other pods to see
226        let cr = build_cr(&self.pod_info.pod_name, &self.pod_info.pod_uid, &metadata)?;
227
228        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
229            // Rollback local state on CR persistence failure
230            tracing::warn!(
231                "Failed to persist metadata removal to CR, rolling back local state: {}",
232                e
233            );
234            *metadata = original_state;
235            return Err(e);
236        }
237
238        tracing::debug!("Persisted metadata removal to DynamoWorkerMetadata CR");
239
240        Ok(())
241    }
242
243    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
244        tracing::debug!("KubeDiscoveryClient::list called with query={:?}", query);
245
246        // Get current snapshot (may be empty if daemon hasn't fetched yet)
247        let snapshot = self.metadata_watch.borrow().clone();
248
249        tracing::debug!(
250            "List using snapshot seq={} with {} instances",
251            snapshot.sequence,
252            snapshot.instances.len()
253        );
254
255        // Filter snapshot by query
256        let instances = snapshot.filter(&query);
257
258        tracing::info!(
259            "KubeDiscoveryClient::list returning {} instances for query={:?}",
260            instances.len(),
261            query
262        );
263
264        Ok(instances)
265    }
266
267    async fn list_and_watch(
268        &self,
269        query: DiscoveryQuery,
270        cancel_token: Option<CancellationToken>,
271    ) -> Result<DiscoveryStream> {
272        use tokio::sync::mpsc;
273
274        tracing::info!(
275            "KubeDiscoveryClient::list_and_watch started for query={:?}",
276            query
277        );
278
279        // Clone the watch receiver
280        let mut watch_rx = self.metadata_watch.clone();
281
282        // Create output stream
283        let (event_tx, event_rx) = mpsc::unbounded_channel();
284
285        // Generate unique stream identifier for tracing
286        let stream_id = uuid::Uuid::new_v4();
287
288        // Spawn task to process snapshots
289        tokio::spawn(async move {
290            // Initialize from current snapshot state
291            // This is critical: watch_rx.changed() only fires on FUTURE changes,
292            // so we must capture the current state first to detect removals correctly
293            let initial_snapshot = watch_rx.borrow_and_update().clone();
294
295            // Build initial map: DiscoveryInstanceId -> DiscoveryInstance
296            let initial: std::collections::HashMap<DiscoveryInstanceId, DiscoveryInstance> =
297                initial_snapshot
298                    .instances
299                    .values()
300                    .flat_map(|metadata| metadata.filter(&query))
301                    .map(|instance| (instance.id(), instance))
302                    .collect();
303
304            tracing::debug!(
305                stream_id = %stream_id,
306                initial_count = initial.len(),
307                "Watch started for query={:?}",
308                query
309            );
310
311            // Emit initial Added events (the "list" part of list_and_watch)
312            for instance in initial.values() {
313                tracing::info!(
314                    stream_id = %stream_id,
315                    instance_id = format!("{:x}", instance.instance_id()),
316                    "Emitting initial Added event"
317                );
318                if event_tx
319                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
320                    .is_err()
321                {
322                    tracing::debug!(
323                        stream_id = %stream_id,
324                        "Watch receiver dropped during initial sync"
325                    );
326                    return;
327                }
328            }
329
330            // Track known instances by their unique ID
331            let mut known: HashSet<DiscoveryInstanceId> = initial.into_keys().collect();
332
333            loop {
334                tracing::trace!(
335                    stream_id = %stream_id,
336                    known_count = known.len(),
337                    "Watch loop waiting for changes"
338                );
339
340                // Wait for next snapshot or cancellation
341                let watch_result = if let Some(ref token) = cancel_token {
342                    tokio::select! {
343                        result = watch_rx.changed() => result,
344                        _ = token.cancelled() => {
345                            tracing::info!(
346                                stream_id = %stream_id,
347                                "Watch cancelled via cancel token"
348                            );
349                            break;
350                        }
351                    }
352                } else {
353                    watch_rx.changed().await
354                };
355
356                match watch_result {
357                    Ok(()) => {
358                        // Get latest snapshot
359                        let snapshot = watch_rx.borrow_and_update().clone();
360
361                        // Build current map: DiscoveryInstanceId -> DiscoveryInstance
362                        let current: std::collections::HashMap<
363                            DiscoveryInstanceId,
364                            DiscoveryInstance,
365                        > = snapshot
366                            .instances
367                            .values()
368                            .flat_map(|metadata| metadata.filter(&query))
369                            .map(|instance| (instance.id(), instance))
370                            .collect();
371
372                        tracing::debug!(
373                            stream_id = %stream_id,
374                            seq = snapshot.sequence,
375                            current_count = current.len(),
376                            known_count = known.len(),
377                            "Watch received snapshot update"
378                        );
379
380                        // Compute diff using keys
381                        let current_keys: HashSet<&DiscoveryInstanceId> = current.keys().collect();
382                        let known_keys: HashSet<&DiscoveryInstanceId> = known.iter().collect();
383
384                        let added: Vec<&DiscoveryInstanceId> =
385                            current_keys.difference(&known_keys).copied().collect();
386
387                        let removed: Vec<DiscoveryInstanceId> = known_keys
388                            .difference(&current_keys)
389                            .map(|&id| id.clone())
390                            .collect();
391
392                        // Log diff results (even if empty, for debugging)
393                        if added.is_empty() && removed.is_empty() {
394                            tracing::debug!(
395                                stream_id = %stream_id,
396                                seq = snapshot.sequence,
397                                "Watch snapshot received but no diff detected"
398                            );
399                        } else {
400                            tracing::debug!(
401                                stream_id = %stream_id,
402                                seq = snapshot.sequence,
403                                added = added.len(),
404                                removed = removed.len(),
405                                total = current.len(),
406                                "Watch detected changes"
407                            );
408                        }
409
410                        // Emit Added events
411                        for id in added {
412                            if let Some(instance) = current.get(id) {
413                                tracing::info!(
414                                    stream_id = %stream_id,
415                                    instance_id = format!("{:x}", instance.instance_id()),
416                                    "Emitting Added event"
417                                );
418                                if event_tx
419                                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
420                                    .is_err()
421                                {
422                                    tracing::debug!(
423                                        stream_id = %stream_id,
424                                        "Watch receiver dropped"
425                                    );
426                                    return;
427                                }
428                            }
429                        }
430
431                        // Emit Removed events
432                        for id in removed {
433                            tracing::info!(
434                                stream_id = %stream_id,
435                                id = ?id,
436                                "Emitting Removed event"
437                            );
438                            if event_tx.send(Ok(DiscoveryEvent::Removed(id))).is_err() {
439                                tracing::debug!(stream_id = %stream_id, "Watch receiver dropped");
440                                return;
441                            }
442                        }
443
444                        // Update known set
445                        known = current.into_keys().collect();
446                    }
447                    Err(_) => {
448                        tracing::info!(
449                            stream_id = %stream_id,
450                            "Watch channel closed (daemon stopped)"
451                        );
452                        break;
453                    }
454                }
455            }
456        });
457
458        // Convert receiver to stream
459        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx);
460        Ok(Box::pin(stream))
461    }
462}