dynamo_llm/
disagg_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use serde::{Deserialize, Serialize};
17use std::sync::{Arc, Mutex};
18use tokio::sync::watch;
19use tracing;
20
21use dynamo_runtime::transports::etcd::WatchEvent;
22use dynamo_runtime::DistributedRuntime;
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub struct DisaggRouterConf {
26    pub max_local_prefill_length: i32,
27}
28
29impl Default for DisaggRouterConf {
30    fn default() -> Self {
31        Self {
32            max_local_prefill_length: 1000,
33        }
34    }
35}
36
37impl DisaggRouterConf {
38    pub async fn from_etcd_with_watcher(
39        drt: Arc<DistributedRuntime>,
40        model_name: &str,
41    ) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
42        let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
43
44        // Get the initial value if it exists
45        let Some(etcd_client) = drt.etcd_client() else {
46            anyhow::bail!("Static components don't have an etcd client");
47        };
48        let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
49            Ok(kvs) => {
50                if let Some(kv) = kvs.first() {
51                    match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
52                        Ok(config) => {
53                            tracing::debug!(
54                                "Found initial config for key {}: {:?}",
55                                etcd_key,
56                                config
57                            );
58                            config
59                        }
60                        Err(e) => {
61                            tracing::warn!(
62                                "Failed to parse initial config for key {}: {}",
63                                etcd_key,
64                                e
65                            );
66                            DisaggRouterConf::default()
67                        }
68                    }
69                } else {
70                    tracing::debug!(
71                        "No initial config found for key {}, using default",
72                        etcd_key
73                    );
74                    DisaggRouterConf::default()
75                }
76            }
77            Err(e) => {
78                tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
79                DisaggRouterConf::default()
80            }
81        };
82
83        // Create watch channel for config updates
84        let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
85
86        // Set up the watcher after getting the initial value
87        let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
88        let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
89
90        // Spawn background task to watch for config changes
91        drt.runtime().secondary().spawn(async move {
92            tracing::info!("Starting config watcher for disagg router key: {}", key);
93
94            loop {
95                let kv_event = tokio::select! {
96                    _ = watch_tx.closed() => {
97                        tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
98                        break;
99                    }
100                    kv_event = kv_event_rx.recv() => {
101                        match kv_event {
102                            Some(kv_event) => kv_event,
103                            None => {
104                                tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
105                                break;
106                            }
107                        }
108                    }
109                };
110
111                tracing::debug!("Received watch event for key {}", key);
112
113                match kv_event {
114                    WatchEvent::Put(kv) => {
115                        let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
116                        if let Ok(config) = val {
117                            tracing::info!("Config updated for key {}: {:?}", key, config);
118                            // Broadcast the update
119                            if watch_tx.send(config).is_err() {
120                                tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
121                                break;
122                            }
123                        } else {
124                            tracing::error!("Unable to parse router config for key {}", key);
125                            break;
126                        }
127                    }
128                    WatchEvent::Delete(_) => {
129                        tracing::warn!("Config key was deleted: {}", key);
130                        // Reset to default values
131                        if watch_tx.send(DisaggRouterConf::default()).is_err() {
132                            tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
133                            break;
134                        }
135                    }
136                }
137            }
138
139            tracing::debug!("Completed config watcher for key: {}", key);
140        });
141
142        Ok((initial_config, watch_rx))
143    }
144}
145
146#[derive(Clone)]
147pub struct DisaggregatedRouter {
148    max_local_prefill_length: Arc<Mutex<i32>>,
149    model_name: String,
150    config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
151}
152
153impl DisaggregatedRouter {
154    pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
155        DisaggregatedRouter {
156            max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
157            model_name,
158            config_watcher: None,
159        }
160    }
161
162    pub async fn new_with_etcd_and_default(
163        drt: Arc<DistributedRuntime>,
164        model_name: String,
165        default_max_local_prefill_length: i32,
166    ) -> anyhow::Result<Self> {
167        let (mut config, watcher) =
168            DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
169
170        // Use the provided default if no etcd value was found (when config is the default value)
171        if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
172            config.max_local_prefill_length = default_max_local_prefill_length;
173        }
174
175        let router = Self {
176            max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
177            model_name: model_name.clone(),
178            config_watcher: Some(watcher),
179        };
180
181        // Start background task to watch for config updates
182        router.start_config_watcher();
183
184        Ok(router)
185    }
186
187    fn start_config_watcher(&self) {
188        if let Some(watcher) = self.config_watcher.clone() {
189            let mut watcher = watcher;
190            // Create a clone for the task
191            let model_name = self.model_name.clone();
192            let max_local_prefill_length = self.max_local_prefill_length.clone();
193
194            tokio::spawn(async move {
195                tracing::info!("Starting config update watcher for model: {}", model_name);
196
197                while watcher.changed().await.is_ok() {
198                    let config = watcher.borrow().clone();
199                    let new_value = config.max_local_prefill_length;
200
201                    // Update the value using the mutex
202                    let mut current_value = max_local_prefill_length.lock().unwrap();
203                    let old_value = *current_value;
204                    if old_value != new_value {
205                        *current_value = new_value;
206                        tracing::info!(
207                            "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
208                            model_name,
209                            old_value,
210                            new_value
211                        );
212                    }
213                }
214
215                tracing::debug!("Config watcher closed for model: {}", model_name);
216            });
217        }
218    }
219
220    pub fn check_for_updates(&self) {
221        if let Some(watcher) = &self.config_watcher {
222            if watcher.has_changed().unwrap_or(false) {
223                let config = watcher.borrow().clone();
224                let new_value = config.max_local_prefill_length;
225
226                // Update the value using the mutex
227                let mut current_value = self.max_local_prefill_length.lock().unwrap();
228                let old_value = *current_value;
229                if old_value != new_value {
230                    *current_value = new_value;
231                    tracing::info!(
232                        "Applied config update for model {}: max_local_prefill_length changed from {} to {}",
233                        self.model_name,
234                        old_value,
235                        new_value
236                    );
237                }
238            }
239        }
240    }
241
242    pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
243        // Check for updates before making the decision
244        self.check_for_updates();
245
246        // Get the current value from the mutex
247        let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
248
249        // schedule the request purely based on the prefill length
250        // TODO: apply math models and compare local vs remote prefill TTFT
251        prefill_length - prefix_hit_length > max_local_prefill_length
252    }
253
254    pub fn update_value(&self, max_local_prefill_length: i32) {
255        let mut current = self.max_local_prefill_length.lock().unwrap();
256        *current = max_local_prefill_length;
257    }
258
259    pub fn get_model_name(&self) -> &str {
260        &self.model_name
261    }
262}