dynamo_runtime/discovery/utils.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Utility functions for working with discovery streams
5
6use serde::Deserialize;
7
8use super::{DiscoveryEvent, DiscoveryInstance, DiscoveryStream};
9
10/// Helper to watch a discovery stream and extract a specific field into a HashMap
11///
12/// This helper spawns a background task that:
13/// - Deserializes ModelCards from discovery events
14/// - Extracts a specific field using the provided extractor function
15/// - Maintains a HashMap<instance_id, Field> that auto-updates on Add/Remove events
16/// - Returns a watch::Receiver that consumers can use to read the current state
17///
18/// # Type Parameters
19/// - `T`: The type to deserialize from DiscoveryInstance (e.g., ModelDeploymentCard)
20/// - `V`: The extracted field type (e.g., ModelRuntimeConfig)
21/// - `F`: The extractor function type
22///
23/// # Arguments
24/// - `stream`: The discovery event stream to watch
25/// - `extractor`: Function that extracts the desired field from the deserialized type
26///
27/// # Example
28/// ```ignore
29/// let stream = discovery.list_and_watch(DiscoveryQuery::ComponentModels { ... }, None).await?;
30/// let runtime_configs_rx = watch_and_extract_field(
31/// stream,
32/// |card: ModelDeploymentCard| card.runtime_config,
33/// );
34///
35/// // Use it:
36/// let configs = runtime_configs_rx.borrow();
37/// if let Some(config) = configs.get(&worker_id) {
38/// // Use config...
39/// }
40/// ```
41pub fn watch_and_extract_field<T, V, F>(
42 stream: DiscoveryStream,
43 extractor: F,
44) -> tokio::sync::watch::Receiver<std::collections::HashMap<u64, V>>
45where
46 T: for<'de> Deserialize<'de> + 'static,
47 V: Clone + Send + Sync + 'static,
48 F: Fn(T) -> V + Send + 'static,
49{
50 use futures::StreamExt;
51 use std::collections::HashMap;
52
53 let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
54
55 tokio::spawn(async move {
56 let mut state: HashMap<u64, V> = HashMap::new();
57 let mut stream = stream;
58
59 while let Some(result) = stream.next().await {
60 match result {
61 Ok(DiscoveryEvent::Added(instance)) => {
62 let instance_id = instance.instance_id();
63
64 // Deserialize the full instance into type T
65 let deserialized: T = match instance.deserialize_model() {
66 Ok(d) => d,
67 Err(e) => {
68 tracing::warn!(
69 instance_id,
70 error = %e,
71 "Failed to deserialize discovery instance, skipping"
72 );
73 continue;
74 }
75 };
76
77 // Extract the field we care about
78 let value = extractor(deserialized);
79
80 // Update state and send
81 state.insert(instance_id, value);
82 if tx.send(state.clone()).is_err() {
83 tracing::debug!("watch_and_extract_field receiver dropped, stopping");
84 break;
85 }
86 }
87 Ok(DiscoveryEvent::Removed(instance_id)) => {
88 // Remove from state and send update
89 state.remove(&instance_id);
90 if tx.send(state.clone()).is_err() {
91 tracing::debug!("watch_and_extract_field receiver dropped, stopping");
92 break;
93 }
94 }
95 Err(e) => {
96 tracing::error!(error = %e, "Discovery event stream error in watch_and_extract_field");
97 // Continue processing other events
98 }
99 }
100 }
101
102 tracing::debug!("watch_and_extract_field task stopped");
103 });
104
105 rx
106}