dynamo_llm/discovery/
watcher.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use tokio::sync::mpsc::Sender;
6
7use anyhow::Context as _;
8use tokio::sync::{Notify, mpsc::Receiver};
9
10use dynamo_runtime::{
11    DistributedRuntime,
12    pipeline::{
13        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
14        network::egress::push_router::PushRouter,
15    },
16    protocols::annotated::Annotated,
17    transports::etcd::{KeyValue, WatchEvent},
18};
19
20use crate::{
21    backend::Backend,
22    entrypoint,
23    kv_router::KvRouterConfig,
24    model_card::ModelDeploymentCard,
25    model_type::{ModelInput, ModelType},
26    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
27    protocols::{
28        common::llm_backend::EmbeddingsEngineOutput,
29        openai::{
30            chat_completions::{
31                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
32            },
33            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
34            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
35        },
36        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
37    },
38};
39
40use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
41use crate::namespace::is_global_namespace;
42
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum ModelUpdate {
45    Added(ModelType),
46    Removed(ModelType),
47}
48
49pub struct ModelWatcher {
50    manager: Arc<ModelManager>,
51    drt: DistributedRuntime,
52    router_mode: RouterMode,
53    notify_on_model: Notify,
54    model_update_tx: Option<Sender<ModelUpdate>>,
55    kv_router_config: Option<KvRouterConfig>,
56    busy_threshold: Option<f64>,
57}
58
59const ALL_MODEL_TYPES: &[ModelType] = &[
60    ModelType::Chat,
61    ModelType::Completions,
62    ModelType::Embedding,
63    ModelType::TensorBased,
64];
65
66impl ModelWatcher {
67    pub fn new(
68        runtime: DistributedRuntime,
69        model_manager: Arc<ModelManager>,
70        router_mode: RouterMode,
71        kv_router_config: Option<KvRouterConfig>,
72        busy_threshold: Option<f64>,
73    ) -> ModelWatcher {
74        Self {
75            manager: model_manager,
76            drt: runtime,
77            router_mode,
78            notify_on_model: Notify::new(),
79            model_update_tx: None,
80            kv_router_config,
81            busy_threshold,
82        }
83    }
84
85    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
86        self.model_update_tx = Some(tx);
87    }
88
89    /// Wait until we have at least one chat completions model and return it's name.
90    pub async fn wait_for_chat_model(&self) -> String {
91        // Loop in case it gets added and immediately deleted
92        loop {
93            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
94                return model_name.to_owned();
95            }
96            self.notify_on_model.notified().await
97        }
98    }
99
100    /// Common watch logic with optional namespace filtering
101    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
102        let global_namespace = target_namespace.is_none_or(is_global_namespace);
103
104        while let Some(event) = events_rx.recv().await {
105            match event {
106                WatchEvent::Put(kv) => {
107                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
108                        Ok(model_entry) => model_entry,
109                        Err(err) => {
110                            match kv.value_str() {
111                                Ok(value) => {
112                                    tracing::error!(%err, value, "Invalid JSON in model entry")
113                                }
114                                Err(value_str_err) => {
115                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
116                                }
117                            }
118                            continue;
119                        }
120                    };
121
122                    // Filter by namespace if target_namespace is specified
123                    if !global_namespace
124                        && let Some(target_ns) = target_namespace
125                        && model_entry.endpoint_id.namespace != target_ns
126                    {
127                        tracing::debug!(
128                            model_namespace = model_entry.endpoint_id.namespace,
129                            target_namespace = target_ns,
130                            model_name = model_entry.name,
131                            "Skipping model from different namespace"
132                        );
133                        continue;
134                    }
135
136                    let key = match kv.key_str() {
137                        Ok(k) => k,
138                        Err(err) => {
139                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
140                            continue;
141                        }
142                    };
143                    self.manager.save_model_entry(key, model_entry.clone());
144
145                    if let Some(tx) = &self.model_update_tx {
146                        tx.send(ModelUpdate::Added(model_entry.model_type))
147                            .await
148                            .ok();
149                    }
150
151                    if self.manager.has_model_any(&model_entry.name) {
152                        tracing::trace!(
153                            name = model_entry.name,
154                            namespace = model_entry.endpoint_id.namespace,
155                            "New endpoint for existing model"
156                        );
157                        self.notify_on_model.notify_waiters();
158                        continue;
159                    }
160
161                    match self.handle_put(&model_entry).await {
162                        Ok(()) => {
163                            tracing::info!(
164                                model_name = model_entry.name,
165                                namespace = model_entry.endpoint_id.namespace,
166                                "added model"
167                            );
168                            self.notify_on_model.notify_waiters();
169                        }
170                        Err(err) => {
171                            tracing::error!(
172                                error = format!("{err:#}"),
173                                "error adding model {} from namespace {}",
174                                model_entry.name,
175                                model_entry.endpoint_id.namespace,
176                            );
177                        }
178                    }
179                }
180                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
181                    Ok(Some(model_name)) => {
182                        tracing::info!(model_name, "removed model");
183                    }
184                    Ok(None) => {
185                        // There are other instances running this model, nothing to do
186                    }
187                    Err(e) => {
188                        tracing::error!(error = %e, "error removing model");
189                    }
190                },
191            }
192        }
193    }
194
195    /// If the last instance running this model has gone delete it.
196    /// Returns the name of the model we just deleted, if any.
197    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
198        let key = kv.key_str()?;
199        let model_entry = match self.manager.remove_model_entry(key) {
200            Some(entry) => entry,
201            None => {
202                anyhow::bail!("Missing ModelEntry for {key}");
203            }
204        };
205        let model_name = model_entry.name;
206        let active_instances = self
207            .entries_for_model(&model_name)
208            .await
209            .with_context(|| model_name.clone())?;
210        if !active_instances.is_empty() {
211            return Ok(None);
212        }
213
214        // Ignore the errors because model could be either type
215        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
216        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
217        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
218        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
219
220        let mut chat_model_removed = false;
221        let mut completions_model_removed = false;
222        let mut embeddings_model_removed = false;
223        let mut tensor_model_removed = false;
224
225        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
226            chat_model_removed = true;
227        }
228        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
229        {
230            completions_model_removed = true;
231        }
232        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
233            embeddings_model_removed = true;
234        }
235        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
236            tensor_model_removed = true;
237        }
238
239        if !chat_model_removed
240            && !completions_model_removed
241            && !embeddings_model_removed
242            && !tensor_model_removed
243        {
244            tracing::debug!(
245                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
246                model_name,
247                chat_model_removed,
248                completions_model_removed,
249                embeddings_model_removed,
250                tensor_model_removed
251            );
252        } else {
253            for model_type in ALL_MODEL_TYPES {
254                if ((chat_model_removed && *model_type == ModelType::Chat)
255                    || (completions_model_removed && *model_type == ModelType::Completions)
256                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
257                    || (tensor_model_removed && *model_type == ModelType::TensorBased))
258                    && let Some(tx) = &self.model_update_tx
259                {
260                    tx.send(ModelUpdate::Removed(*model_type)).await.ok();
261                }
262            }
263        }
264
265        Ok(Some(model_name))
266    }
267
268    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
269    // models.
270    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
271        let endpoint_id = &model_entry.endpoint_id;
272        let component = self
273            .drt
274            .namespace(&endpoint_id.namespace)?
275            .component(&endpoint_id.component)?;
276        let client = component.endpoint(&endpoint_id.name).client().await?;
277        let model_slug = model_entry.slug();
278        let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
279            Ok(Some(mut card)) => {
280                tracing::debug!(card.display_name, "adding model");
281                // Ensure runtime_config is populated
282                if let Some(rc) = model_entry.runtime_config.clone() {
283                    card.runtime_config = rc;
284                }
285                card
286            }
287            Ok(None) => {
288                anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
289            }
290            Err(err) => {
291                anyhow::bail!(
292                    "Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
293                );
294            }
295        };
296
297        if model_entry.model_input == ModelInput::Tokens
298            && (model_entry.model_type.supports_chat()
299                || model_entry.model_type.supports_completions())
300        {
301            // Case 1: Tokens + (Chat OR Completions OR Both)
302            // A model that expects pre-processed requests meaning it's up to us whether we
303            // handle Chat or Completions requests, so handle whatever the model supports.
304
305            let kv_chooser = if self.router_mode == RouterMode::KV {
306                Some(
307                    self.manager
308                        .kv_chooser_for(
309                            &model_entry.name,
310                            &component,
311                            card.kv_cache_block_size,
312                            self.kv_router_config,
313                        )
314                        .await?,
315                )
316            } else {
317                None
318            };
319
320            // This is expensive, we are loading ~10MiB JSON, so only do it once
321            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
322
323            // Add chat engine only if the model supports chat
324            if model_entry.model_type.supports_chat() {
325                let chat_engine = entrypoint::build_routed_pipeline::<
326                    NvCreateChatCompletionRequest,
327                    NvCreateChatCompletionStreamResponse,
328                >(
329                    &card,
330                    &client,
331                    self.router_mode,
332                    self.busy_threshold,
333                    kv_chooser.clone(),
334                    tokenizer_hf.clone(),
335                )
336                .await
337                .context("build_routed_pipeline")?;
338                self.manager
339                    .add_chat_completions_model(&model_entry.name, chat_engine)
340                    .context("add_chat_completions_model")?;
341                tracing::info!("Chat completions is ready");
342            }
343
344            // Add completions engine only if the model supports completions
345            if model_entry.model_type.supports_completions() {
346                let formatter = PromptFormatter::no_op();
347                let PromptFormatter::OAI(formatter) = formatter;
348                let preprocessor = OpenAIPreprocessor::new_with_parts(
349                    card.clone(),
350                    formatter,
351                    tokenizer_hf.clone(),
352                )
353                .context("OpenAIPreprocessor::new_with_parts")?;
354                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
355                    NvCreateCompletionRequest,
356                    NvCreateCompletionResponse,
357                >(
358                    &card,
359                    &client,
360                    self.router_mode,
361                    self.busy_threshold,
362                    kv_chooser,
363                    preprocessor,
364                    tokenizer_hf,
365                )
366                .await
367                .context("build_routed_pipeline_with_preprocessor")?;
368                self.manager
369                    .add_completions_model(&model_entry.name, completions_engine)
370                    .context("add_completions_model")?;
371                tracing::info!("Completions is ready");
372            }
373        } else if model_entry.model_input == ModelInput::Text
374            && model_entry.model_type.supports_chat()
375        {
376            // Case 3: Text + Chat
377            let push_router = PushRouter::<
378                NvCreateChatCompletionRequest,
379                Annotated<NvCreateChatCompletionStreamResponse>,
380            >::from_client_with_threshold(
381                client, self.router_mode, self.busy_threshold
382            )
383            .await?;
384            let engine = Arc::new(push_router);
385            self.manager
386                .add_chat_completions_model(&model_entry.name, engine)?;
387        } else if model_entry.model_input == ModelInput::Text
388            && model_entry.model_type.supports_completions()
389        {
390            // Case 2: Text + Completions
391            let push_router = PushRouter::<
392                NvCreateCompletionRequest,
393                Annotated<NvCreateCompletionResponse>,
394            >::from_client_with_threshold(
395                client, self.router_mode, self.busy_threshold
396            )
397            .await?;
398            let engine = Arc::new(push_router);
399            self.manager
400                .add_completions_model(&model_entry.name, engine)?;
401        } else if model_entry.model_input == ModelInput::Tokens
402            && model_entry.model_type.supports_embedding()
403        {
404            // Case 4: Tokens + Embeddings
405
406            // Create preprocessing pipeline similar to Backend
407            let frontend = SegmentSource::<
408                SingleIn<NvCreateEmbeddingRequest>,
409                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
410            >::new();
411
412            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
413            let backend = Backend::from_mdc(&card).into_operator();
414
415            let router = PushRouter::<
416                PreprocessedEmbeddingRequest,
417                Annotated<EmbeddingsEngineOutput>,
418            >::from_client_with_threshold(
419                client, self.router_mode, self.busy_threshold
420            )
421            .await?;
422
423            // Note: Embeddings don't need KV routing complexity
424            let service_backend = ServiceBackend::from_engine(Arc::new(router));
425
426            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
427            let embedding_engine = frontend
428                .link(preprocessor.forward_edge())?
429                .link(backend.forward_edge())?
430                .link(service_backend)?
431                .link(backend.backward_edge())?
432                .link(preprocessor.backward_edge())?
433                .link(frontend)?;
434
435            self.manager
436                .add_embeddings_model(&model_entry.name, embedding_engine)?;
437        } else if model_entry.model_input == ModelInput::Tensor
438            && model_entry.model_type.supports_tensor()
439        {
440            // Case 5: Tensor + Tensor (non-LLM)
441            let push_router = PushRouter::<
442                NvCreateTensorRequest,
443                Annotated<NvCreateTensorResponse>,
444            >::from_client_with_threshold(
445                client, self.router_mode, self.busy_threshold
446            )
447            .await?;
448            let engine = Arc::new(push_router);
449            self.manager.add_tensor_model(&model_entry.name, engine)?;
450        } else {
451            // Reject unsupported combinations
452            anyhow::bail!(
453                "Unsupported model configuration: {} with {} input. Supported combinations: \
454                Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
455                model_entry.model_type,
456                model_entry.model_input.as_str()
457            );
458        }
459
460        Ok(())
461    }
462
463    /// All the registered ModelEntry, one per instance
464    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
465        let Some(etcd_client) = self.drt.etcd_client() else {
466            anyhow::bail!("all_entries: Missing etcd client");
467        };
468        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
469        let mut entries = Vec::with_capacity(kvs.len());
470        for kv in kvs {
471            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
472                Ok(model_entry) => model_entry,
473                Err(err) => {
474                    match kv.value_str() {
475                        Ok(value) => {
476                            tracing::error!(%err, value, "Invalid JSON in model entry")
477                        }
478                        Err(value_str_err) => {
479                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
480                        }
481                    }
482                    continue;
483                }
484            };
485            entries.push(model_entry);
486        }
487        Ok(entries)
488    }
489
490    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
491        let mut all = self.all_entries().await?;
492        all.retain(|entry| entry.name == model_name);
493        Ok(all)
494    }
495}