dynamo_llm/
local_model.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7
8use anyhow::Context as _;
9use dynamo_runtime::protocols::EndpointId;
10use dynamo_runtime::slug::Slug;
11use dynamo_runtime::traits::DistributedRuntimeProvider;
12use dynamo_runtime::{
13    component::Endpoint,
14    storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
15};
16
17use crate::discovery::ModelEntry;
18use crate::entrypoint::RouterConfig;
19use crate::mocker::protocols::MockEngineArgs;
20use crate::model_card::{self, ModelDeploymentCard};
21use crate::model_type::{ModelInput, ModelType};
22use crate::request_template::RequestTemplate;
23
24mod network_name;
25pub use network_name::ModelNetworkName;
26pub mod runtime_config;
27
28use runtime_config::ModelRuntimeConfig;
29
30/// Prefix for Hugging Face model repository
31const HF_SCHEME: &str = "hf://";
32
33/// What we call a model if the user didn't provide a name. Usually this means the name
34/// is invisible, for example in a text chat.
35const DEFAULT_NAME: &str = "dynamo";
36
37/// Engines don't usually provide a default, so we do.
38const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;
39
40/// We can't have it default to 0, so pick something
41/// 'pub' because the bindings use it for consistency.
42pub const DEFAULT_HTTP_PORT: u16 = 8080;
43
44pub struct LocalModelBuilder {
45    model_path: Option<PathBuf>,
46    model_name: Option<String>,
47    model_config: Option<PathBuf>,
48    endpoint_id: Option<EndpointId>,
49    context_length: Option<u32>,
50    template_file: Option<PathBuf>,
51    router_config: Option<RouterConfig>,
52    kv_cache_block_size: u32,
53    http_host: Option<String>,
54    http_port: u16,
55    tls_cert_path: Option<PathBuf>,
56    tls_key_path: Option<PathBuf>,
57    migration_limit: u32,
58    is_mocker: bool,
59    extra_engine_args: Option<PathBuf>,
60    runtime_config: ModelRuntimeConfig,
61    user_data: Option<serde_json::Value>,
62    custom_template_path: Option<PathBuf>,
63    namespace: Option<String>,
64}
65
66impl Default for LocalModelBuilder {
67    fn default() -> Self {
68        LocalModelBuilder {
69            kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
70            http_host: Default::default(),
71            http_port: DEFAULT_HTTP_PORT,
72            tls_cert_path: Default::default(),
73            tls_key_path: Default::default(),
74            model_path: Default::default(),
75            model_name: Default::default(),
76            model_config: Default::default(),
77            endpoint_id: Default::default(),
78            context_length: Default::default(),
79            template_file: Default::default(),
80            router_config: Default::default(),
81            migration_limit: Default::default(),
82            is_mocker: Default::default(),
83            extra_engine_args: Default::default(),
84            runtime_config: Default::default(),
85            user_data: Default::default(),
86            custom_template_path: Default::default(),
87            namespace: Default::default(),
88        }
89    }
90}
91
92impl LocalModelBuilder {
93    pub fn model_path(&mut self, model_path: Option<PathBuf>) -> &mut Self {
94        self.model_path = model_path;
95        self
96    }
97
98    pub fn model_name(&mut self, model_name: Option<String>) -> &mut Self {
99        self.model_name = model_name;
100        self
101    }
102
103    pub fn model_config(&mut self, model_config: Option<PathBuf>) -> &mut Self {
104        self.model_config = model_config;
105        self
106    }
107
108    pub fn endpoint_id(&mut self, endpoint_id: Option<EndpointId>) -> &mut Self {
109        self.endpoint_id = endpoint_id;
110        self
111    }
112
113    pub fn context_length(&mut self, context_length: Option<u32>) -> &mut Self {
114        self.context_length = context_length;
115        self
116    }
117
118    /// Passing None resets it to default
119    pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option<u32>) -> &mut Self {
120        self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE);
121        self
122    }
123
124    pub fn http_host(&mut self, host: Option<String>) -> &mut Self {
125        self.http_host = host;
126        self
127    }
128
129    pub fn http_port(&mut self, port: u16) -> &mut Self {
130        self.http_port = port;
131        self
132    }
133
134    pub fn tls_cert_path(&mut self, p: Option<PathBuf>) -> &mut Self {
135        self.tls_cert_path = p;
136        self
137    }
138
139    pub fn tls_key_path(&mut self, p: Option<PathBuf>) -> &mut Self {
140        self.tls_key_path = p;
141        self
142    }
143
144    pub fn router_config(&mut self, router_config: Option<RouterConfig>) -> &mut Self {
145        self.router_config = router_config;
146        self
147    }
148
149    pub fn namespace(&mut self, namespace: Option<String>) -> &mut Self {
150        self.namespace = namespace;
151        self
152    }
153
154    pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
155        self.template_file = template_file;
156        self
157    }
158
159    pub fn custom_template_path(&mut self, custom_template_path: Option<PathBuf>) -> &mut Self {
160        self.custom_template_path = custom_template_path;
161        self
162    }
163
164    pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
165        self.migration_limit = migration_limit.unwrap_or(0);
166        self
167    }
168
169    pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
170        self.is_mocker = is_mocker;
171        self
172    }
173
174    pub fn extra_engine_args(&mut self, extra_engine_args: Option<PathBuf>) -> &mut Self {
175        self.extra_engine_args = extra_engine_args;
176        self
177    }
178
179    pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self {
180        self.runtime_config = runtime_config;
181        self
182    }
183
184    pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
185        self.user_data = user_data;
186        self
187    }
188
189    /// Make an LLM ready for use:
190    /// - Download it from Hugging Face (and NGC in future) if necessary
191    /// - Resolve the path
192    /// - Load it's ModelDeploymentCard card
193    /// - Name it correctly
194    ///
195    /// The model name will depend on what "model_path" is:
196    /// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct"
197    /// - A file: The GGUF filename: "/data/llms/Qwen2.5-3B-Instruct-Q6_K.gguf" -> "Qwen2.5-3B-Instruct-Q6_K.gguf"
198    /// - An HF repo: The HF repo name: "Qwen/Qwen3-0.6B" stays the same
199    pub async fn build(&mut self) -> anyhow::Result<LocalModel> {
200        // Generate an endpoint ID for this model if the user didn't provide one.
201        // The user only provides one if exposing the model.
202        let endpoint_id = self
203            .endpoint_id
204            .take()
205            .unwrap_or_else(|| internal_endpoint("local_model"));
206
207        let template = self
208            .template_file
209            .as_deref()
210            .map(RequestTemplate::load)
211            .transpose()?;
212
213        // echo engine doesn't need a path. It's an edge case, move it out of the way.
214        if self.model_path.is_none() {
215            let mut card = ModelDeploymentCard::with_name_only(
216                self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
217            );
218            card.migration_limit = self.migration_limit;
219            card.user_data = self.user_data.take();
220            card.runtime_config = self.runtime_config.clone();
221
222            return Ok(LocalModel {
223                card,
224                full_path: PathBuf::new(),
225                endpoint_id,
226                template,
227                http_host: self.http_host.take(),
228                http_port: self.http_port,
229                tls_cert_path: self.tls_cert_path.take(),
230                tls_key_path: self.tls_key_path.take(),
231                router_config: self.router_config.take().unwrap_or_default(),
232                runtime_config: self.runtime_config.clone(),
233                namespace: self.namespace.clone(),
234            });
235        }
236
237        // Main logic. We are running a model.
238        let model_path = self.model_path.take().unwrap();
239        let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?;
240
241        // Check for hf:// prefix first, in case we really want an HF repo but it conflicts
242        // with a relative path.
243        let is_hf_repo =
244            model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false);
245        let relative_path = model_path.trim_start_matches(HF_SCHEME);
246        let full_path = if is_hf_repo {
247            // HF download if necessary
248            super::hub::from_hf(relative_path, self.is_mocker).await?
249        } else {
250            fs::canonicalize(relative_path)?
251        };
252        // --model-config takes precedence over --model-path
253        let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
254
255        let mut card = ModelDeploymentCard::load_from_disk(
256            model_config_path,
257            self.custom_template_path.as_deref(),
258        )?;
259
260        // Usually we infer from the path, self.model_name is user override
261        let model_name = self.model_name.take().unwrap_or_else(|| {
262            if is_hf_repo {
263                // HF repos use their full name ("org/name") not the folder name
264                relative_path.to_string()
265            } else {
266                full_path
267                    .iter()
268                    .next_back()
269                    .map(|n| n.to_string_lossy().into_owned())
270                    .unwrap_or_else(|| {
271                        // Panic because we can't do anything without a model
272                        panic!("Invalid model path, too short: '{}'", full_path.display())
273                    })
274            }
275        });
276        card.set_name(&model_name);
277
278        card.kv_cache_block_size = self.kv_cache_block_size;
279
280        // Override max number of tokens in context. We usually only do this to limit kv cache allocation.
281        if let Some(context_length) = self.context_length {
282            card.context_length = context_length;
283        }
284
285        // Override runtime configs with mocker engine args
286        if self.is_mocker
287            && let Some(path) = &self.extra_engine_args
288        {
289            let mocker_engine_args = MockEngineArgs::from_json_file(path)
290                .expect("Failed to load mocker engine args for runtime config overriding.");
291            self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
292            self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
293            self.runtime_config.max_num_batched_tokens =
294                mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
295        }
296
297        card.migration_limit = self.migration_limit;
298        card.user_data = self.user_data.take();
299        card.runtime_config = self.runtime_config.clone();
300
301        Ok(LocalModel {
302            card,
303            full_path,
304            endpoint_id,
305            template,
306            http_host: self.http_host.take(),
307            http_port: self.http_port,
308            tls_cert_path: self.tls_cert_path.take(),
309            tls_key_path: self.tls_key_path.take(),
310            router_config: self.router_config.take().unwrap_or_default(),
311            runtime_config: self.runtime_config.clone(),
312            namespace: self.namespace.clone(),
313        })
314    }
315}
316
317#[derive(Debug, Clone)]
318pub struct LocalModel {
319    full_path: PathBuf,
320    card: ModelDeploymentCard,
321    endpoint_id: EndpointId,
322    template: Option<RequestTemplate>,
323    http_host: Option<String>,
324    http_port: u16,
325    tls_cert_path: Option<PathBuf>,
326    tls_key_path: Option<PathBuf>,
327    router_config: RouterConfig,
328    runtime_config: ModelRuntimeConfig,
329    namespace: Option<String>,
330}
331
332impl LocalModel {
333    pub fn card(&self) -> &ModelDeploymentCard {
334        &self.card
335    }
336
337    pub fn path(&self) -> &Path {
338        &self.full_path
339    }
340
341    /// Human friendly model name. This is the correct name.
342    pub fn display_name(&self) -> &str {
343        &self.card.display_name
344    }
345
346    /// The name under which we make this model available over HTTP.
347    /// A slugified version of the model's name, for use in NATS, etcd, etc.
348    pub fn service_name(&self) -> &str {
349        self.card.slug().as_ref()
350    }
351
352    pub fn request_template(&self) -> Option<RequestTemplate> {
353        self.template.clone()
354    }
355
356    pub fn http_host(&self) -> Option<String> {
357        self.http_host.clone()
358    }
359
360    pub fn http_port(&self) -> u16 {
361        self.http_port
362    }
363
364    pub fn tls_cert_path(&self) -> Option<&Path> {
365        self.tls_cert_path.as_deref()
366    }
367
368    pub fn tls_key_path(&self) -> Option<&Path> {
369        self.tls_key_path.as_deref()
370    }
371
372    pub fn router_config(&self) -> &RouterConfig {
373        &self.router_config
374    }
375
376    pub fn runtime_config(&self) -> &ModelRuntimeConfig {
377        &self.runtime_config
378    }
379
380    pub fn namespace(&self) -> Option<&str> {
381        self.namespace.as_deref()
382    }
383
384    pub fn is_gguf(&self) -> bool {
385        // GGUF is the only file (not-folder) we accept, so we don't need to check the extension
386        // We will error when we come to parse it
387        self.full_path.is_file()
388    }
389
390    /// An endpoint to identify this model by.
391    pub fn endpoint_id(&self) -> &EndpointId {
392        &self.endpoint_id
393    }
394
395    /// Drop the LocalModel returning it's ModelDeploymentCard.
396    /// For the case where we only need the card and don't want to clone it.
397    pub fn into_card(self) -> ModelDeploymentCard {
398        self.card
399    }
400
401    /// Attach this model the endpoint. This registers it on the network
402    /// allowing ingress to discover it.
403    pub async fn attach(
404        &mut self,
405        endpoint: &Endpoint,
406        model_type: ModelType,
407        model_input: ModelInput,
408    ) -> anyhow::Result<()> {
409        // A static component doesn't have an etcd_client because it doesn't need to register
410        let Some(etcd_client) = endpoint.drt().etcd_client() else {
411            anyhow::bail!("Cannot attach to static endpoint");
412        };
413
414        // Store model config files in NATS object store
415        let nats_client = endpoint.drt().nats_client();
416        self.card.move_to_nats(nats_client.clone()).await?;
417
418        // Publish the Model Deployment Card to etcd
419        let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
420        let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
421        let key = self.card.slug().to_string();
422
423        card_store
424            .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
425            .await?;
426
427        // Publish our ModelEntry to etcd. This allows ingress to find the model card.
428        // (Why don't we put the model card directly under this key?)
429        let network_name = ModelNetworkName::new();
430        tracing::debug!("Registering with etcd as {network_name}");
431        let model_registration = ModelEntry {
432            name: self.display_name().to_string(),
433            endpoint_id: endpoint.id(),
434            model_type,
435            runtime_config: Some(self.runtime_config.clone()),
436            model_input,
437        };
438        etcd_client
439            .kv_create(
440                &network_name,
441                serde_json::to_vec_pretty(&model_registration)?,
442                None, // use primary lease
443            )
444            .await
445    }
446}
447
448/// A random endpoint to use for internal communication
449/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
450fn internal_endpoint(engine: &str) -> EndpointId {
451    EndpointId {
452        namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
453        component: engine.to_string(),
454        name: "generate".to_string(),
455    }
456}