use crate::CancellationToken;
use crate::discovery::{DiscoveryMetadata, MetadataSnapshot};
use anyhow::Result;
use futures::StreamExt;
use k8s_openapi::api::discovery::v1::EndpointSlice;
use kube::{
Api, Client as KubeClient,
runtime::{WatchStreamExt, reflector, watcher, watcher::Config},
};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use super::utils::{PodInfo, extract_endpoint_info, hash_pod_name};
const SNAPSHOT_POLL_INTERVAL_MS: u64 = 5000;
const MAX_CONCURRENT_FETCHES: usize = 20;
const METADATA_FETCH_TIMEOUT_SECS: u64 = 5;
#[derive(Clone)]
pub(super) struct DiscoveryDaemon {
kube_client: KubeClient,
http_client: reqwest::Client,
cache: Arc<RwLock<HashMap<u64, Arc<DiscoveryMetadata>>>>,
pod_info: PodInfo,
cancel_token: CancellationToken,
}
impl DiscoveryDaemon {
pub fn new(
kube_client: KubeClient,
pod_info: PodInfo,
cancel_token: CancellationToken,
) -> Result<Self> {
let http_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(METADATA_FETCH_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
Ok(Self {
kube_client,
http_client,
cache: Arc::new(RwLock::new(HashMap::new())),
pod_info,
cancel_token,
})
}
pub async fn run(
self,
watch_tx: tokio::sync::watch::Sender<Arc<MetadataSnapshot>>,
) -> Result<()> {
tracing::info!("Discovery daemon starting");
let endpoint_slices: Api<EndpointSlice> =
Api::namespaced(self.kube_client.clone(), &self.pod_info.pod_namespace);
let (reader, writer) = reflector::store();
let watch_config = Config::default()
.labels("nvidia.com/dynamo-discovery-backend=kubernetes")
.labels("nvidia.com/dynamo-discovery-enabled=true");
tracing::info!(
"Daemon watching EndpointSlices with labels: nvidia.com/dynamo-discovery-backend=kubernetes, nvidia.com/dynamo-discovery-enabled=true"
);
let reflector_stream = reflector(writer, watcher(endpoint_slices, watch_config))
.default_backoff()
.touched_objects()
.for_each(|res| {
match res {
Ok(obj) => {
tracing::debug!(
slice_name = obj.metadata.name.as_deref().unwrap_or("unknown"),
"Daemon reflector updated EndpointSlice"
);
}
Err(e) => {
tracing::warn!("Daemon reflector error: {}", e);
}
}
futures::future::ready(())
});
tokio::spawn(reflector_stream);
let mut sequence = 0u64;
let mut prev_instance_ids: HashSet<u64> = HashSet::new();
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(SNAPSHOT_POLL_INTERVAL_MS));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = interval.tick() => {
match self.aggregate_snapshot(&reader, sequence).await {
Ok(snapshot) => {
let current_instance_ids: HashSet<u64> =
snapshot.instances.keys().copied().collect();
let instances_changed = current_instance_ids != prev_instance_ids;
if instances_changed {
let added: Vec<u64> = current_instance_ids
.difference(&prev_instance_ids)
.copied()
.collect();
let removed: Vec<u64> = prev_instance_ids
.difference(¤t_instance_ids)
.copied()
.collect();
tracing::info!(
"Daemon snapshot (seq={}): instances changed, total={}, added=[{}], removed=[{}]",
sequence,
current_instance_ids.len(),
added.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>().join(", "),
removed.iter().map(|id| format!("{:x}", id)).collect::<Vec<_>>().join(", ")
);
if !removed.is_empty() {
self.prune_cache(&removed).await;
}
if watch_tx.send(Arc::new(snapshot)).is_err() {
tracing::debug!("No watch subscribers, daemon stopping");
break;
}
prev_instance_ids = current_instance_ids;
} else {
tracing::trace!(
"Daemon snapshot (seq={}): no changes, {} instances",
sequence,
current_instance_ids.len()
);
}
sequence += 1;
}
Err(e) => {
tracing::error!("Failed to aggregate snapshot: {}", e);
}
}
}
_ = self.cancel_token.cancelled() => {
tracing::info!("Discovery daemon received cancellation");
break;
}
}
}
tracing::info!("Discovery daemon stopped");
Ok(())
}
async fn aggregate_snapshot(
&self,
reader: &reflector::Store<EndpointSlice>,
sequence: u64,
) -> Result<MetadataSnapshot> {
let start = std::time::Instant::now();
let all_endpoints: Vec<(u64, String, String, u16)> = reader
.state()
.iter()
.flat_map(|arc_slice| extract_endpoint_info(arc_slice.as_ref()))
.collect();
tracing::trace!(
"Daemon found {} ready endpoints to fetch",
all_endpoints.len()
);
let fetch_futures =
all_endpoints
.into_iter()
.map(|(instance_id, pod_name, pod_ip, system_port)| {
let daemon = self.clone();
async move {
match daemon.fetch_metadata(&pod_name, &pod_ip, system_port).await {
Ok(metadata) => Some((instance_id, metadata)),
Err(e) => {
tracing::warn!(
"Failed to fetch metadata for pod {} (instance_id={:x}): {}",
pod_name,
instance_id,
e
);
None
}
}
}
});
let results: Vec<_> = futures::stream::iter(fetch_futures)
.buffer_unordered(MAX_CONCURRENT_FETCHES)
.collect()
.await;
let instances: HashMap<u64, Arc<DiscoveryMetadata>> =
results.into_iter().flatten().collect();
let elapsed = start.elapsed();
tracing::trace!(
"Daemon snapshot complete (seq={}): {} instances in {:?}",
sequence,
instances.len(),
elapsed
);
Ok(MetadataSnapshot {
instances,
sequence,
timestamp: std::time::Instant::now(),
})
}
async fn fetch_metadata(
&self,
pod_name: &str,
pod_ip: &str,
system_port: u16,
) -> Result<Arc<DiscoveryMetadata>> {
let instance_id = hash_pod_name(pod_name);
{
let cache = self.cache.read().await;
if let Some(cached) = cache.get(&instance_id) {
tracing::trace!(
"Cache hit for pod_name={}, instance_id={:x}",
pod_name,
instance_id
);
return Ok(cached.clone());
}
}
let url = format!("http://{}:{}/metadata", pod_ip, system_port);
tracing::debug!("Fetching metadata from {url}");
let response = self
.http_client
.get(&url)
.send()
.await
.map_err(|e| anyhow::anyhow!("Failed to fetch metadata from {}: {}", url, e))?;
let metadata: DiscoveryMetadata = response
.json()
.await
.map_err(|e| anyhow::anyhow!("Failed to parse metadata from {}: {}", url, e))?;
let metadata = Arc::new(metadata);
{
let mut cache = self.cache.write().await;
if let Some(existing) = cache.get(&instance_id) {
tracing::debug!(
"Another task cached metadata for instance_id={:x} while we were fetching",
instance_id
);
return Ok(existing.clone());
}
cache.insert(instance_id, metadata.clone());
tracing::debug!(
"Cached metadata for pod_name={}, instance_id={:x}",
pod_name,
instance_id
);
}
Ok(metadata)
}
async fn prune_cache(&self, removed_ids: &[u64]) {
let mut cache = self.cache.write().await;
for id in removed_ids {
if cache.remove(id).is_some() {
tracing::debug!("Pruned cache for removed instance_id={:x}", id);
}
}
}
}