1use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7
8use anyhow::Context as _;
9use dynamo_runtime::protocols::EndpointId;
10use dynamo_runtime::slug::Slug;
11use dynamo_runtime::traits::DistributedRuntimeProvider;
12use dynamo_runtime::{
13 component::Endpoint,
14 storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
15};
16
17use crate::discovery::ModelEntry;
18use crate::entrypoint::RouterConfig;
19use crate::mocker::protocols::MockEngineArgs;
20use crate::model_card::{self, ModelDeploymentCard};
21use crate::model_type::{ModelInput, ModelType};
22use crate::request_template::RequestTemplate;
23
24mod network_name;
25pub use network_name::ModelNetworkName;
26pub mod runtime_config;
27
28use runtime_config::ModelRuntimeConfig;
29
30const HF_SCHEME: &str = "hf://";
32
33const DEFAULT_NAME: &str = "dynamo";
36
37const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;
39
40pub const DEFAULT_HTTP_PORT: u16 = 8080;
43
44pub struct LocalModelBuilder {
45 model_path: Option<PathBuf>,
46 model_name: Option<String>,
47 model_config: Option<PathBuf>,
48 endpoint_id: Option<EndpointId>,
49 context_length: Option<u32>,
50 template_file: Option<PathBuf>,
51 router_config: Option<RouterConfig>,
52 kv_cache_block_size: u32,
53 http_host: Option<String>,
54 http_port: u16,
55 tls_cert_path: Option<PathBuf>,
56 tls_key_path: Option<PathBuf>,
57 migration_limit: u32,
58 is_mocker: bool,
59 extra_engine_args: Option<PathBuf>,
60 runtime_config: ModelRuntimeConfig,
61 user_data: Option<serde_json::Value>,
62 custom_template_path: Option<PathBuf>,
63 namespace: Option<String>,
64}
65
66impl Default for LocalModelBuilder {
67 fn default() -> Self {
68 LocalModelBuilder {
69 kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
70 http_host: Default::default(),
71 http_port: DEFAULT_HTTP_PORT,
72 tls_cert_path: Default::default(),
73 tls_key_path: Default::default(),
74 model_path: Default::default(),
75 model_name: Default::default(),
76 model_config: Default::default(),
77 endpoint_id: Default::default(),
78 context_length: Default::default(),
79 template_file: Default::default(),
80 router_config: Default::default(),
81 migration_limit: Default::default(),
82 is_mocker: Default::default(),
83 extra_engine_args: Default::default(),
84 runtime_config: Default::default(),
85 user_data: Default::default(),
86 custom_template_path: Default::default(),
87 namespace: Default::default(),
88 }
89 }
90}
91
92impl LocalModelBuilder {
93 pub fn model_path(&mut self, model_path: Option<PathBuf>) -> &mut Self {
94 self.model_path = model_path;
95 self
96 }
97
98 pub fn model_name(&mut self, model_name: Option<String>) -> &mut Self {
99 self.model_name = model_name;
100 self
101 }
102
103 pub fn model_config(&mut self, model_config: Option<PathBuf>) -> &mut Self {
104 self.model_config = model_config;
105 self
106 }
107
108 pub fn endpoint_id(&mut self, endpoint_id: Option<EndpointId>) -> &mut Self {
109 self.endpoint_id = endpoint_id;
110 self
111 }
112
113 pub fn context_length(&mut self, context_length: Option<u32>) -> &mut Self {
114 self.context_length = context_length;
115 self
116 }
117
118 pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option<u32>) -> &mut Self {
120 self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE);
121 self
122 }
123
124 pub fn http_host(&mut self, host: Option<String>) -> &mut Self {
125 self.http_host = host;
126 self
127 }
128
129 pub fn http_port(&mut self, port: u16) -> &mut Self {
130 self.http_port = port;
131 self
132 }
133
134 pub fn tls_cert_path(&mut self, p: Option<PathBuf>) -> &mut Self {
135 self.tls_cert_path = p;
136 self
137 }
138
139 pub fn tls_key_path(&mut self, p: Option<PathBuf>) -> &mut Self {
140 self.tls_key_path = p;
141 self
142 }
143
144 pub fn router_config(&mut self, router_config: Option<RouterConfig>) -> &mut Self {
145 self.router_config = router_config;
146 self
147 }
148
149 pub fn namespace(&mut self, namespace: Option<String>) -> &mut Self {
150 self.namespace = namespace;
151 self
152 }
153
154 pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
155 self.template_file = template_file;
156 self
157 }
158
159 pub fn custom_template_path(&mut self, custom_template_path: Option<PathBuf>) -> &mut Self {
160 self.custom_template_path = custom_template_path;
161 self
162 }
163
164 pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
165 self.migration_limit = migration_limit.unwrap_or(0);
166 self
167 }
168
169 pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
170 self.is_mocker = is_mocker;
171 self
172 }
173
174 pub fn extra_engine_args(&mut self, extra_engine_args: Option<PathBuf>) -> &mut Self {
175 self.extra_engine_args = extra_engine_args;
176 self
177 }
178
179 pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self {
180 self.runtime_config = runtime_config;
181 self
182 }
183
184 pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
185 self.user_data = user_data;
186 self
187 }
188
189 pub async fn build(&mut self) -> anyhow::Result<LocalModel> {
200 let endpoint_id = self
203 .endpoint_id
204 .take()
205 .unwrap_or_else(|| internal_endpoint("local_model"));
206
207 let template = self
208 .template_file
209 .as_deref()
210 .map(RequestTemplate::load)
211 .transpose()?;
212
213 if self.model_path.is_none() {
215 let mut card = ModelDeploymentCard::with_name_only(
216 self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
217 );
218 card.migration_limit = self.migration_limit;
219 card.user_data = self.user_data.take();
220 card.runtime_config = self.runtime_config.clone();
221
222 return Ok(LocalModel {
223 card,
224 full_path: PathBuf::new(),
225 endpoint_id,
226 template,
227 http_host: self.http_host.take(),
228 http_port: self.http_port,
229 tls_cert_path: self.tls_cert_path.take(),
230 tls_key_path: self.tls_key_path.take(),
231 router_config: self.router_config.take().unwrap_or_default(),
232 runtime_config: self.runtime_config.clone(),
233 namespace: self.namespace.clone(),
234 });
235 }
236
237 let model_path = self.model_path.take().unwrap();
239 let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?;
240
241 let is_hf_repo =
244 model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false);
245 let relative_path = model_path.trim_start_matches(HF_SCHEME);
246 let full_path = if is_hf_repo {
247 super::hub::from_hf(relative_path, self.is_mocker).await?
249 } else {
250 fs::canonicalize(relative_path)?
251 };
252 let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
254
255 let mut card = ModelDeploymentCard::load_from_disk(
256 model_config_path,
257 self.custom_template_path.as_deref(),
258 )?;
259
260 let model_name = self.model_name.take().unwrap_or_else(|| {
262 if is_hf_repo {
263 relative_path.to_string()
265 } else {
266 full_path
267 .iter()
268 .next_back()
269 .map(|n| n.to_string_lossy().into_owned())
270 .unwrap_or_else(|| {
271 panic!("Invalid model path, too short: '{}'", full_path.display())
273 })
274 }
275 });
276 card.set_name(&model_name);
277
278 card.kv_cache_block_size = self.kv_cache_block_size;
279
280 if let Some(context_length) = self.context_length {
282 card.context_length = context_length;
283 }
284
285 if self.is_mocker
287 && let Some(path) = &self.extra_engine_args
288 {
289 let mocker_engine_args = MockEngineArgs::from_json_file(path)
290 .expect("Failed to load mocker engine args for runtime config overriding.");
291 self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
292 self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
293 self.runtime_config.max_num_batched_tokens =
294 mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
295 }
296
297 card.migration_limit = self.migration_limit;
298 card.user_data = self.user_data.take();
299 card.runtime_config = self.runtime_config.clone();
300
301 Ok(LocalModel {
302 card,
303 full_path,
304 endpoint_id,
305 template,
306 http_host: self.http_host.take(),
307 http_port: self.http_port,
308 tls_cert_path: self.tls_cert_path.take(),
309 tls_key_path: self.tls_key_path.take(),
310 router_config: self.router_config.take().unwrap_or_default(),
311 runtime_config: self.runtime_config.clone(),
312 namespace: self.namespace.clone(),
313 })
314 }
315}
316
317#[derive(Debug, Clone)]
318pub struct LocalModel {
319 full_path: PathBuf,
320 card: ModelDeploymentCard,
321 endpoint_id: EndpointId,
322 template: Option<RequestTemplate>,
323 http_host: Option<String>,
324 http_port: u16,
325 tls_cert_path: Option<PathBuf>,
326 tls_key_path: Option<PathBuf>,
327 router_config: RouterConfig,
328 runtime_config: ModelRuntimeConfig,
329 namespace: Option<String>,
330}
331
332impl LocalModel {
333 pub fn card(&self) -> &ModelDeploymentCard {
334 &self.card
335 }
336
337 pub fn path(&self) -> &Path {
338 &self.full_path
339 }
340
341 pub fn display_name(&self) -> &str {
343 &self.card.display_name
344 }
345
346 pub fn service_name(&self) -> &str {
349 self.card.slug().as_ref()
350 }
351
352 pub fn request_template(&self) -> Option<RequestTemplate> {
353 self.template.clone()
354 }
355
356 pub fn http_host(&self) -> Option<String> {
357 self.http_host.clone()
358 }
359
360 pub fn http_port(&self) -> u16 {
361 self.http_port
362 }
363
364 pub fn tls_cert_path(&self) -> Option<&Path> {
365 self.tls_cert_path.as_deref()
366 }
367
368 pub fn tls_key_path(&self) -> Option<&Path> {
369 self.tls_key_path.as_deref()
370 }
371
372 pub fn router_config(&self) -> &RouterConfig {
373 &self.router_config
374 }
375
376 pub fn runtime_config(&self) -> &ModelRuntimeConfig {
377 &self.runtime_config
378 }
379
380 pub fn namespace(&self) -> Option<&str> {
381 self.namespace.as_deref()
382 }
383
384 pub fn is_gguf(&self) -> bool {
385 self.full_path.is_file()
388 }
389
390 pub fn endpoint_id(&self) -> &EndpointId {
392 &self.endpoint_id
393 }
394
395 pub fn into_card(self) -> ModelDeploymentCard {
398 self.card
399 }
400
401 pub async fn attach(
404 &mut self,
405 endpoint: &Endpoint,
406 model_type: ModelType,
407 model_input: ModelInput,
408 ) -> anyhow::Result<()> {
409 let Some(etcd_client) = endpoint.drt().etcd_client() else {
411 anyhow::bail!("Cannot attach to static endpoint");
412 };
413
414 let nats_client = endpoint.drt().nats_client();
416 self.card.move_to_nats(nats_client.clone()).await?;
417
418 let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
420 let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
421 let key = self.card.slug().to_string();
422
423 card_store
424 .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
425 .await?;
426
427 let network_name = ModelNetworkName::new();
430 tracing::debug!("Registering with etcd as {network_name}");
431 let model_registration = ModelEntry {
432 name: self.display_name().to_string(),
433 endpoint_id: endpoint.id(),
434 model_type,
435 runtime_config: Some(self.runtime_config.clone()),
436 model_input,
437 };
438 etcd_client
439 .kv_create(
440 &network_name,
441 serde_json::to_vec_pretty(&model_registration)?,
442 None, )
444 .await
445 }
446}
447
448fn internal_endpoint(engine: &str) -> EndpointId {
451 EndpointId {
452 namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
453 component: engine.to_string(),
454 name: "generate".to_string(),
455 }
456}