dynamo_runtime/utils/
worker_monitor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring
5// system is not tied to KV cache concepts, which are LLM-specific. This would allow
6// different types of workers to define their own load metrics and busy thresholds.
7
8use crate::component::{Client, InstanceSource};
9use crate::traits::DistributedRuntimeProvider;
10use crate::traits::events::EventSubscriber;
11use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use tokio::sync::watch;
15use tokio_stream::StreamExt;
16
17// Constants for monitoring configuration
18const KV_METRICS_SUBJECT: &str = "kv_metrics";
19
20// Internal structs for deserializing metrics events
21#[derive(serde::Deserialize)]
22struct LoadEvent {
23    worker_id: i64,
24    data: ForwardPassMetrics,
25}
26
27#[derive(serde::Deserialize)]
28struct ForwardPassMetrics {
29    kv_stats: KvStats,
30}
31
32#[derive(serde::Deserialize)]
33struct KvStats {
34    kv_active_blocks: u64,
35}
36
37#[derive(serde::Deserialize)]
38struct RuntimeConfig {
39    total_kv_blocks: Option<u64>,
40}
41
42/// Worker load monitoring state
43#[derive(Clone, Debug)]
44pub struct WorkerLoadState {
45    pub kv_active_blocks: Option<u64>,
46    pub kv_total_blocks: Option<u64>,
47}
48
49impl WorkerLoadState {
50    pub fn is_busy(&self, threshold: f64) -> bool {
51        match (self.kv_active_blocks, self.kv_total_blocks) {
52            (Some(active), Some(total)) if total > 0 => {
53                (active as f64) > (threshold * total as f64)
54            }
55            _ => false,
56        }
57    }
58}
59
60/// Worker monitor for tracking KV cache usage and busy states
61pub struct WorkerMonitor {
62    client: Arc<Client>,
63    worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
64    busy_threshold: f64,
65}
66
67impl WorkerMonitor {
68    /// Create a new worker monitor with custom threshold
69    pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
70        Self {
71            client,
72            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
73            busy_threshold,
74        }
75    }
76
77    /// Get the worker load states for external access
78    pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
79        self.worker_load_states.clone()
80    }
81
82    /// Start background monitoring of worker KV cache usage
83    pub async fn start_monitoring(&self) -> anyhow::Result<()> {
84        let endpoint = &self.client.endpoint;
85        let component = endpoint.component();
86
87        let Some(etcd_client) = component.drt().etcd_client() else {
88            // Static mode, no monitoring needed
89            return Ok(());
90        };
91
92        // WorkerMonitor is in the wrong crate. It deals with LLM things (KV) so it should be in
93        // dynamo-llm not dynamo-runtime.
94        // That means we cannot use ModelDeploymentCard, so use serde_json::Value for now .
95        let runtime_configs_watcher = watch_prefix_with_extraction(
96            etcd_client,
97            "v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate
98            key_extractors::lease_id,
99            |card: serde_json::Value| {
100                card.get("runtime_config")
101                    .and_then(|rc| rc.get("total_kv_blocks"))
102                    .and_then(|t_kv| t_kv.as_u64())
103            },
104            component.drt().child_token(),
105        )
106        .await?;
107        let mut config_events_rx = runtime_configs_watcher.receiver();
108
109        // Subscribe to KV metrics events
110        let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
111
112        let worker_load_states = self.worker_load_states.clone();
113        let client = self.client.clone();
114        let cancellation_token = component.drt().child_token();
115        let busy_threshold = self.busy_threshold; // Capture threshold for the closure
116
117        // Spawn background monitoring task
118        tokio::spawn(async move {
119            let mut previous_busy_instances = Vec::new(); // Track previous state
120
121            loop {
122                tokio::select! {
123                    _ = cancellation_token.cancelled() => {
124                        tracing::debug!("Worker monitoring cancelled");
125                        break;
126                    }
127
128                    // Handle runtime config updates - now receives full HashMap
129                    _ = config_events_rx.changed() => {
130                        let runtime_configs = config_events_rx.borrow().clone();
131
132                        let mut states = worker_load_states.write().unwrap();
133                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
134
135                        // Update worker load states with total blocks
136                        for (lease_id, total_blocks) in runtime_configs.iter() {
137                            let state = states.entry(*lease_id).or_insert(WorkerLoadState {
138                                kv_active_blocks: None,
139                                kv_total_blocks: None,
140                            });
141                            state.kv_total_blocks = Some(*total_blocks);
142                        }
143                    }
144
145                    // Handle KV metrics updates
146                    kv_event = kv_metrics_rx.next() => {
147                        let Some(event) = kv_event else {
148                            tracing::debug!("KV metrics stream closed");
149                            break;
150                        };
151
152                        if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
153                            let worker_id = load_event.worker_id;
154                            let active_blocks = load_event.data.kv_stats.kv_active_blocks;
155
156                            // Update worker load state
157                            let mut states = worker_load_states.write().unwrap();
158                            let state = states.entry(worker_id).or_insert(WorkerLoadState {
159                                kv_active_blocks: None,
160                                kv_total_blocks: None,
161                            });
162                            state.kv_active_blocks = Some(active_blocks);
163                            drop(states);
164
165                            // Recalculate all busy instances and update
166                            let states = worker_load_states.read().unwrap();
167                            let busy_instances: Vec<i64> = states
168                                .iter()
169                                .filter_map(|(&id, state)| {
170                                    state.is_busy(busy_threshold).then_some(id)
171                                })
172                                .collect();
173                            drop(states);
174
175                            // Only update if busy_instances has changed
176                            if busy_instances != previous_busy_instances {
177                                tracing::debug!("Busy instances changed: {:?}", busy_instances);
178                                client.update_free_instances(&busy_instances);
179                                previous_busy_instances = busy_instances;
180                            }
181                        }
182                    }
183                }
184            }
185
186            tracing::info!("Worker monitoring task exiting");
187        });
188
189        Ok(())
190    }
191}