dynamo_llm/http/service/
discovery.rs1use std::sync::Arc;
17
18use serde::{Deserialize, Serialize};
19use tokio::sync::mpsc::Receiver;
20
21use dynamo_runtime::{
22 protocols::{self, annotated::Annotated},
23 raise,
24 transports::etcd::{KeyValue, WatchEvent},
25 DistributedRuntime,
26};
27
28use super::ModelManager;
29use crate::model_type::ModelType;
30use crate::protocols::openai::chat_completions::{
31 NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
32};
33use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
34use tracing;
35#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
38pub struct ModelEntry {
39 pub name: String,
43
44 pub endpoint: protocols::Endpoint,
46
47 pub model_type: ModelType,
49}
50
51pub struct ModelWatchState {
52 pub prefix: String,
53 pub model_type: ModelType,
54 pub manager: ModelManager,
55 pub drt: DistributedRuntime,
56}
57
58pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
59 tracing::debug!("model watcher started");
60
61 while let Some(event) = events_rx.recv().await {
62 match event {
63 WatchEvent::Put(kv) => {
64 let key = match kv.key_str() {
65 Ok(key) => key,
66 Err(err) => {
67 tracing::error!(%err, ?kv, "Invalid UTF8 in model key");
68 continue;
69 }
70 };
71 tracing::debug!(key, "adding model");
72
73 let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
75 Ok(model_entry) => model_entry,
76 Err(err) => {
77 tracing::error!(%err, ?kv, "Invalid JSON in model entry");
78 continue;
79 }
80 };
81 if state.manager.has_model_any(&model_entry.name) {
82 tracing::trace!(
83 service_name = model_entry.name,
84 "New endpoint for existing model"
85 );
86 continue;
87 }
88
89 match handle_put(model_entry, state.clone()).await {
90 Ok((model_name, model_type)) => {
91 tracing::info!("added {} model: {}", model_type, model_name);
92 }
93 Err(e) => {
94 tracing::error!("error adding model: {}", e);
95 }
96 }
97 }
98 WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
99 Ok((model_name, model_type)) => {
100 tracing::info!("removed {} model: {}", model_type, model_name);
101 }
102 Err(e) => {
103 tracing::error!("error removing model: {}", e);
104 }
105 },
106 }
107 }
108}
109
110async fn handle_delete(
111 kv: &KeyValue,
112 state: Arc<ModelWatchState>,
113) -> anyhow::Result<(&str, ModelType)> {
114 let key = kv.key_str()?;
115 tracing::debug!(key, "removing model");
116
117 let model_name = key.trim_start_matches(&state.prefix);
118
119 match state.model_type {
120 ModelType::Chat => state.manager.remove_chat_completions_model(model_name)?,
121 ModelType::Completion => state.manager.remove_completions_model(model_name)?,
122 };
123
124 Ok((model_name, state.model_type))
125}
126
127async fn handle_put(
132 model_entry: ModelEntry,
133 state: Arc<ModelWatchState>,
134) -> anyhow::Result<(String, ModelType)> {
135 if model_entry.model_type != state.model_type {
136 raise!(
137 "model type mismatch: {} != {}",
138 model_entry.model_type,
139 state.model_type
140 );
141 }
142
143 match state.model_type {
144 ModelType::Chat => {
145 let client = state
146 .drt
147 .namespace(model_entry.endpoint.namespace)?
148 .component(model_entry.endpoint.component)?
149 .endpoint(model_entry.endpoint.name)
150 .client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
151 .await?;
152 state
153 .manager
154 .add_chat_completions_model(&model_entry.name, Arc::new(client))?;
155 }
156 ModelType::Completion => {
157 let client = state
158 .drt
159 .namespace(model_entry.endpoint.namespace)?
160 .component(model_entry.endpoint.component)?
161 .endpoint(model_entry.endpoint.name)
162 .client::<CompletionRequest, Annotated<CompletionResponse>>()
163 .await?;
164 state
165 .manager
166 .add_completions_model(&model_entry.name, Arc::new(client))?;
167 }
168 }
169
170 Ok((model_entry.name, state.model_type))
171}