dynamo_llm/discovery/
model_manager.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    collections::{HashMap, HashSet},
6    sync::Arc,
7};
8
9use parking_lot::{Mutex, RwLock};
10
11use dynamo_runtime::component::Component;
12use dynamo_runtime::prelude::DistributedRuntimeProvider;
13
14use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
15use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
16use crate::{
17    kv_router::KvRouter,
18    types::generic::tensor::TensorStreamingEngine,
19    types::openai::{
20        chat_completions::OpenAIChatCompletionsStreamingEngine,
21        completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
22    },
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ModelManagerError {
27    #[error("Model not found: {0}")]
28    ModelNotFound(String),
29
30    #[error("Model already exists: {0}")]
31    ModelAlreadyExists(String),
32}
33
34// Don't implement Clone for this, put it in an Arc instead.
35pub struct ModelManager {
36    // We read a lot and write rarely, so these three are RwLock
37    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
38    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
39    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
40    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
41
42    // These two are Mutex because we read and write rarely and equally
43    entries: Mutex<HashMap<String, ModelEntry>>,
44    kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
45}
46
47impl Default for ModelManager {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl ModelManager {
54    pub fn new() -> Self {
55        Self {
56            completion_engines: RwLock::new(ModelEngines::default()),
57            chat_completion_engines: RwLock::new(ModelEngines::default()),
58            embeddings_engines: RwLock::new(ModelEngines::default()),
59            tensor_engines: RwLock::new(ModelEngines::default()),
60            entries: Mutex::new(HashMap::new()),
61            kv_choosers: Mutex::new(HashMap::new()),
62        }
63    }
64
65    pub fn get_model_entries(&self) -> Vec<ModelEntry> {
66        self.entries.lock().values().cloned().collect()
67    }
68
69    pub fn has_model_any(&self, model: &str) -> bool {
70        self.chat_completion_engines.read().contains(model)
71            || self.completion_engines.read().contains(model)
72    }
73
74    pub fn model_display_names(&self) -> HashSet<String> {
75        self.list_chat_completions_models()
76            .into_iter()
77            .chain(self.list_completions_models())
78            .chain(self.list_embeddings_models())
79            .chain(self.list_tensor_models())
80            .collect()
81    }
82
83    pub fn list_chat_completions_models(&self) -> Vec<String> {
84        self.chat_completion_engines.read().list()
85    }
86
87    pub fn list_completions_models(&self) -> Vec<String> {
88        self.completion_engines.read().list()
89    }
90
91    pub fn list_embeddings_models(&self) -> Vec<String> {
92        self.embeddings_engines.read().list()
93    }
94
95    pub fn list_tensor_models(&self) -> Vec<String> {
96        self.tensor_engines.read().list()
97    }
98
99    pub fn add_completions_model(
100        &self,
101        model: &str,
102        engine: OpenAICompletionsStreamingEngine,
103    ) -> Result<(), ModelManagerError> {
104        let mut clients = self.completion_engines.write();
105        clients.add(model, engine)
106    }
107
108    pub fn add_chat_completions_model(
109        &self,
110        model: &str,
111        engine: OpenAIChatCompletionsStreamingEngine,
112    ) -> Result<(), ModelManagerError> {
113        let mut clients = self.chat_completion_engines.write();
114        clients.add(model, engine)
115    }
116
117    pub fn add_embeddings_model(
118        &self,
119        model: &str,
120        engine: OpenAIEmbeddingsStreamingEngine,
121    ) -> Result<(), ModelManagerError> {
122        let mut clients = self.embeddings_engines.write();
123        clients.add(model, engine)
124    }
125
126    pub fn add_tensor_model(
127        &self,
128        model: &str,
129        engine: TensorStreamingEngine,
130    ) -> Result<(), ModelManagerError> {
131        let mut clients = self.tensor_engines.write();
132        clients.add(model, engine)
133    }
134
135    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
136        let mut clients = self.completion_engines.write();
137        clients.remove(model)
138    }
139
140    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
141        let mut clients = self.chat_completion_engines.write();
142        clients.remove(model)
143    }
144
145    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
146        let mut clients = self.embeddings_engines.write();
147        clients.remove(model)
148    }
149
150    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
151        let mut clients = self.tensor_engines.write();
152        clients.remove(model)
153    }
154
155    pub fn get_embeddings_engine(
156        &self,
157        model: &str,
158    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
159        self.embeddings_engines
160            .read()
161            .get(model)
162            .cloned()
163            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
164    }
165
166    pub fn get_completions_engine(
167        &self,
168        model: &str,
169    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
170        self.completion_engines
171            .read()
172            .get(model)
173            .cloned()
174            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
175    }
176
177    pub fn get_chat_completions_engine(
178        &self,
179        model: &str,
180    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
181        self.chat_completion_engines
182            .read()
183            .get(model)
184            .cloned()
185            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
186    }
187
188    pub fn get_tensor_engine(
189        &self,
190        model: &str,
191    ) -> Result<TensorStreamingEngine, ModelManagerError> {
192        self.tensor_engines
193            .read()
194            .get(model)
195            .cloned()
196            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
197    }
198
199    /// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
200    /// deleted from etcd.
201    pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
202        self.entries.lock().insert(key.to_string(), entry);
203    }
204
205    /// Remove and return model entry for this instance's etcd key. We do this when the instance stops.
206    pub fn remove_model_entry(&self, key: &str) -> Option<ModelEntry> {
207        self.entries.lock().remove(key)
208    }
209
210    pub async fn kv_chooser_for(
211        &self,
212        model_name: &str,
213        component: &Component,
214        kv_cache_block_size: u32,
215        kv_router_config: Option<KvRouterConfig>,
216    ) -> anyhow::Result<Arc<KvRouter>> {
217        if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
218            // Check if the existing router has a different block size
219            if kv_chooser.block_size() != kv_cache_block_size {
220                tracing::warn!(
221                    model_name = %model_name,
222                    existing_block_size = %kv_chooser.block_size(),
223                    requested_block_size = %kv_cache_block_size,
224                    "KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
225                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
226                );
227            }
228            return Ok(kv_chooser);
229        }
230        self.create_kv_chooser(model_name, component, kv_cache_block_size, kv_router_config)
231            .await
232    }
233
234    fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
235        self.kv_choosers.lock().get(model_name).cloned()
236    }
237
238    /// Create and return a KV chooser for this component and model
239    async fn create_kv_chooser(
240        &self,
241        model_name: &str,
242        component: &Component,
243        kv_cache_block_size: u32,
244        kv_router_config: Option<KvRouterConfig>,
245    ) -> anyhow::Result<Arc<KvRouter>> {
246        let etcd_client = component
247            .drt()
248            .etcd_client()
249            .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
250        let router_uuid = uuid::Uuid::new_v4();
251        let router_key = format!(
252            "{}/{}/{}",
253            KV_ROUTERS_ROOT_PATH,
254            component.path(),
255            router_uuid
256        );
257        etcd_client
258            .kv_create(
259                &router_key,
260                serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?,
261                None, // use primary lease
262            )
263            .await?;
264
265        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
266        let chooser = KvRouter::new(
267            component.clone(),
268            kv_cache_block_size,
269            Some(selector),
270            kv_router_config,
271            router_uuid.to_string(),
272        )
273        .await?;
274        let new_kv_chooser = Arc::new(chooser);
275        self.kv_choosers
276            .lock()
277            .insert(model_name.to_string(), new_kv_chooser.clone());
278        Ok(new_kv_chooser)
279    }
280
281    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
282        self.entries
283            .lock()
284            .values()
285            .find(|entry| entry.name == model)
286            .and_then(|entry| entry.runtime_config.as_ref())
287            .and_then(|config| config.tool_call_parser.clone())
288            .map(|parser| parser.to_string())
289    }
290
291    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
292    /// Currently reasoning parser is not implemented (returns None).
293    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
294        let tool_call_parser = self.get_model_tool_call_parser(model);
295        let reasoning_parser = None; // TODO: Implement reasoning parser
296
297        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
298    }
299}
300
301pub struct ModelEngines<E> {
302    /// Optional default model name
303    default: Option<String>,
304    engines: HashMap<String, E>,
305}
306
307impl<E> Default for ModelEngines<E> {
308    fn default() -> Self {
309        Self {
310            default: None,
311            engines: HashMap::new(),
312        }
313    }
314}
315
316impl<E> ModelEngines<E> {
317    #[allow(dead_code)]
318    fn set_default(&mut self, model: &str) {
319        self.default = Some(model.to_string());
320    }
321
322    #[allow(dead_code)]
323    fn clear_default(&mut self) {
324        self.default = None;
325    }
326
327    fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
328        if self.engines.contains_key(model) {
329            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
330        }
331        self.engines.insert(model.to_string(), engine);
332        Ok(())
333    }
334
335    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
336        if self.engines.remove(model).is_none() {
337            return Err(ModelManagerError::ModelNotFound(model.to_string()));
338        }
339        Ok(())
340    }
341
342    fn get(&self, model: &str) -> Option<&E> {
343        self.engines.get(model)
344    }
345
346    fn contains(&self, model: &str) -> bool {
347        self.engines.contains_key(model)
348    }
349
350    pub fn list(&self) -> Vec<String> {
351        self.engines.keys().map(|k| k.to_owned()).collect()
352    }
353}