use serde::Deserialize;
use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryStream};
fn collapse_by_instance_id<V: Clone>(
state: &std::collections::HashMap<DiscoveryInstanceId, V>,
) -> std::collections::HashMap<u64, V> {
let mut result = std::collections::HashMap::new();
for (id, val) in state {
let instance_id = id.instance_id();
let model_suffix = match id {
DiscoveryInstanceId::Model(mid) => mid.model_suffix.as_ref(),
_ => None,
};
if model_suffix.is_none() || !result.contains_key(&instance_id) {
result.insert(instance_id, val.clone());
}
}
result
}
pub fn watch_and_extract_field<T, V, F>(
stream: DiscoveryStream,
extractor: F,
) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
where
T: for<'de> Deserialize<'de> + 'static,
V: Clone + PartialEq + Send + Sync + 'static,
F: Fn(T) -> V + Send + 'static,
{
use futures::StreamExt;
use std::collections::HashMap;
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
tokio::spawn(async move {
let mut state: HashMap<DiscoveryInstanceId, V> = HashMap::new();
let mut stream = stream;
while let Some(result) = stream.next().await {
match result {
Ok(DiscoveryEvent::Added(instance)) => {
let instance_id = instance.instance_id();
let key = instance.id();
let deserialized: T = match instance.deserialize_model() {
Ok(d) => d,
Err(e) => {
tracing::warn!(
instance_id,
error = %e,
"Failed to deserialize discovery instance, skipping"
);
continue;
}
};
let value = extractor(deserialized);
tracing::debug!(
instance_id,
?key,
state_len = state.len(),
"watch_and_extract_field: inserting instance"
);
state.insert(key, value);
let collapsed = collapse_by_instance_id(&state);
if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break;
}
}
Ok(DiscoveryEvent::Removed(id)) => {
let had_entry = state.contains_key(&id);
tracing::debug!(
instance_id = id.instance_id(),
?id,
had_entry,
state_len = state.len(),
"watch_and_extract_field: removing instance"
);
state.remove(&id);
let collapsed = collapse_by_instance_id(&state);
if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
tracing::debug!("watch_and_extract_field receiver dropped, stopping");
break;
}
}
Err(e) => {
tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
}
}
}
tracing::debug!("watch_and_extract_field task stopped");
});
rx
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::mock::{MockDiscovery, SharedMockRegistry};
use crate::discovery::{Discovery, DiscoveryQuery, DiscoverySpec};
#[derive(serde::Deserialize, Clone, Debug)]
struct FakeCard {
display_name: String,
}
fn model_spec(name: &str) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({ "display_name": name }),
model_suffix: None,
}
}
async fn poll_until(
rx: &tokio::sync::watch::Receiver<std::collections::HashMap<u64, String>>,
pred: impl Fn(&std::collections::HashMap<u64, String>) -> bool,
msg: &str,
) {
for _ in 0..100 {
if pred(&rx.borrow()) {
return;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
panic!("{}: state={:?}", msg, *rx.borrow());
}
fn lora_spec(lora_name: &str) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({
"display_name": lora_name,
"source_path": "base-model",
"lora": { "name": lora_name },
}),
model_suffix: Some(lora_name.to_string()),
}
}
#[tokio::test]
async fn test_lora_unregister_preserves_worker_runtime_config() {
let discovery = MockDiscovery::new(Some(42), SharedMockRegistry::new());
let query = DiscoveryQuery::EndpointModels {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
};
let stream = discovery.list_and_watch(query, None).await.unwrap();
let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
let base = discovery.register(model_spec("base-model")).await.unwrap();
let lora_a = discovery.register(lora_spec("lora-a")).await.unwrap();
discovery.register(lora_spec("lora-b")).await.unwrap();
poll_until(
&rx,
|s| s.contains_key(&42),
"Worker 42 should be present after registrations",
)
.await;
discovery.unregister(lora_a).await.unwrap();
poll_until(
&rx,
|s| s.get(&42).map(|v| v.as_str()) == Some("base-model"),
"Worker 42 should have base-model after removing lora-a",
)
.await;
{
let state = rx.borrow();
assert_eq!(state.get(&42).map(|s| s.as_str()), Some("base-model"));
}
discovery.unregister(base).await.unwrap();
poll_until(
&rx,
|s| s.get(&42).map(|v| v.as_str()) == Some("lora-b"),
"Worker 42 should fall back to lora-b after removing base model",
)
.await;
{
let state = rx.borrow();
assert_eq!(state.get(&42).map(|s| s.as_str()), Some("lora-b"));
}
}
#[tokio::test]
async fn test_all_models_cross_endpoint_no_alias() {
let registry = SharedMockRegistry::new();
let discovery = MockDiscovery::new(Some(7), registry.clone());
let stream = discovery
.list_and_watch(DiscoveryQuery::AllModels, None)
.await
.unwrap();
let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
let ep_a = discovery
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "ep-a".to_string(),
card_json: serde_json::json!({ "display_name": "model-on-ep-a" }),
model_suffix: None,
})
.await
.unwrap();
discovery
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "ep-b".to_string(),
card_json: serde_json::json!({ "display_name": "model-on-ep-b" }),
model_suffix: None,
})
.await
.unwrap();
poll_until(
&rx,
|s| s.contains_key(&7),
"Worker 7 should appear after registrations",
)
.await;
discovery.unregister(ep_a).await.unwrap();
poll_until(
&rx,
|s| s.get(&7).map(|v| v.as_str()) == Some("model-on-ep-b"),
"Worker 7 should still be present via ep-b after removing ep-a",
)
.await;
}
}