dynamo_llm/
disagg_router.rs1use serde::{Deserialize, Serialize};
17use std::sync::{Arc, Mutex};
18use tokio::sync::watch;
19use tracing;
20
21use dynamo_runtime::transports::etcd::WatchEvent;
22use dynamo_runtime::DistributedRuntime;
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub struct DisaggRouterConf {
26 pub max_local_prefill_length: i32,
27}
28
29impl Default for DisaggRouterConf {
30 fn default() -> Self {
31 Self {
32 max_local_prefill_length: 1000,
33 }
34 }
35}
36
37impl DisaggRouterConf {
38 pub async fn from_etcd_with_watcher(
39 drt: Arc<DistributedRuntime>,
40 model_name: &str,
41 ) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
42 let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
43
44 let Some(etcd_client) = drt.etcd_client() else {
46 anyhow::bail!("Static components don't have an etcd client");
47 };
48 let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
49 Ok(kvs) => {
50 if let Some(kv) = kvs.first() {
51 match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
52 Ok(config) => {
53 tracing::debug!(
54 "Found initial config for key {}: {:?}",
55 etcd_key,
56 config
57 );
58 config
59 }
60 Err(e) => {
61 tracing::warn!(
62 "Failed to parse initial config for key {}: {}",
63 etcd_key,
64 e
65 );
66 DisaggRouterConf::default()
67 }
68 }
69 } else {
70 tracing::debug!(
71 "No initial config found for key {}, using default",
72 etcd_key
73 );
74 DisaggRouterConf::default()
75 }
76 }
77 Err(e) => {
78 tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
79 DisaggRouterConf::default()
80 }
81 };
82
83 let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
85
86 let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
88 let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
89
90 drt.runtime().secondary().spawn(async move {
92 tracing::info!("Starting config watcher for disagg router key: {}", key);
93
94 loop {
95 let kv_event = tokio::select! {
96 _ = watch_tx.closed() => {
97 tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
98 break;
99 }
100 kv_event = kv_event_rx.recv() => {
101 match kv_event {
102 Some(kv_event) => kv_event,
103 None => {
104 tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
105 break;
106 }
107 }
108 }
109 };
110
111 tracing::debug!("Received watch event for key {}", key);
112
113 match kv_event {
114 WatchEvent::Put(kv) => {
115 let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
116 if let Ok(config) = val {
117 tracing::info!("Config updated for key {}: {:?}", key, config);
118 if watch_tx.send(config).is_err() {
120 tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
121 break;
122 }
123 } else {
124 tracing::error!("Unable to parse router config for key {}", key);
125 break;
126 }
127 }
128 WatchEvent::Delete(_) => {
129 tracing::warn!("Config key was deleted: {}", key);
130 if watch_tx.send(DisaggRouterConf::default()).is_err() {
132 tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
133 break;
134 }
135 }
136 }
137 }
138
139 tracing::debug!("Completed config watcher for key: {}", key);
140 });
141
142 Ok((initial_config, watch_rx))
143 }
144}
145
146#[derive(Clone)]
147pub struct DisaggregatedRouter {
148 max_local_prefill_length: Arc<Mutex<i32>>,
149 model_name: String,
150 config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
151}
152
153impl DisaggregatedRouter {
154 pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
155 DisaggregatedRouter {
156 max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
157 model_name,
158 config_watcher: None,
159 }
160 }
161
162 pub async fn new_with_etcd_and_default(
163 drt: Arc<DistributedRuntime>,
164 model_name: String,
165 default_max_local_prefill_length: i32,
166 ) -> anyhow::Result<Self> {
167 let (mut config, watcher) =
168 DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
169
170 if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
172 config.max_local_prefill_length = default_max_local_prefill_length;
173 }
174
175 let router = Self {
176 max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
177 model_name: model_name.clone(),
178 config_watcher: Some(watcher),
179 };
180
181 router.start_config_watcher();
183
184 Ok(router)
185 }
186
187 fn start_config_watcher(&self) {
188 if let Some(watcher) = self.config_watcher.clone() {
189 let mut watcher = watcher;
190 let model_name = self.model_name.clone();
192 let max_local_prefill_length = self.max_local_prefill_length.clone();
193
194 tokio::spawn(async move {
195 tracing::info!("Starting config update watcher for model: {}", model_name);
196
197 while watcher.changed().await.is_ok() {
198 let config = watcher.borrow().clone();
199 let new_value = config.max_local_prefill_length;
200
201 let mut current_value = max_local_prefill_length.lock().unwrap();
203 let old_value = *current_value;
204 if old_value != new_value {
205 *current_value = new_value;
206 tracing::info!(
207 "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
208 model_name,
209 old_value,
210 new_value
211 );
212 }
213 }
214
215 tracing::debug!("Config watcher closed for model: {}", model_name);
216 });
217 }
218 }
219
220 pub fn check_for_updates(&self) {
221 if let Some(watcher) = &self.config_watcher {
222 if watcher.has_changed().unwrap_or(false) {
223 let config = watcher.borrow().clone();
224 let new_value = config.max_local_prefill_length;
225
226 let mut current_value = self.max_local_prefill_length.lock().unwrap();
228 let old_value = *current_value;
229 if old_value != new_value {
230 *current_value = new_value;
231 tracing::info!(
232 "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
233 self.model_name,
234 old_value,
235 new_value
236 );
237 }
238 }
239 }
240 }
241
242 pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
243 self.check_for_updates();
245
246 let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
248
249 prefill_length - prefix_hit_length > max_local_prefill_length
252 }
253
254 pub fn update_value(&self, max_local_prefill_length: i32) {
255 let mut current = self.max_local_prefill_length.lock().unwrap();
256 *current = max_local_prefill_length;
257 }
258
259 pub fn get_model_name(&self) -> &str {
260 &self.model_name
261 }
262}