dynamo_llm/
disagg_router.rs1use serde::{Deserialize, Serialize};
5use std::sync::{Arc, Mutex};
6use tokio::sync::watch;
7use tracing;
8
9use dynamo_runtime::DistributedRuntime;
10use dynamo_runtime::transports::etcd::WatchEvent;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct DisaggRouterConf {
14 pub max_local_prefill_length: i32,
15}
16
17impl Default for DisaggRouterConf {
18 fn default() -> Self {
19 Self {
20 max_local_prefill_length: 1000,
21 }
22 }
23}
24
25impl DisaggRouterConf {
26 pub async fn from_etcd_with_watcher(
27 drt: Arc<DistributedRuntime>,
28 model_name: &str,
29 ) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
30 let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
31
32 let Some(etcd_client) = drt.etcd_client() else {
34 anyhow::bail!("Static components don't have an etcd client");
35 };
36 let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
37 Ok(kvs) => {
38 if let Some(kv) = kvs.first() {
39 match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
40 Ok(config) => {
41 tracing::debug!(
42 "Found initial config for key {}: {:?}",
43 etcd_key,
44 config
45 );
46 config
47 }
48 Err(e) => {
49 tracing::warn!(
50 "Failed to parse initial config for key {}: {}",
51 etcd_key,
52 e
53 );
54 DisaggRouterConf::default()
55 }
56 }
57 } else {
58 tracing::debug!(
59 "No initial config found for key {}, using default",
60 etcd_key
61 );
62 DisaggRouterConf::default()
63 }
64 }
65 Err(e) => {
66 tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
67 DisaggRouterConf::default()
68 }
69 };
70
71 let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
73
74 let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
76 let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
77
78 drt.runtime().secondary().spawn(async move {
80 tracing::info!("Starting config watcher for disagg router key: {}", key);
81
82 loop {
83 let kv_event = tokio::select! {
84 _ = watch_tx.closed() => {
85 tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
86 break;
87 }
88 kv_event = kv_event_rx.recv() => {
89 match kv_event {
90 Some(kv_event) => kv_event,
91 None => {
92 tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
93 break;
94 }
95 }
96 }
97 };
98
99 tracing::debug!("Received watch event for key {}", key);
100
101 match kv_event {
102 WatchEvent::Put(kv) => {
103 let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
104 if let Ok(config) = val {
105 tracing::info!("Config updated for key {}: {:?}", key, config);
106 if watch_tx.send(config).is_err() {
108 tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
109 break;
110 }
111 } else {
112 tracing::error!("Unable to parse router config for key {}", key);
113 break;
114 }
115 }
116 WatchEvent::Delete(_) => {
117 tracing::warn!("Config key was deleted: {}", key);
118 if watch_tx.send(DisaggRouterConf::default()).is_err() {
120 tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
121 break;
122 }
123 }
124 }
125 }
126
127 tracing::debug!("Completed config watcher for key: {}", key);
128 });
129
130 Ok((initial_config, watch_rx))
131 }
132}
133
134#[derive(Clone)]
135pub struct DisaggregatedRouter {
136 max_local_prefill_length: Arc<Mutex<i32>>,
137 model_name: String,
138 config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
139}
140
141impl DisaggregatedRouter {
142 pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
143 DisaggregatedRouter {
144 max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
145 model_name,
146 config_watcher: None,
147 }
148 }
149
150 pub async fn new_with_etcd_and_default(
151 drt: Arc<DistributedRuntime>,
152 model_name: String,
153 default_max_local_prefill_length: i32,
154 ) -> anyhow::Result<Self> {
155 let (mut config, watcher) =
156 DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
157
158 if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
160 config.max_local_prefill_length = default_max_local_prefill_length;
161 }
162
163 let router = Self {
164 max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
165 model_name: model_name.clone(),
166 config_watcher: Some(watcher),
167 };
168
169 router.start_config_watcher();
171
172 Ok(router)
173 }
174
175 fn start_config_watcher(&self) {
176 if let Some(watcher) = self.config_watcher.clone() {
177 let mut watcher = watcher;
178 let model_name = self.model_name.clone();
180 let max_local_prefill_length = self.max_local_prefill_length.clone();
181
182 tokio::spawn(async move {
183 tracing::info!("Starting config update watcher for model: {}", model_name);
184
185 while watcher.changed().await.is_ok() {
186 let config = watcher.borrow().clone();
187 let new_value = config.max_local_prefill_length;
188
189 let mut current_value = max_local_prefill_length.lock().unwrap();
191 let old_value = *current_value;
192 if old_value != new_value {
193 *current_value = new_value;
194 tracing::info!(
195 "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
196 model_name,
197 old_value,
198 new_value
199 );
200 }
201 }
202
203 tracing::debug!("Config watcher closed for model: {}", model_name);
204 });
205 }
206 }
207
208 pub fn check_for_updates(&self) {
209 if let Some(watcher) = &self.config_watcher
210 && watcher.has_changed().unwrap_or(false)
211 {
212 let config = watcher.borrow().clone();
213 let new_value = config.max_local_prefill_length;
214
215 let mut current_value = self.max_local_prefill_length.lock().unwrap();
217 let old_value = *current_value;
218 if old_value != new_value {
219 *current_value = new_value;
220 tracing::info!(
221 "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
222 self.model_name,
223 old_value,
224 new_value
225 );
226 }
227 }
228 }
229
230 pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
231 self.check_for_updates();
233
234 let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
236
237 prefill_length - prefix_hit_length > max_local_prefill_length
240 }
241
242 pub fn update_value(&self, max_local_prefill_length: i32) {
243 let mut current = self.max_local_prefill_length.lock().unwrap();
244 *current = max_local_prefill_length;
245 }
246
247 pub fn get_model_name(&self) -> &str {
248 &self.model_name
249 }
250}