dynamo_runtime/utils/
worker_monitor.rs1use crate::component::{Client, InstanceSource};
9use crate::traits::DistributedRuntimeProvider;
10use crate::traits::events::EventSubscriber;
11use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use tokio::sync::watch;
15use tokio_stream::StreamExt;
16
17const KV_METRICS_SUBJECT: &str = "kv_metrics";
19
20#[derive(serde::Deserialize)]
22struct LoadEvent {
23 worker_id: i64,
24 data: ForwardPassMetrics,
25}
26
27#[derive(serde::Deserialize)]
28struct ForwardPassMetrics {
29 kv_stats: KvStats,
30}
31
32#[derive(serde::Deserialize)]
33struct KvStats {
34 kv_active_blocks: u64,
35}
36
37#[derive(serde::Deserialize)]
38struct RuntimeConfig {
39 total_kv_blocks: Option<u64>,
40}
41
42#[derive(Clone, Debug)]
44pub struct WorkerLoadState {
45 pub kv_active_blocks: Option<u64>,
46 pub kv_total_blocks: Option<u64>,
47}
48
49impl WorkerLoadState {
50 pub fn is_busy(&self, threshold: f64) -> bool {
51 match (self.kv_active_blocks, self.kv_total_blocks) {
52 (Some(active), Some(total)) if total > 0 => {
53 (active as f64) > (threshold * total as f64)
54 }
55 _ => false,
56 }
57 }
58}
59
60pub struct WorkerMonitor {
62 client: Arc<Client>,
63 worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
64 busy_threshold: f64,
65}
66
67impl WorkerMonitor {
68 pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
70 Self {
71 client,
72 worker_load_states: Arc::new(RwLock::new(HashMap::new())),
73 busy_threshold,
74 }
75 }
76
77 pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
79 self.worker_load_states.clone()
80 }
81
82 pub async fn start_monitoring(&self) -> anyhow::Result<()> {
84 let endpoint = &self.client.endpoint;
85 let component = endpoint.component();
86
87 let Some(etcd_client) = component.drt().etcd_client() else {
88 return Ok(());
90 };
91
92 let runtime_configs_watcher = watch_prefix_with_extraction(
96 etcd_client,
97 "v1/mdc/", key_extractors::lease_id,
99 |card: serde_json::Value| {
100 card.get("runtime_config")
101 .and_then(|rc| rc.get("total_kv_blocks"))
102 .and_then(|t_kv| t_kv.as_u64())
103 },
104 component.drt().child_token(),
105 )
106 .await?;
107 let mut config_events_rx = runtime_configs_watcher.receiver();
108
109 let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
111
112 let worker_load_states = self.worker_load_states.clone();
113 let client = self.client.clone();
114 let cancellation_token = component.drt().child_token();
115 let busy_threshold = self.busy_threshold; tokio::spawn(async move {
119 let mut previous_busy_instances = Vec::new(); loop {
122 tokio::select! {
123 _ = cancellation_token.cancelled() => {
124 tracing::debug!("Worker monitoring cancelled");
125 break;
126 }
127
128 _ = config_events_rx.changed() => {
130 let runtime_configs = config_events_rx.borrow().clone();
131
132 let mut states = worker_load_states.write().unwrap();
133 states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
134
135 for (lease_id, total_blocks) in runtime_configs.iter() {
137 let state = states.entry(*lease_id).or_insert(WorkerLoadState {
138 kv_active_blocks: None,
139 kv_total_blocks: None,
140 });
141 state.kv_total_blocks = Some(*total_blocks);
142 }
143 }
144
145 kv_event = kv_metrics_rx.next() => {
147 let Some(event) = kv_event else {
148 tracing::debug!("KV metrics stream closed");
149 break;
150 };
151
152 if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
153 let worker_id = load_event.worker_id;
154 let active_blocks = load_event.data.kv_stats.kv_active_blocks;
155
156 let mut states = worker_load_states.write().unwrap();
158 let state = states.entry(worker_id).or_insert(WorkerLoadState {
159 kv_active_blocks: None,
160 kv_total_blocks: None,
161 });
162 state.kv_active_blocks = Some(active_blocks);
163 drop(states);
164
165 let states = worker_load_states.read().unwrap();
167 let busy_instances: Vec<i64> = states
168 .iter()
169 .filter_map(|(&id, state)| {
170 state.is_busy(busy_threshold).then_some(id)
171 })
172 .collect();
173 drop(states);
174
175 if busy_instances != previous_busy_instances {
177 tracing::debug!("Busy instances changed: {:?}", busy_instances);
178 client.update_free_instances(&busy_instances);
179 previous_busy_instances = busy_instances;
180 }
181 }
182 }
183 }
184 }
185
186 tracing::info!("Worker monitoring task exiting");
187 });
188
189 Ok(())
190 }
191}