1use std::sync::Arc;
5use tokio::sync::mpsc::Sender;
6
7use anyhow::Context as _;
8use tokio::sync::{Notify, mpsc::Receiver};
9
10use dynamo_runtime::{
11 DistributedRuntime,
12 pipeline::{
13 ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
14 network::egress::push_router::PushRouter,
15 },
16 protocols::annotated::Annotated,
17 transports::etcd::{KeyValue, WatchEvent},
18};
19
20use crate::{
21 backend::Backend,
22 entrypoint,
23 kv_router::KvRouterConfig,
24 model_card::ModelDeploymentCard,
25 model_type::{ModelInput, ModelType},
26 preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
27 protocols::{
28 common::llm_backend::EmbeddingsEngineOutput,
29 openai::{
30 chat_completions::{
31 NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
32 },
33 completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
34 embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
35 },
36 tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
37 },
38};
39
40use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
41use crate::namespace::is_global_namespace;
42
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum ModelUpdate {
45 Added(ModelType),
46 Removed(ModelType),
47}
48
49pub struct ModelWatcher {
50 manager: Arc<ModelManager>,
51 drt: DistributedRuntime,
52 router_mode: RouterMode,
53 notify_on_model: Notify,
54 model_update_tx: Option<Sender<ModelUpdate>>,
55 kv_router_config: Option<KvRouterConfig>,
56 busy_threshold: Option<f64>,
57}
58
59const ALL_MODEL_TYPES: &[ModelType] = &[
60 ModelType::Chat,
61 ModelType::Completions,
62 ModelType::Embedding,
63 ModelType::TensorBased,
64];
65
66impl ModelWatcher {
67 pub fn new(
68 runtime: DistributedRuntime,
69 model_manager: Arc<ModelManager>,
70 router_mode: RouterMode,
71 kv_router_config: Option<KvRouterConfig>,
72 busy_threshold: Option<f64>,
73 ) -> ModelWatcher {
74 Self {
75 manager: model_manager,
76 drt: runtime,
77 router_mode,
78 notify_on_model: Notify::new(),
79 model_update_tx: None,
80 kv_router_config,
81 busy_threshold,
82 }
83 }
84
85 pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
86 self.model_update_tx = Some(tx);
87 }
88
89 pub async fn wait_for_chat_model(&self) -> String {
91 loop {
93 if let Some(model_name) = self.manager.list_chat_completions_models().first() {
94 return model_name.to_owned();
95 }
96 self.notify_on_model.notified().await
97 }
98 }
99
100 pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
102 let global_namespace = target_namespace.is_none_or(is_global_namespace);
103
104 while let Some(event) = events_rx.recv().await {
105 match event {
106 WatchEvent::Put(kv) => {
107 let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
108 Ok(model_entry) => model_entry,
109 Err(err) => {
110 match kv.value_str() {
111 Ok(value) => {
112 tracing::error!(%err, value, "Invalid JSON in model entry")
113 }
114 Err(value_str_err) => {
115 tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
116 }
117 }
118 continue;
119 }
120 };
121
122 if !global_namespace
124 && let Some(target_ns) = target_namespace
125 && model_entry.endpoint_id.namespace != target_ns
126 {
127 tracing::debug!(
128 model_namespace = model_entry.endpoint_id.namespace,
129 target_namespace = target_ns,
130 model_name = model_entry.name,
131 "Skipping model from different namespace"
132 );
133 continue;
134 }
135
136 let key = match kv.key_str() {
137 Ok(k) => k,
138 Err(err) => {
139 tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
140 continue;
141 }
142 };
143 self.manager.save_model_entry(key, model_entry.clone());
144
145 if let Some(tx) = &self.model_update_tx {
146 tx.send(ModelUpdate::Added(model_entry.model_type))
147 .await
148 .ok();
149 }
150
151 if self.manager.has_model_any(&model_entry.name) {
152 tracing::trace!(
153 name = model_entry.name,
154 namespace = model_entry.endpoint_id.namespace,
155 "New endpoint for existing model"
156 );
157 self.notify_on_model.notify_waiters();
158 continue;
159 }
160
161 match self.handle_put(&model_entry).await {
162 Ok(()) => {
163 tracing::info!(
164 model_name = model_entry.name,
165 namespace = model_entry.endpoint_id.namespace,
166 "added model"
167 );
168 self.notify_on_model.notify_waiters();
169 }
170 Err(err) => {
171 tracing::error!(
172 error = format!("{err:#}"),
173 "error adding model {} from namespace {}",
174 model_entry.name,
175 model_entry.endpoint_id.namespace,
176 );
177 }
178 }
179 }
180 WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
181 Ok(Some(model_name)) => {
182 tracing::info!(model_name, "removed model");
183 }
184 Ok(None) => {
185 }
187 Err(e) => {
188 tracing::error!(error = %e, "error removing model");
189 }
190 },
191 }
192 }
193 }
194
195 async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
198 let key = kv.key_str()?;
199 let model_entry = match self.manager.remove_model_entry(key) {
200 Some(entry) => entry,
201 None => {
202 anyhow::bail!("Missing ModelEntry for {key}");
203 }
204 };
205 let model_name = model_entry.name;
206 let active_instances = self
207 .entries_for_model(&model_name)
208 .await
209 .with_context(|| model_name.clone())?;
210 if !active_instances.is_empty() {
211 return Ok(None);
212 }
213
214 let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
216 let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
217 let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
218 let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
219
220 let mut chat_model_removed = false;
221 let mut completions_model_removed = false;
222 let mut embeddings_model_removed = false;
223 let mut tensor_model_removed = false;
224
225 if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
226 chat_model_removed = true;
227 }
228 if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
229 {
230 completions_model_removed = true;
231 }
232 if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
233 embeddings_model_removed = true;
234 }
235 if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
236 tensor_model_removed = true;
237 }
238
239 if !chat_model_removed
240 && !completions_model_removed
241 && !embeddings_model_removed
242 && !tensor_model_removed
243 {
244 tracing::debug!(
245 "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
246 model_name,
247 chat_model_removed,
248 completions_model_removed,
249 embeddings_model_removed,
250 tensor_model_removed
251 );
252 } else {
253 for model_type in ALL_MODEL_TYPES {
254 if ((chat_model_removed && *model_type == ModelType::Chat)
255 || (completions_model_removed && *model_type == ModelType::Completions)
256 || (embeddings_model_removed && *model_type == ModelType::Embedding)
257 || (tensor_model_removed && *model_type == ModelType::TensorBased))
258 && let Some(tx) = &self.model_update_tx
259 {
260 tx.send(ModelUpdate::Removed(*model_type)).await.ok();
261 }
262 }
263 }
264
265 Ok(Some(model_name))
266 }
267
268 async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
271 let endpoint_id = &model_entry.endpoint_id;
272 let component = self
273 .drt
274 .namespace(&endpoint_id.namespace)?
275 .component(&endpoint_id.component)?;
276 let client = component.endpoint(&endpoint_id.name).client().await?;
277 let model_slug = model_entry.slug();
278 let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
279 Ok(Some(mut card)) => {
280 tracing::debug!(card.display_name, "adding model");
281 if let Some(rc) = model_entry.runtime_config.clone() {
283 card.runtime_config = rc;
284 }
285 card
286 }
287 Ok(None) => {
288 anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
289 }
290 Err(err) => {
291 anyhow::bail!(
292 "Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
293 );
294 }
295 };
296
297 if model_entry.model_input == ModelInput::Tokens
298 && (model_entry.model_type.supports_chat()
299 || model_entry.model_type.supports_completions())
300 {
301 let kv_chooser = if self.router_mode == RouterMode::KV {
306 Some(
307 self.manager
308 .kv_chooser_for(
309 &model_entry.name,
310 &component,
311 card.kv_cache_block_size,
312 self.kv_router_config,
313 )
314 .await?,
315 )
316 } else {
317 None
318 };
319
320 let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
322
323 if model_entry.model_type.supports_chat() {
325 let chat_engine = entrypoint::build_routed_pipeline::<
326 NvCreateChatCompletionRequest,
327 NvCreateChatCompletionStreamResponse,
328 >(
329 &card,
330 &client,
331 self.router_mode,
332 self.busy_threshold,
333 kv_chooser.clone(),
334 tokenizer_hf.clone(),
335 )
336 .await
337 .context("build_routed_pipeline")?;
338 self.manager
339 .add_chat_completions_model(&model_entry.name, chat_engine)
340 .context("add_chat_completions_model")?;
341 tracing::info!("Chat completions is ready");
342 }
343
344 if model_entry.model_type.supports_completions() {
346 let formatter = PromptFormatter::no_op();
347 let PromptFormatter::OAI(formatter) = formatter;
348 let preprocessor = OpenAIPreprocessor::new_with_parts(
349 card.clone(),
350 formatter,
351 tokenizer_hf.clone(),
352 )
353 .context("OpenAIPreprocessor::new_with_parts")?;
354 let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
355 NvCreateCompletionRequest,
356 NvCreateCompletionResponse,
357 >(
358 &card,
359 &client,
360 self.router_mode,
361 self.busy_threshold,
362 kv_chooser,
363 preprocessor,
364 tokenizer_hf,
365 )
366 .await
367 .context("build_routed_pipeline_with_preprocessor")?;
368 self.manager
369 .add_completions_model(&model_entry.name, completions_engine)
370 .context("add_completions_model")?;
371 tracing::info!("Completions is ready");
372 }
373 } else if model_entry.model_input == ModelInput::Text
374 && model_entry.model_type.supports_chat()
375 {
376 let push_router = PushRouter::<
378 NvCreateChatCompletionRequest,
379 Annotated<NvCreateChatCompletionStreamResponse>,
380 >::from_client_with_threshold(
381 client, self.router_mode, self.busy_threshold
382 )
383 .await?;
384 let engine = Arc::new(push_router);
385 self.manager
386 .add_chat_completions_model(&model_entry.name, engine)?;
387 } else if model_entry.model_input == ModelInput::Text
388 && model_entry.model_type.supports_completions()
389 {
390 let push_router = PushRouter::<
392 NvCreateCompletionRequest,
393 Annotated<NvCreateCompletionResponse>,
394 >::from_client_with_threshold(
395 client, self.router_mode, self.busy_threshold
396 )
397 .await?;
398 let engine = Arc::new(push_router);
399 self.manager
400 .add_completions_model(&model_entry.name, engine)?;
401 } else if model_entry.model_input == ModelInput::Tokens
402 && model_entry.model_type.supports_embedding()
403 {
404 let frontend = SegmentSource::<
408 SingleIn<NvCreateEmbeddingRequest>,
409 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
410 >::new();
411
412 let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
413 let backend = Backend::from_mdc(&card).into_operator();
414
415 let router = PushRouter::<
416 PreprocessedEmbeddingRequest,
417 Annotated<EmbeddingsEngineOutput>,
418 >::from_client_with_threshold(
419 client, self.router_mode, self.busy_threshold
420 )
421 .await?;
422
423 let service_backend = ServiceBackend::from_engine(Arc::new(router));
425
426 let embedding_engine = frontend
428 .link(preprocessor.forward_edge())?
429 .link(backend.forward_edge())?
430 .link(service_backend)?
431 .link(backend.backward_edge())?
432 .link(preprocessor.backward_edge())?
433 .link(frontend)?;
434
435 self.manager
436 .add_embeddings_model(&model_entry.name, embedding_engine)?;
437 } else if model_entry.model_input == ModelInput::Tensor
438 && model_entry.model_type.supports_tensor()
439 {
440 let push_router = PushRouter::<
442 NvCreateTensorRequest,
443 Annotated<NvCreateTensorResponse>,
444 >::from_client_with_threshold(
445 client, self.router_mode, self.busy_threshold
446 )
447 .await?;
448 let engine = Arc::new(push_router);
449 self.manager.add_tensor_model(&model_entry.name, engine)?;
450 } else {
451 anyhow::bail!(
453 "Unsupported model configuration: {} with {} input. Supported combinations: \
454 Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
455 model_entry.model_type,
456 model_entry.model_input.as_str()
457 );
458 }
459
460 Ok(())
461 }
462
463 pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
465 let Some(etcd_client) = self.drt.etcd_client() else {
466 anyhow::bail!("all_entries: Missing etcd client");
467 };
468 let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
469 let mut entries = Vec::with_capacity(kvs.len());
470 for kv in kvs {
471 let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
472 Ok(model_entry) => model_entry,
473 Err(err) => {
474 match kv.value_str() {
475 Ok(value) => {
476 tracing::error!(%err, value, "Invalid JSON in model entry")
477 }
478 Err(value_str_err) => {
479 tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
480 }
481 }
482 continue;
483 }
484 };
485 entries.push(model_entry);
486 }
487 Ok(entries)
488 }
489
490 pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
491 let mut all = self.all_entries().await?;
492 all.retain(|entry| entry.name == model_name);
493 Ok(all)
494 }
495}