dynamo_runtime/discovery/
utils.rs1use serde::Deserialize;
7
8use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryStream};
9
10fn collapse_by_instance_id<V: Clone>(
16 state: &std::collections::HashMap<DiscoveryInstanceId, V>,
17) -> std::collections::HashMap<u64, V> {
18 let mut result = std::collections::HashMap::new();
19 for (id, val) in state {
20 let instance_id = id.instance_id();
21 let model_suffix = match id {
22 DiscoveryInstanceId::Model(mid) => mid.model_suffix.as_ref(),
23 _ => None,
24 };
25 if model_suffix.is_none() || !result.contains_key(&instance_id) {
26 result.insert(instance_id, val.clone());
27 }
28 }
29 result
30}
31
32pub fn watch_and_extract_field<T, V, F>(
64 stream: DiscoveryStream,
65 extractor: F,
66) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
67where
68 T: for<'de> Deserialize<'de> + 'static,
69 V: Clone + PartialEq + Send + Sync + 'static,
70 F: Fn(T) -> V + Send + 'static,
71{
72 use futures::StreamExt;
73 use std::collections::HashMap;
74
75 let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
76
77 tokio::spawn(async move {
78 let mut state: HashMap<DiscoveryInstanceId, V> = HashMap::new();
84 let mut stream = stream;
85
86 while let Some(result) = stream.next().await {
87 match result {
88 Ok(DiscoveryEvent::Added(instance)) => {
89 let instance_id = instance.instance_id();
90 let key = instance.id();
91
92 let deserialized: T = match instance.deserialize_model() {
94 Ok(d) => d,
95 Err(e) => {
96 tracing::warn!(
97 instance_id,
98 error = %e,
99 "Failed to deserialize discovery instance, skipping"
100 );
101 continue;
102 }
103 };
104
105 let value = extractor(deserialized);
107
108 tracing::debug!(
109 instance_id,
110 ?key,
111 state_len = state.len(),
112 "watch_and_extract_field: inserting instance"
113 );
114
115 state.insert(key, value);
116
117 let collapsed = collapse_by_instance_id(&state);
121 if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
122 tracing::debug!("watch_and_extract_field receiver dropped, stopping");
123 break;
124 }
125 }
126 Ok(DiscoveryEvent::Removed(id)) => {
127 let had_entry = state.contains_key(&id);
128
129 tracing::debug!(
130 instance_id = id.instance_id(),
131 ?id,
132 had_entry,
133 state_len = state.len(),
134 "watch_and_extract_field: removing instance"
135 );
136
137 state.remove(&id);
138
139 let collapsed = collapse_by_instance_id(&state);
143 if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
144 tracing::debug!("watch_and_extract_field receiver dropped, stopping");
145 break;
146 }
147 }
148 Err(e) => {
149 tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
150 }
152 }
153 }
154
155 tracing::debug!("watch_and_extract_field task stopped");
156 });
157
158 rx
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::discovery::mock::{MockDiscovery, SharedMockRegistry};
165 use crate::discovery::{Discovery, DiscoveryQuery, DiscoverySpec};
166
167 #[derive(serde::Deserialize, Clone, Debug)]
169 struct FakeCard {
170 display_name: String,
171 }
172
173 fn model_spec(name: &str) -> DiscoverySpec {
174 DiscoverySpec::Model {
175 namespace: "ns".to_string(),
176 component: "comp".to_string(),
177 endpoint: "generate".to_string(),
178 card_json: serde_json::json!({ "display_name": name }),
179 model_suffix: None,
180 }
181 }
182
183 async fn poll_until(
185 rx: &tokio::sync::watch::Receiver<std::collections::HashMap<u64, String>>,
186 pred: impl Fn(&std::collections::HashMap<u64, String>) -> bool,
187 msg: &str,
188 ) {
189 for _ in 0..100 {
190 if pred(&rx.borrow()) {
191 return;
192 }
193 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
194 }
195 panic!("{}: state={:?}", msg, *rx.borrow());
196 }
197
198 fn lora_spec(lora_name: &str) -> DiscoverySpec {
199 DiscoverySpec::Model {
200 namespace: "ns".to_string(),
201 component: "comp".to_string(),
202 endpoint: "generate".to_string(),
203 card_json: serde_json::json!({
204 "display_name": lora_name,
205 "source_path": "base-model",
206 "lora": { "name": lora_name },
207 }),
208 model_suffix: Some(lora_name.to_string()),
209 }
210 }
211
212 #[tokio::test]
216 async fn test_lora_unregister_preserves_worker_runtime_config() {
217 let discovery = MockDiscovery::new(Some(42), SharedMockRegistry::new());
219
220 let query = DiscoveryQuery::EndpointModels {
221 namespace: "ns".to_string(),
222 component: "comp".to_string(),
223 endpoint: "generate".to_string(),
224 };
225
226 let stream = discovery.list_and_watch(query, None).await.unwrap();
227
228 let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
230
231 let base = discovery.register(model_spec("base-model")).await.unwrap();
233 let lora_a = discovery.register(lora_spec("lora-a")).await.unwrap();
234 discovery.register(lora_spec("lora-b")).await.unwrap();
235
236 poll_until(
237 &rx,
238 |s| s.contains_key(&42),
239 "Worker 42 should be present after registrations",
240 )
241 .await;
242
243 discovery.unregister(lora_a).await.unwrap();
245
246 poll_until(
248 &rx,
249 |s| s.get(&42).map(|v| v.as_str()) == Some("base-model"),
250 "Worker 42 should have base-model after removing lora-a",
251 )
252 .await;
253
254 {
255 let state = rx.borrow();
256 assert_eq!(state.get(&42).map(|s| s.as_str()), Some("base-model"));
257 }
258
259 discovery.unregister(base).await.unwrap();
261
262 poll_until(
263 &rx,
264 |s| s.get(&42).map(|v| v.as_str()) == Some("lora-b"),
265 "Worker 42 should fall back to lora-b after removing base model",
266 )
267 .await;
268
269 {
270 let state = rx.borrow();
271 assert_eq!(state.get(&42).map(|s| s.as_str()), Some("lora-b"));
272 }
273 }
274
275 #[tokio::test]
279 async fn test_all_models_cross_endpoint_no_alias() {
280 let registry = SharedMockRegistry::new();
281 let discovery = MockDiscovery::new(Some(7), registry.clone());
283
284 let stream = discovery
285 .list_and_watch(DiscoveryQuery::AllModels, None)
286 .await
287 .unwrap();
288 let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
289
290 let ep_a = discovery
292 .register(DiscoverySpec::Model {
293 namespace: "ns".to_string(),
294 component: "comp".to_string(),
295 endpoint: "ep-a".to_string(),
296 card_json: serde_json::json!({ "display_name": "model-on-ep-a" }),
297 model_suffix: None,
298 })
299 .await
300 .unwrap();
301
302 discovery
304 .register(DiscoverySpec::Model {
305 namespace: "ns".to_string(),
306 component: "comp".to_string(),
307 endpoint: "ep-b".to_string(),
308 card_json: serde_json::json!({ "display_name": "model-on-ep-b" }),
309 model_suffix: None,
310 })
311 .await
312 .unwrap();
313
314 poll_until(
315 &rx,
316 |s| s.contains_key(&7),
317 "Worker 7 should appear after registrations",
318 )
319 .await;
320
321 discovery.unregister(ep_a).await.unwrap();
323
324 poll_until(
325 &rx,
326 |s| s.get(&7).map(|v| v.as_str()) == Some("model-on-ep-b"),
327 "Worker 7 should still be present via ep-b after removing ep-a",
328 )
329 .await;
330 }
331}