use std::collections::{HashMap, HashSet};
use tokio::sync::watch;
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::kv_router::protocols::WorkerId;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::ModelDeploymentCard;
pub type RuntimeConfigWatch = watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>;
pub async fn runtime_config_watch(endpoint: &Endpoint) -> anyhow::Result<RuntimeConfigWatch> {
let component = endpoint.component();
let cancel_token = component.drt().primary_token();
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
let discovery = component.drt().discovery();
let eid = endpoint.id();
let stream = discovery
.list_and_watch(
DiscoveryQuery::EndpointModels {
namespace: eid.namespace.clone(),
component: eid.component.clone(),
endpoint: eid.name.clone(),
},
Some(cancel_token.clone()),
)
.await?;
let mut configs_rx =
watch_and_extract_field(stream, |card: ModelDeploymentCard| card.runtime_config);
let (tx, rx) = watch::channel(HashMap::new());
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel_token.cancelled() => break,
result = instance_ids_rx.changed() => { if result.is_err() { break; } }
result = configs_rx.changed() => { if result.is_err() { break; } }
}
let instances: HashSet<WorkerId> = instance_ids_rx
.borrow_and_update()
.iter()
.copied()
.collect();
let configs = configs_rx.borrow_and_update().clone();
let ready: HashMap<WorkerId, ModelRuntimeConfig> = instances
.into_iter()
.filter_map(|id| configs.get(&id).map(|cfg| (id, cfg.clone())))
.collect();
if *tx.borrow() == ready {
continue;
}
if tx.send(ready).is_err() {
break;
}
}
});
Ok(rx)
}