dynamo_llm/
disagg_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::sync::{Arc, Mutex};
6use tokio::sync::watch;
7use tracing;
8
9use dynamo_runtime::DistributedRuntime;
10use dynamo_runtime::transports::etcd::WatchEvent;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct DisaggRouterConf {
14    pub max_local_prefill_length: i32,
15}
16
17impl Default for DisaggRouterConf {
18    fn default() -> Self {
19        Self {
20            max_local_prefill_length: 1000,
21        }
22    }
23}
24
25impl DisaggRouterConf {
26    pub async fn from_etcd_with_watcher(
27        drt: Arc<DistributedRuntime>,
28        model_name: &str,
29    ) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
30        let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
31
32        // Get the initial value if it exists
33        let Some(etcd_client) = drt.etcd_client() else {
34            anyhow::bail!("Static components don't have an etcd client");
35        };
36        let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
37            Ok(kvs) => {
38                if let Some(kv) = kvs.first() {
39                    match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
40                        Ok(config) => {
41                            tracing::debug!(
42                                "Found initial config for key {}: {:?}",
43                                etcd_key,
44                                config
45                            );
46                            config
47                        }
48                        Err(e) => {
49                            tracing::warn!(
50                                "Failed to parse initial config for key {}: {}",
51                                etcd_key,
52                                e
53                            );
54                            DisaggRouterConf::default()
55                        }
56                    }
57                } else {
58                    tracing::debug!(
59                        "No initial config found for key {}, using default",
60                        etcd_key
61                    );
62                    DisaggRouterConf::default()
63                }
64            }
65            Err(e) => {
66                tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
67                DisaggRouterConf::default()
68            }
69        };
70
71        // Create watch channel for config updates
72        let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
73
74        // Set up the watcher after getting the initial value
75        let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
76        let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
77
78        // Spawn background task to watch for config changes
79        drt.runtime().secondary().spawn(async move {
80            tracing::info!("Starting config watcher for disagg router key: {}", key);
81
82            loop {
83                let kv_event = tokio::select! {
84                    _ = watch_tx.closed() => {
85                        tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
86                        break;
87                    }
88                    kv_event = kv_event_rx.recv() => {
89                        match kv_event {
90                            Some(kv_event) => kv_event,
91                            None => {
92                                tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
93                                break;
94                            }
95                        }
96                    }
97                };
98
99                tracing::debug!("Received watch event for key {}", key);
100
101                match kv_event {
102                    WatchEvent::Put(kv) => {
103                        let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
104                        if let Ok(config) = val {
105                            tracing::info!("Config updated for key {}: {:?}", key, config);
106                            // Broadcast the update
107                            if watch_tx.send(config).is_err() {
108                                tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
109                                break;
110                            }
111                        } else {
112                            tracing::error!("Unable to parse router config for key {}", key);
113                            break;
114                        }
115                    }
116                    WatchEvent::Delete(_) => {
117                        tracing::warn!("Config key was deleted: {}", key);
118                        // Reset to default values
119                        if watch_tx.send(DisaggRouterConf::default()).is_err() {
120                            tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
121                            break;
122                        }
123                    }
124                }
125            }
126
127            tracing::debug!("Completed config watcher for key: {}", key);
128        });
129
130        Ok((initial_config, watch_rx))
131    }
132}
133
134#[derive(Clone)]
135pub struct DisaggregatedRouter {
136    max_local_prefill_length: Arc<Mutex<i32>>,
137    model_name: String,
138    config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
139}
140
141impl DisaggregatedRouter {
142    pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
143        DisaggregatedRouter {
144            max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
145            model_name,
146            config_watcher: None,
147        }
148    }
149
150    pub async fn new_with_etcd_and_default(
151        drt: Arc<DistributedRuntime>,
152        model_name: String,
153        default_max_local_prefill_length: i32,
154    ) -> anyhow::Result<Self> {
155        let (mut config, watcher) =
156            DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
157
158        // Use the provided default if no etcd value was found (when config is the default value)
159        if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
160            config.max_local_prefill_length = default_max_local_prefill_length;
161        }
162
163        let router = Self {
164            max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
165            model_name: model_name.clone(),
166            config_watcher: Some(watcher),
167        };
168
169        // Start background task to watch for config updates
170        router.start_config_watcher();
171
172        Ok(router)
173    }
174
175    fn start_config_watcher(&self) {
176        if let Some(watcher) = self.config_watcher.clone() {
177            let mut watcher = watcher;
178            // Create a clone for the task
179            let model_name = self.model_name.clone();
180            let max_local_prefill_length = self.max_local_prefill_length.clone();
181
182            tokio::spawn(async move {
183                tracing::info!("Starting config update watcher for model: {}", model_name);
184
185                while watcher.changed().await.is_ok() {
186                    let config = watcher.borrow().clone();
187                    let new_value = config.max_local_prefill_length;
188
189                    // Update the value using the mutex
190                    let mut current_value = max_local_prefill_length.lock().unwrap();
191                    let old_value = *current_value;
192                    if old_value != new_value {
193                        *current_value = new_value;
194                        tracing::info!(
195                            "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
196                            model_name,
197                            old_value,
198                            new_value
199                        );
200                    }
201                }
202
203                tracing::debug!("Config watcher closed for model: {}", model_name);
204            });
205        }
206    }
207
208    pub fn check_for_updates(&self) {
209        if let Some(watcher) = &self.config_watcher
210            && watcher.has_changed().unwrap_or(false)
211        {
212            let config = watcher.borrow().clone();
213            let new_value = config.max_local_prefill_length;
214
215            // Update the value using the mutex
216            let mut current_value = self.max_local_prefill_length.lock().unwrap();
217            let old_value = *current_value;
218            if old_value != new_value {
219                *current_value = new_value;
220                tracing::info!(
221                    "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
222                    self.model_name,
223                    old_value,
224                    new_value
225                );
226            }
227        }
228    }
229
230    pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
231        // Check for updates before making the decision
232        self.check_for_updates();
233
234        // Get the current value from the mutex
235        let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
236
237        // schedule the request purely based on the prefill length
238        // TODO: apply math models and compare local vs remote prefill TTFT
239        prefill_length - prefix_hit_length > max_local_prefill_length
240    }
241
242    pub fn update_value(&self, max_local_prefill_length: i32) {
243        let mut current = self.max_local_prefill_length.lock().unwrap();
244        *current = max_local_prefill_length;
245    }
246
247    pub fn get_model_name(&self) -> &str {
248        &self.model_name
249    }
250}