dynamo_runtime/discovery/
kube.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4mod crd;
5mod daemon;
6mod utils;
7
8pub use crd::{DynamoWorkerMetadata, DynamoWorkerMetadataSpec};
9pub use utils::hash_pod_name;
10
11use crd::{apply_cr, build_cr};
12use daemon::DiscoveryDaemon;
13use utils::PodInfo;
14
15use crate::CancellationToken;
16use crate::discovery::{
17    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryMetadata,
18    DiscoveryQuery, DiscoverySpec, DiscoveryStream, MetadataSnapshot,
19};
20use anyhow::Result;
21use async_trait::async_trait;
22use kube::Client as KubeClient;
23use std::collections::HashSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26
27/// Kubernetes-based discovery client
28#[derive(Clone)]
29pub struct KubeDiscoveryClient {
30    instance_id: u64,
31    metadata: Arc<RwLock<DiscoveryMetadata>>,
32    metadata_watch: tokio::sync::watch::Receiver<Arc<MetadataSnapshot>>,
33    kube_client: KubeClient,
34    pod_info: PodInfo,
35}
36
37impl KubeDiscoveryClient {
38    /// Create a new Kubernetes discovery client
39    ///
40    /// # Arguments
41    /// * `metadata` - Shared metadata store (also used by system server)
42    /// * `cancel_token` - Cancellation token for shutdown
43    pub async fn new(
44        metadata: Arc<RwLock<DiscoveryMetadata>>,
45        cancel_token: CancellationToken,
46    ) -> Result<Self> {
47        let pod_info = PodInfo::from_env()?;
48        let instance_id = hash_pod_name(&pod_info.pod_name);
49
50        tracing::info!(
51            "Initializing KubeDiscoveryClient: pod_name={}, instance_id={:x}, namespace={}, pod_uid={}",
52            pod_info.pod_name,
53            instance_id,
54            pod_info.pod_namespace,
55            pod_info.pod_uid
56        );
57
58        let kube_client = KubeClient::try_default()
59            .await
60            .map_err(|e| anyhow::anyhow!("Failed to create Kubernetes client: {}", e))?;
61
62        // Create watch channel with initial empty snapshot
63        let (watch_tx, watch_rx) = tokio::sync::watch::channel(Arc::new(MetadataSnapshot::empty()));
64
65        // Create and spawn daemon
66        let daemon = DiscoveryDaemon::new(kube_client.clone(), pod_info.clone(), cancel_token)?;
67
68        tokio::spawn(async move {
69            if let Err(e) = daemon.run(watch_tx).await {
70                tracing::error!("Discovery daemon failed: {}", e);
71            }
72        });
73
74        tracing::info!("Discovery daemon started");
75
76        Ok(Self {
77            instance_id,
78            metadata,
79            metadata_watch: watch_rx,
80            kube_client,
81            pod_info,
82        })
83    }
84}
85
86#[async_trait]
87impl Discovery for KubeDiscoveryClient {
88    fn instance_id(&self) -> u64 {
89        self.instance_id
90    }
91
92    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
93        let instance_id = self.instance_id();
94        let instance = spec.with_instance_id(instance_id);
95
96        tracing::debug!(
97            "Registering instance: {:?} with instance_id={:x}",
98            instance,
99            instance_id
100        );
101
102        // Write to local metadata and persist to CR
103        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
104        let mut metadata = self.metadata.write().await;
105
106        // Clone state for rollback in case CR persistence fails
107        let original_state = metadata.clone();
108
109        match &instance {
110            DiscoveryInstance::Endpoint(inst) => {
111                tracing::info!(
112                    "Registering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
113                    inst.namespace,
114                    inst.component,
115                    inst.endpoint,
116                    instance_id
117                );
118                metadata.register_endpoint(instance.clone())?;
119            }
120            DiscoveryInstance::Model {
121                namespace,
122                component,
123                endpoint,
124                ..
125            } => {
126                tracing::info!(
127                    "Registering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
128                    namespace,
129                    component,
130                    endpoint,
131                    instance_id
132                );
133                metadata.register_model_card(instance.clone())?;
134            }
135        }
136
137        // Build and apply the CR with the updated metadata
138        // This persists the metadata to Kubernetes for other pods to discover
139        let cr = build_cr(&self.pod_info.pod_name, &self.pod_info.pod_uid, &metadata)?;
140
141        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
142            // Rollback local state on CR persistence failure
143            tracing::warn!(
144                "Failed to persist metadata to CR, rolling back local state: {}",
145                e
146            );
147            *metadata = original_state;
148            return Err(e);
149        }
150
151        tracing::debug!("Persisted metadata to DynamoWorkerMetadata CR");
152
153        Ok(instance)
154    }
155
156    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
157        let instance_id = self.instance_id();
158
159        // Write to local metadata and persist to CR
160        // IMPORTANT: Hold the write lock across the CR write to prevent race conditions
161        let mut metadata = self.metadata.write().await;
162
163        // Clone state for rollback in case CR persistence fails
164        let original_state = metadata.clone();
165
166        match &instance {
167            DiscoveryInstance::Endpoint(inst) => {
168                tracing::info!(
169                    "Unregistering endpoint: namespace={}, component={}, endpoint={}, instance_id={:x}",
170                    inst.namespace,
171                    inst.component,
172                    inst.endpoint,
173                    instance_id
174                );
175                metadata.unregister_endpoint(&instance)?;
176            }
177            DiscoveryInstance::Model {
178                namespace,
179                component,
180                endpoint,
181                ..
182            } => {
183                tracing::info!(
184                    "Unregistering model card: namespace={}, component={}, endpoint={}, instance_id={:x}",
185                    namespace,
186                    component,
187                    endpoint,
188                    instance_id
189                );
190                metadata.unregister_model_card(&instance)?;
191            }
192        }
193
194        // Build and apply the CR with the updated metadata
195        // This persists the removal to Kubernetes for other pods to see
196        let cr = build_cr(&self.pod_info.pod_name, &self.pod_info.pod_uid, &metadata)?;
197
198        if let Err(e) = apply_cr(&self.kube_client, &self.pod_info.pod_namespace, &cr).await {
199            // Rollback local state on CR persistence failure
200            tracing::warn!(
201                "Failed to persist metadata removal to CR, rolling back local state: {}",
202                e
203            );
204            *metadata = original_state;
205            return Err(e);
206        }
207
208        tracing::debug!("Persisted metadata removal to DynamoWorkerMetadata CR");
209
210        Ok(())
211    }
212
213    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
214        tracing::debug!("KubeDiscoveryClient::list called with query={:?}", query);
215
216        // Get current snapshot (may be empty if daemon hasn't fetched yet)
217        let snapshot = self.metadata_watch.borrow().clone();
218
219        tracing::debug!(
220            "List using snapshot seq={} with {} instances",
221            snapshot.sequence,
222            snapshot.instances.len()
223        );
224
225        // Filter snapshot by query
226        let instances = snapshot.filter(&query);
227
228        tracing::info!(
229            "KubeDiscoveryClient::list returning {} instances for query={:?}",
230            instances.len(),
231            query
232        );
233
234        Ok(instances)
235    }
236
237    async fn list_and_watch(
238        &self,
239        query: DiscoveryQuery,
240        cancel_token: Option<CancellationToken>,
241    ) -> Result<DiscoveryStream> {
242        use tokio::sync::mpsc;
243
244        tracing::info!(
245            "KubeDiscoveryClient::list_and_watch started for query={:?}",
246            query
247        );
248
249        // Clone the watch receiver
250        let mut watch_rx = self.metadata_watch.clone();
251
252        // Create output stream
253        let (event_tx, event_rx) = mpsc::unbounded_channel();
254
255        // Generate unique stream identifier for tracing
256        let stream_id = uuid::Uuid::new_v4();
257
258        // Spawn task to process snapshots
259        tokio::spawn(async move {
260            // Initialize from current snapshot state
261            // This is critical: watch_rx.changed() only fires on FUTURE changes,
262            // so we must capture the current state first to detect removals correctly
263            let initial_snapshot = watch_rx.borrow_and_update().clone();
264
265            // Build initial map: DiscoveryInstanceId -> DiscoveryInstance
266            let initial: std::collections::HashMap<DiscoveryInstanceId, DiscoveryInstance> =
267                initial_snapshot
268                    .instances
269                    .values()
270                    .flat_map(|metadata| metadata.filter(&query))
271                    .map(|instance| (instance.id(), instance))
272                    .collect();
273
274            tracing::debug!(
275                stream_id = %stream_id,
276                initial_count = initial.len(),
277                "Watch started for query={:?}",
278                query
279            );
280
281            // Emit initial Added events (the "list" part of list_and_watch)
282            for instance in initial.values() {
283                tracing::info!(
284                    stream_id = %stream_id,
285                    instance_id = format!("{:x}", instance.instance_id()),
286                    "Emitting initial Added event"
287                );
288                if event_tx
289                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
290                    .is_err()
291                {
292                    tracing::debug!(
293                        stream_id = %stream_id,
294                        "Watch receiver dropped during initial sync"
295                    );
296                    return;
297                }
298            }
299
300            // Track known instances by their unique ID
301            let mut known: HashSet<DiscoveryInstanceId> = initial.into_keys().collect();
302
303            loop {
304                tracing::trace!(
305                    stream_id = %stream_id,
306                    known_count = known.len(),
307                    "Watch loop waiting for changes"
308                );
309
310                // Wait for next snapshot or cancellation
311                let watch_result = if let Some(ref token) = cancel_token {
312                    tokio::select! {
313                        result = watch_rx.changed() => result,
314                        _ = token.cancelled() => {
315                            tracing::info!(
316                                stream_id = %stream_id,
317                                "Watch cancelled via cancel token"
318                            );
319                            break;
320                        }
321                    }
322                } else {
323                    watch_rx.changed().await
324                };
325
326                match watch_result {
327                    Ok(()) => {
328                        // Get latest snapshot
329                        let snapshot = watch_rx.borrow_and_update().clone();
330
331                        // Build current map: DiscoveryInstanceId -> DiscoveryInstance
332                        let current: std::collections::HashMap<
333                            DiscoveryInstanceId,
334                            DiscoveryInstance,
335                        > = snapshot
336                            .instances
337                            .values()
338                            .flat_map(|metadata| metadata.filter(&query))
339                            .map(|instance| (instance.id(), instance))
340                            .collect();
341
342                        tracing::debug!(
343                            stream_id = %stream_id,
344                            seq = snapshot.sequence,
345                            current_count = current.len(),
346                            known_count = known.len(),
347                            "Watch received snapshot update"
348                        );
349
350                        // Compute diff using keys
351                        let current_keys: HashSet<&DiscoveryInstanceId> = current.keys().collect();
352                        let known_keys: HashSet<&DiscoveryInstanceId> = known.iter().collect();
353
354                        let added: Vec<&DiscoveryInstanceId> =
355                            current_keys.difference(&known_keys).copied().collect();
356
357                        let removed: Vec<DiscoveryInstanceId> = known_keys
358                            .difference(&current_keys)
359                            .map(|&id| id.clone())
360                            .collect();
361
362                        // Log diff results (even if empty, for debugging)
363                        if added.is_empty() && removed.is_empty() {
364                            tracing::debug!(
365                                stream_id = %stream_id,
366                                seq = snapshot.sequence,
367                                "Watch snapshot received but no diff detected"
368                            );
369                        } else {
370                            tracing::debug!(
371                                stream_id = %stream_id,
372                                seq = snapshot.sequence,
373                                added = added.len(),
374                                removed = removed.len(),
375                                total = current.len(),
376                                "Watch detected changes"
377                            );
378                        }
379
380                        // Emit Added events
381                        for id in added {
382                            if let Some(instance) = current.get(id) {
383                                tracing::info!(
384                                    stream_id = %stream_id,
385                                    instance_id = format!("{:x}", instance.instance_id()),
386                                    "Emitting Added event"
387                                );
388                                if event_tx
389                                    .send(Ok(DiscoveryEvent::Added(instance.clone())))
390                                    .is_err()
391                                {
392                                    tracing::debug!(
393                                        stream_id = %stream_id,
394                                        "Watch receiver dropped"
395                                    );
396                                    return;
397                                }
398                            }
399                        }
400
401                        // Emit Removed events
402                        for id in removed {
403                            tracing::info!(
404                                stream_id = %stream_id,
405                                id = ?id,
406                                "Emitting Removed event"
407                            );
408                            if event_tx.send(Ok(DiscoveryEvent::Removed(id))).is_err() {
409                                tracing::debug!(stream_id = %stream_id, "Watch receiver dropped");
410                                return;
411                            }
412                        }
413
414                        // Update known set
415                        known = current.into_keys().collect();
416                    }
417                    Err(_) => {
418                        tracing::info!(
419                            stream_id = %stream_id,
420                            "Watch channel closed (daemon stopped)"
421                        );
422                        break;
423                    }
424                }
425            }
426        });
427
428        // Convert receiver to stream
429        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx);
430        Ok(Box::pin(stream))
431    }
432}