dynamo_runtime/discovery/
utils.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Utility functions for working with discovery streams
5
6use serde::Deserialize;
7
8use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream};
9
10/// Helper to watch a discovery stream and extract a specific field into a HashMap
11///
12/// This helper spawns a background task that:
13/// - Deserializes ModelCards from discovery events
14/// - Extracts a specific field using the provided extractor function
15/// - Maintains a HashMap<instance_id, Field> that auto-updates on Add/Remove events
16/// - Returns a watch::Receiver that consumers can use to read the current state
17///
18/// # Type Parameters
19/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard)
20/// - `V`: The extracted field type (e.g., ModelRuntimeConfig)
21/// - `F`: The extractor function type
22///
23/// # Arguments
24/// - `stream`: The discovery event stream to watch
25/// - `extractor`: Function that extracts the desired field from the deserialized type
26///
27/// # Example
28/// ```ignore
29/// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
30/// let runtime_configs_rx = watch_and_extract_field(
31///     stream,
32///     |card: ModelDeploymentCard| card.runtime_config,
33/// );
34///
35/// // Use it:
36/// let configs = runtime_configs_rx.borrow();
37/// if let Some(config) = configs.get(&worker_id) {
38///     // Use config...
39/// }
40/// ```
41pub fn watch_and_extract_field<T, V, F>(
42    stream: DiscoveryStream,
43    extractor: F,
44) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
45where
46    T: for<'de> Deserialize<'de> + 'static,
47    V: Clone + Send + Sync + 'static,
48    F: Fn(T) -> V + Send + 'static,
49{
50    use futures::StreamExt;
51    use std::collections::HashMap;
52
53    let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
54
55    tokio::spawn(async move {
56        let mut state: HashMap<u64, V> = HashMap::new();
57        let mut stream = stream;
58
59        while let Some(result) = stream.next().await {
60            match result {
61                Ok(DiscoveryEvent::Added(instance)) => {
62                    let instance_id = instance.instance_id();
63
64                    // Deserialize the full instance into type T
65                    let deserialized: T = match instance.deserialize_model() {
66                        Ok(d) => d,
67                        Err(e) => {
68                            tracing::warn!(
69                                instance_id,
70                                error = %e,
71                                "Failed to deserialize discovery instance, skipping"
72                            );
73                            continue;
74                        }
75                    };
76
77                    // Extract the field we care about
78                    let value = extractor(deserialized);
79
80                    // Update state and send
81                    state.insert(instance_id, value);
82                    if tx.send(state.clone()).is_err() {
83                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
84                        break;
85                    }
86                }
87                Ok(DiscoveryEvent::Removed(id)) => {
88                    // Remove from state and send update
89                    state.remove(&id.instance_id());
90                    if tx.send(state.clone()).is_err() {
91                        tracing::debug!("watch_and_extract_field receiver dropped, stopping");
92                        break;
93                    }
94                }
95                Err(e) => {
96                    tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
97                    // Continue processing other events
98                }
99            }
100        }
101
102        tracing::debug!("watch_and_extract_field task stopped");
103    });
104
105    rx
106}