Skip to main content

dynamo_runtime/discovery/
utils.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Utility functions for working with discovery streams
5
6use serde::Deserialize;
7
8use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryStream};
9
10/// Collapse state keyed by full `DiscoveryInstanceId` into a flat HashMap<u64, V>.
11/// When multiple entries share the same instance_id (e.g., base model +
12/// LoRA adapters on the same worker, or the same worker on different endpoints),
13/// the base model (suffix=None) is preferred. If no base model exists, an
14/// arbitrary LoRA entry is used.
15fn collapse_by_instance_id<V: Clone>(
16    state: &std::collections::HashMap<DiscoveryInstanceId, V>,
17) -> std::collections::HashMap<u64, V> {
18    let mut result = std::collections::HashMap::new();
19    for (id, val) in state {
20        let instance_id = id.instance_id();
21        let model_suffix = match id {
22            DiscoveryInstanceId::Model(mid) => mid.model_suffix.as_ref(),
23            _ => None,
24        };
25        if model_suffix.is_none() || !result.contains_key(&instance_id) {
26            result.insert(instance_id, val.clone());
27        }
28    }
29    result
30}
31
32/// Helper to watch a discovery stream and extract a specific field into a HashMap
33///
34/// This helper spawns a background task that:
35/// - Deserializes ModelCards from discovery events
36/// - Extracts a specific field using the provided extractor function
37/// - Maintains a HashMap<instance_id, Field> that auto-updates on Add/Remove events
38/// - Returns a watch::Receiver that consumers can use to read the current state
39///
40/// # Type Parameters
41/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard)
42/// - `V`: The extracted field type (e.g., ModelRuntimeConfig)
43/// - `F`: The extractor function type
44///
45/// # Arguments
46/// - `stream`: The discovery event stream to watch
47/// - `extractor`: Function that extracts the desired field from the deserialized type
48///
49/// # Example
50/// ```ignore
51/// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
52/// let runtime_configs_rx = watch_and_extract_field(
53///     stream,
54///     |card: ModelDeploymentCard| card.runtime_config,
55/// );
56///
57/// // Use it:
58/// let configs = runtime_configs_rx.borrow();
59/// if let Some(config) = configs.get(&worker_id) {
60///     // Use config...
61/// }
62/// ```
63pub fn watch_and_extract_field<T, V, F>(
64    stream: DiscoveryStream,
65    extractor: F,
66) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
67where
68    T: for<'de> Deserialize<'de> + 'static,
69    V: Clone + PartialEq + Send + Sync + 'static,
70    F: Fn(T) -> V + Send + 'static,
71{
72    use futures::StreamExt;
73    use std::collections::HashMap;
74
75    let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
76
77    tokio::spawn(async move {
78        // Internal state keyed by full DiscoveryInstanceId to correctly
79        // distinguish entries across namespaces, components, endpoints, and
80        // model suffixes — even when they share the same raw instance_id.
81        // Collapsed to HashMap<u64, V> for consumers, preferring suffix=None
82        // (base model) when multiple entries exist for the same instance_id.
83        let mut state: HashMap<DiscoveryInstanceId, V> = HashMap::new();
84        let mut stream = stream;
85
86        while let Some(result) = stream.next().await {
87            match result {
88                Ok(DiscoveryEvent::Added(instance)) => {
89                    let instance_id = instance.instance_id();
90                    let key = instance.id();
91
92                    // Deserialize the full instance into type T
93                    let deserialized: T = match instance.deserialize_model() {
94                        Ok(d) => d,
95                        Err(e) => {
96                            tracing::warn!(
97                                instance_id,
98                                error = %e,
99                                "Failed to deserialize discovery instance, skipping"
100                            );
101                            continue;
102                        }
103                    };
104
105                    // Extract the field we care about
106                    let value = extractor(deserialized);
107
108                    tracing::debug!(
109                        instance_id,
110                        ?key,
111                        state_len = state.len(),
112                        "watch_and_extract_field: inserting instance"
113                    );
114
115                    state.insert(key, value);
116
117                    // Only publish if the collapsed worker view actually changed,
118                    // to avoid waking downstream watchers on no-op events
119                    // (e.g., adding a LoRA when base model already represents the worker).
120                    let collapsed = collapse_by_instance_id(&state);
121                    if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
122                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
123                        break;
124                    }
125                }
126                Ok(DiscoveryEvent::Removed(id)) => {
127                    let had_entry = state.contains_key(&id);
128
129                    tracing::debug!(
130                        instance_id = id.instance_id(),
131                        ?id,
132                        had_entry,
133                        state_len = state.len(),
134                        "watch_and_extract_field: removing instance"
135                    );
136
137                    state.remove(&id);
138
139                    // Only publish if the collapsed worker view actually changed,
140                    // to avoid waking downstream watchers on no-op events
141                    // (e.g., adding a LoRA when base model already represents the worker).
142                    let collapsed = collapse_by_instance_id(&state);
143                    if *tx.borrow() != collapsed && tx.send(collapsed).is_err() {
144                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
145                        break;
146                    }
147                }
148                Err(e) => {
149                    tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
150                    // Continue processing other events
151                }
152            }
153        }
154
155        tracing::debug!("watch_and_extract_field task stopped");
156    });
157
158    rx
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::discovery::mock::{MockDiscovery, SharedMockRegistry};
165    use crate::discovery::{Discovery, DiscoveryQuery, DiscoverySpec};
166
167    /// Minimal struct that mirrors the fields watch_and_extract_field deserializes.
168    #[derive(serde::Deserialize, Clone, Debug)]
169    struct FakeCard {
170        display_name: String,
171    }
172
173    fn model_spec(name: &str) -> DiscoverySpec {
174        DiscoverySpec::Model {
175            namespace: "ns".to_string(),
176            component: "comp".to_string(),
177            endpoint: "generate".to_string(),
178            card_json: serde_json::json!({ "display_name": name }),
179            model_suffix: None,
180        }
181    }
182
183    /// Poll a watch receiver until the predicate is satisfied, or timeout after 1s.
184    async fn poll_until(
185        rx: &tokio::sync::watch::Receiver<std::collections::HashMap<u64, String>>,
186        pred: impl Fn(&std::collections::HashMap<u64, String>) -> bool,
187        msg: &str,
188    ) {
189        for _ in 0..100 {
190            if pred(&rx.borrow()) {
191                return;
192            }
193            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
194        }
195        panic!("{}: state={:?}", msg, *rx.borrow());
196    }
197
198    fn lora_spec(lora_name: &str) -> DiscoverySpec {
199        DiscoverySpec::Model {
200            namespace: "ns".to_string(),
201            component: "comp".to_string(),
202            endpoint: "generate".to_string(),
203            card_json: serde_json::json!({
204                "display_name": lora_name,
205                "source_path": "base-model",
206                "lora": { "name": lora_name },
207            }),
208            model_suffix: Some(lora_name.to_string()),
209        }
210    }
211
212    /// Unregistering a single LoRA adapter must not remove the worker's
213    /// runtime config. Base model and other LoRA adapters on the same worker
214    /// share the same instance_id; removing one must leave the others intact.
215    #[tokio::test]
216    async fn test_lora_unregister_preserves_worker_runtime_config() {
217        // All registrations use the same instance_id (same worker)
218        let discovery = MockDiscovery::new(Some(42), SharedMockRegistry::new());
219
220        let query = DiscoveryQuery::EndpointModels {
221            namespace: "ns".to_string(),
222            component: "comp".to_string(),
223            endpoint: "generate".to_string(),
224        };
225
226        let stream = discovery.list_and_watch(query, None).await.unwrap();
227
228        // Watch the stream, extracting display_name as a stand-in for runtime_config
229        let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
230
231        // Register base model + LoRA-A + LoRA-B on the same worker (instance_id=42)
232        let base = discovery.register(model_spec("base-model")).await.unwrap();
233        let lora_a = discovery.register(lora_spec("lora-a")).await.unwrap();
234        discovery.register(lora_spec("lora-b")).await.unwrap();
235
236        poll_until(
237            &rx,
238            |s| s.contains_key(&42),
239            "Worker 42 should be present after registrations",
240        )
241        .await;
242
243        // Unregister LoRA-A only — base model and LoRA-B remain.
244        discovery.unregister(lora_a).await.unwrap();
245
246        // Base model is preferred in the collapsed view.
247        poll_until(
248            &rx,
249            |s| s.get(&42).map(|v| v.as_str()) == Some("base-model"),
250            "Worker 42 should have base-model after removing lora-a",
251        )
252        .await;
253
254        {
255            let state = rx.borrow();
256            assert_eq!(state.get(&42).map(|s| s.as_str()), Some("base-model"));
257        }
258
259        // Unregister the base model — lora-b should be the fallback.
260        discovery.unregister(base).await.unwrap();
261
262        poll_until(
263            &rx,
264            |s| s.get(&42).map(|v| v.as_str()) == Some("lora-b"),
265            "Worker 42 should fall back to lora-b after removing base model",
266        )
267        .await;
268
269        {
270            let state = rx.borrow();
271            assert_eq!(state.get(&42).map(|s| s.as_str()), Some("lora-b"));
272        }
273    }
274
275    /// Same worker (instance_id) registered on two different endpoints must not
276    /// alias when watched via AllModels. Removing the registration from one
277    /// endpoint must leave the other intact in the collapsed view.
278    #[tokio::test]
279    async fn test_all_models_cross_endpoint_no_alias() {
280        let registry = SharedMockRegistry::new();
281        // Same instance_id for both — simulates a single worker serving two endpoints
282        let discovery = MockDiscovery::new(Some(7), registry.clone());
283
284        let stream = discovery
285            .list_and_watch(DiscoveryQuery::AllModels, None)
286            .await
287            .unwrap();
288        let rx = watch_and_extract_field(stream, |card: FakeCard| card.display_name);
289
290        // Register on endpoint "ep-a"
291        let ep_a = discovery
292            .register(DiscoverySpec::Model {
293                namespace: "ns".to_string(),
294                component: "comp".to_string(),
295                endpoint: "ep-a".to_string(),
296                card_json: serde_json::json!({ "display_name": "model-on-ep-a" }),
297                model_suffix: None,
298            })
299            .await
300            .unwrap();
301
302        // Register on endpoint "ep-b"
303        discovery
304            .register(DiscoverySpec::Model {
305                namespace: "ns".to_string(),
306                component: "comp".to_string(),
307                endpoint: "ep-b".to_string(),
308                card_json: serde_json::json!({ "display_name": "model-on-ep-b" }),
309                model_suffix: None,
310            })
311            .await
312            .unwrap();
313
314        poll_until(
315            &rx,
316            |s| s.contains_key(&7),
317            "Worker 7 should appear after registrations",
318        )
319        .await;
320
321        // Remove the ep-a registration — ep-b should keep worker 7 alive.
322        discovery.unregister(ep_a).await.unwrap();
323
324        poll_until(
325            &rx,
326            |s| s.get(&7).map(|v| v.as_str()) == Some("model-on-ep-b"),
327            "Worker 7 should still be present via ep-b after removing ep-a",
328        )
329        .await;
330    }
331}