1use std::{
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8
9use parking_lot::{Mutex, RwLock};
10
11use dynamo_runtime::component::Component;
12use dynamo_runtime::prelude::DistributedRuntimeProvider;
13
14use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
15use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
16use crate::{
17 kv_router::KvRouter,
18 types::generic::tensor::TensorStreamingEngine,
19 types::openai::{
20 chat_completions::OpenAIChatCompletionsStreamingEngine,
21 completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
22 },
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ModelManagerError {
27 #[error("Model not found: {0}")]
28 ModelNotFound(String),
29
30 #[error("Model already exists: {0}")]
31 ModelAlreadyExists(String),
32}
33
34pub struct ModelManager {
36 completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
38 chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
39 embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
40 tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
41
42 entries: Mutex<HashMap<String, ModelEntry>>,
44 kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
45}
46
47impl Default for ModelManager {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl ModelManager {
54 pub fn new() -> Self {
55 Self {
56 completion_engines: RwLock::new(ModelEngines::default()),
57 chat_completion_engines: RwLock::new(ModelEngines::default()),
58 embeddings_engines: RwLock::new(ModelEngines::default()),
59 tensor_engines: RwLock::new(ModelEngines::default()),
60 entries: Mutex::new(HashMap::new()),
61 kv_choosers: Mutex::new(HashMap::new()),
62 }
63 }
64
65 pub fn get_model_entries(&self) -> Vec<ModelEntry> {
66 self.entries.lock().values().cloned().collect()
67 }
68
69 pub fn has_model_any(&self, model: &str) -> bool {
70 self.chat_completion_engines.read().contains(model)
71 || self.completion_engines.read().contains(model)
72 }
73
74 pub fn model_display_names(&self) -> HashSet<String> {
75 self.list_chat_completions_models()
76 .into_iter()
77 .chain(self.list_completions_models())
78 .chain(self.list_embeddings_models())
79 .chain(self.list_tensor_models())
80 .collect()
81 }
82
83 pub fn list_chat_completions_models(&self) -> Vec<String> {
84 self.chat_completion_engines.read().list()
85 }
86
87 pub fn list_completions_models(&self) -> Vec<String> {
88 self.completion_engines.read().list()
89 }
90
91 pub fn list_embeddings_models(&self) -> Vec<String> {
92 self.embeddings_engines.read().list()
93 }
94
95 pub fn list_tensor_models(&self) -> Vec<String> {
96 self.tensor_engines.read().list()
97 }
98
99 pub fn add_completions_model(
100 &self,
101 model: &str,
102 engine: OpenAICompletionsStreamingEngine,
103 ) -> Result<(), ModelManagerError> {
104 let mut clients = self.completion_engines.write();
105 clients.add(model, engine)
106 }
107
108 pub fn add_chat_completions_model(
109 &self,
110 model: &str,
111 engine: OpenAIChatCompletionsStreamingEngine,
112 ) -> Result<(), ModelManagerError> {
113 let mut clients = self.chat_completion_engines.write();
114 clients.add(model, engine)
115 }
116
117 pub fn add_embeddings_model(
118 &self,
119 model: &str,
120 engine: OpenAIEmbeddingsStreamingEngine,
121 ) -> Result<(), ModelManagerError> {
122 let mut clients = self.embeddings_engines.write();
123 clients.add(model, engine)
124 }
125
126 pub fn add_tensor_model(
127 &self,
128 model: &str,
129 engine: TensorStreamingEngine,
130 ) -> Result<(), ModelManagerError> {
131 let mut clients = self.tensor_engines.write();
132 clients.add(model, engine)
133 }
134
135 pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
136 let mut clients = self.completion_engines.write();
137 clients.remove(model)
138 }
139
140 pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
141 let mut clients = self.chat_completion_engines.write();
142 clients.remove(model)
143 }
144
145 pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
146 let mut clients = self.embeddings_engines.write();
147 clients.remove(model)
148 }
149
150 pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
151 let mut clients = self.tensor_engines.write();
152 clients.remove(model)
153 }
154
155 pub fn get_embeddings_engine(
156 &self,
157 model: &str,
158 ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
159 self.embeddings_engines
160 .read()
161 .get(model)
162 .cloned()
163 .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
164 }
165
166 pub fn get_completions_engine(
167 &self,
168 model: &str,
169 ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
170 self.completion_engines
171 .read()
172 .get(model)
173 .cloned()
174 .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
175 }
176
177 pub fn get_chat_completions_engine(
178 &self,
179 model: &str,
180 ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
181 self.chat_completion_engines
182 .read()
183 .get(model)
184 .cloned()
185 .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
186 }
187
188 pub fn get_tensor_engine(
189 &self,
190 model: &str,
191 ) -> Result<TensorStreamingEngine, ModelManagerError> {
192 self.tensor_engines
193 .read()
194 .get(model)
195 .cloned()
196 .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
197 }
198
199 pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
202 self.entries.lock().insert(key.to_string(), entry);
203 }
204
205 pub fn remove_model_entry(&self, key: &str) -> Option<ModelEntry> {
207 self.entries.lock().remove(key)
208 }
209
210 pub async fn kv_chooser_for(
211 &self,
212 model_name: &str,
213 component: &Component,
214 kv_cache_block_size: u32,
215 kv_router_config: Option<KvRouterConfig>,
216 ) -> anyhow::Result<Arc<KvRouter>> {
217 if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
218 if kv_chooser.block_size() != kv_cache_block_size {
220 tracing::warn!(
221 model_name = %model_name,
222 existing_block_size = %kv_chooser.block_size(),
223 requested_block_size = %kv_cache_block_size,
224 "KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
225 This will cause routing to fail silently. Consider using the same block size or restarting the router."
226 );
227 }
228 return Ok(kv_chooser);
229 }
230 self.create_kv_chooser(model_name, component, kv_cache_block_size, kv_router_config)
231 .await
232 }
233
234 fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
235 self.kv_choosers.lock().get(model_name).cloned()
236 }
237
238 async fn create_kv_chooser(
240 &self,
241 model_name: &str,
242 component: &Component,
243 kv_cache_block_size: u32,
244 kv_router_config: Option<KvRouterConfig>,
245 ) -> anyhow::Result<Arc<KvRouter>> {
246 let etcd_client = component
247 .drt()
248 .etcd_client()
249 .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
250 let router_uuid = uuid::Uuid::new_v4();
251 let router_key = format!(
252 "{}/{}/{}",
253 KV_ROUTERS_ROOT_PATH,
254 component.path(),
255 router_uuid
256 );
257 etcd_client
258 .kv_create(
259 &router_key,
260 serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?,
261 None, )
263 .await?;
264
265 let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
266 let chooser = KvRouter::new(
267 component.clone(),
268 kv_cache_block_size,
269 Some(selector),
270 kv_router_config,
271 router_uuid.to_string(),
272 )
273 .await?;
274 let new_kv_chooser = Arc::new(chooser);
275 self.kv_choosers
276 .lock()
277 .insert(model_name.to_string(), new_kv_chooser.clone());
278 Ok(new_kv_chooser)
279 }
280
281 pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
282 self.entries
283 .lock()
284 .values()
285 .find(|entry| entry.name == model)
286 .and_then(|entry| entry.runtime_config.as_ref())
287 .and_then(|config| config.tool_call_parser.clone())
288 .map(|parser| parser.to_string())
289 }
290
291 pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
294 let tool_call_parser = self.get_model_tool_call_parser(model);
295 let reasoning_parser = None; crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
298 }
299}
300
301pub struct ModelEngines<E> {
302 default: Option<String>,
304 engines: HashMap<String, E>,
305}
306
307impl<E> Default for ModelEngines<E> {
308 fn default() -> Self {
309 Self {
310 default: None,
311 engines: HashMap::new(),
312 }
313 }
314}
315
316impl<E> ModelEngines<E> {
317 #[allow(dead_code)]
318 fn set_default(&mut self, model: &str) {
319 self.default = Some(model.to_string());
320 }
321
322 #[allow(dead_code)]
323 fn clear_default(&mut self) {
324 self.default = None;
325 }
326
327 fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
328 if self.engines.contains_key(model) {
329 return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
330 }
331 self.engines.insert(model.to_string(), engine);
332 Ok(())
333 }
334
335 fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
336 if self.engines.remove(model).is_none() {
337 return Err(ModelManagerError::ModelNotFound(model.to_string()));
338 }
339 Ok(())
340 }
341
342 fn get(&self, model: &str) -> Option<&E> {
343 self.engines.get(model)
344 }
345
346 fn contains(&self, model: &str) -> bool {
347 self.engines.contains_key(model)
348 }
349
350 pub fn list(&self) -> Vec<String> {
351 self.engines.keys().map(|k| k.to_owned()).collect()
352 }
353}