dynamo_llm/
model_card.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Model Deployment Card
5//!
6//! The ModelDeploymentCard (MDC) is the primary model configuration structure that will be available to any
7//! component that needs to interact with the model or its dependent artifacts.
8//!
9//! The ModelDeploymentCard contains LLM model deployment configuration information:
10//! - Display name and service name for the model
11//! - Model information (ModelInfoType)
12//! - Tokenizer configuration (TokenizerKind)
13//! - Prompt formatter settings (PromptFormatterArtifact)
14//! - Various metadata like revision, publish time, etc.
15
16use std::fmt;
17use std::fs::File;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::common::checked_file::CheckedFile;
23use crate::local_model::runtime_config::ModelRuntimeConfig;
24use anyhow::{Context, Result};
25use derive_builder::Builder;
26use dynamo_runtime::DistributedRuntime;
27use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
28use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
29use serde::{Deserialize, Serialize};
30use tokenizers::Tokenizer as HfTokenizer;
31
32use crate::gguf::{Content, ContentConfig, ModelConfigLike};
33use crate::protocols::TokenIdType;
34
35/// Identify model deployment cards in the key-value store
36pub const ROOT_PATH: &str = "mdc";
37
38/// If a model deployment card hasn't been refreshed in this much time the worker is likely gone
39const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
40
41#[derive(Serialize, Deserialize, Clone, Debug)]
42#[serde(rename_all = "snake_case")]
43pub enum ModelInfoType {
44    HfConfigJson(CheckedFile),
45    GGUF(PathBuf),
46}
47
48#[derive(Serialize, Deserialize, Clone, Debug)]
49#[serde(rename_all = "snake_case")]
50pub enum TokenizerKind {
51    HfTokenizerJson(CheckedFile),
52    GGUF(Box<HfTokenizer>),
53}
54
55/// Supported types of prompt formatters.
56///
57/// We need a way to associate the prompt formatter template definition with an associated
58/// data model which is expected for rendering.
59///
60/// All current prompt formatters are Jinja2 templates which use the OpenAI ChatCompletionRequest
61/// format. However, we currently do not have a discovery path to know if the model supports tool use
62/// unless we inspect the template.
63///
64/// TODO(): Add an enum for the PromptFormatDataModel with at minimum arms for:
65/// - OaiChat
66/// - OaiChatToolUse
67#[derive(Serialize, Deserialize, Clone, Debug)]
68#[serde(rename_all = "snake_case")]
69pub enum PromptFormatterArtifact {
70    HfTokenizerConfigJson(CheckedFile),
71    HfChatTemplate(CheckedFile),
72    GGUF(PathBuf),
73}
74
75#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
76#[serde(rename_all = "snake_case")]
77pub enum PromptContextMixin {
78    /// Support OAI Chat Messages and Tools
79    OaiChat,
80
81    /// Enables templates with `{{datetime}}` to be rendered with the current date and time.
82    Llama3DateTime,
83}
84
85#[derive(Serialize, Deserialize, Clone, Debug)]
86#[serde(rename_all = "snake_case")]
87pub enum GenerationConfig {
88    HfGenerationConfigJson(CheckedFile),
89    GGUF(PathBuf),
90}
91
92#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
93pub struct ModelDeploymentCard {
94    /// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
95    pub display_name: String,
96
97    // Cache the Slugified display_name so we can share references to it
98    slug: Slug,
99
100    /// Model information
101    pub model_info: Option<ModelInfoType>,
102
103    /// Tokenizer configuration
104    pub tokenizer: Option<TokenizerKind>,
105
106    /// Prompt Formatter configuration
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub prompt_formatter: Option<PromptFormatterArtifact>,
109
110    /// chat template may be stored as a separate file instead of in `prompt_formatter`.
111    #[serde(default, skip_serializing_if = "Option::is_none")]
112    pub chat_template_file: Option<PromptFormatterArtifact>,
113
114    /// Generation config - default sampling params
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    pub gen_config: Option<GenerationConfig>,
117
118    /// Prompt Formatter Config
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub prompt_context: Option<Vec<PromptContextMixin>>,
121
122    /// When this card was last advertised by a worker. None if not yet published.
123    pub last_published: Option<chrono::DateTime<chrono::Utc>>,
124
125    /// Incrementing count of how many times we published this card
126    #[serde(default, skip_serializing)]
127    pub revision: u64,
128
129    /// Max context (in number of tokens) this model can handle
130    pub context_length: u32,
131
132    /// Size of a KV cache block - vllm only currently
133    /// Passed to the engine and the KV router.
134    pub kv_cache_block_size: u32,
135
136    /// How many times a request can be migrated to another worker if the HTTP server lost
137    /// connection to the current worker.
138    pub migration_limit: u32,
139
140    /// User-defined metadata for custom worker behavior
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub user_data: Option<serde_json::Value>,
143
144    #[serde(default)]
145    pub runtime_config: ModelRuntimeConfig,
146
147    #[serde(skip)]
148    cache_dir: Option<Arc<tempfile::TempDir>>,
149}
150
151impl ModelDeploymentCard {
152    pub fn builder() -> ModelDeploymentCardBuilder {
153        ModelDeploymentCardBuilder::default()
154    }
155
156    /// Create a ModelDeploymentCard where only the name is filled in.
157    ///
158    /// Single-process setups don't need an MDC to communicate model details, but it
159    /// simplifies the code to assume we always have one. This is how you get one in those
160    /// cases. A quasi-null object: <https://en.wikipedia.org/wiki/Null_object_pattern>
161    pub fn with_name_only(name: &str) -> ModelDeploymentCard {
162        ModelDeploymentCard {
163            display_name: name.to_string(),
164            slug: Slug::from_string(name),
165            ..Default::default()
166        }
167    }
168
169    /// How often we should check if a model deployment card expired because it's workers are gone
170    pub fn expiry_check_period() -> Duration {
171        match CARD_MAX_AGE.to_std() {
172            Ok(duration) => duration / 3,
173            Err(_) => {
174                // Only happens if CARD_MAX_AGE is negative, which it isn't
175                unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
176            }
177        }
178    }
179
180    /// Load a model deployment card from a JSON file
181    pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
182        let contents = std::fs::read_to_string(&file)?;
183        Ok(serde_json::from_str(&contents).inspect_err(|err| {
184            crate::log_json_err(&file.as_ref().display().to_string(), &contents, err)
185        })?)
186    }
187
188    /// Load a model deployment card from a JSON string
189    pub fn load_from_json_str(contents: &str) -> Result<Self, anyhow::Error> {
190        Ok(serde_json::from_str(contents)
191            .inspect_err(|err| crate::log_json_err("unknown", contents, err))?)
192    }
193
194    //
195    // Methods
196    //
197
198    /// Save the model deployment card to a JSON file
199    pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
200        std::fs::write(file, self.to_json()?)?;
201        Ok(())
202    }
203
204    pub fn slug(&self) -> &Slug {
205        &self.slug
206    }
207
208    /// Serialize the model deployment card to a JSON string
209    pub fn to_json(&self) -> Result<String, anyhow::Error> {
210        Ok(serde_json::to_string(self)?)
211    }
212
213    pub fn mdcsum(&self) -> String {
214        let json = self.to_json().unwrap();
215        format!("{}", blake3::hash(json.as_bytes()))
216    }
217
218    /// Was this card last published a long time ago, suggesting the worker is gone?
219    pub fn is_expired(&self) -> bool {
220        if let Some(last_published) = self.last_published.as_ref() {
221            chrono::Utc::now() - last_published > CARD_MAX_AGE
222        } else {
223            false
224        }
225    }
226
227    /// Is this a full model card with tokenizer?
228    /// There are cases where we have a placeholder card (see `with_name_only`).
229    pub fn has_tokenizer(&self) -> bool {
230        self.tokenizer.is_some()
231    }
232
233    pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
234        match &self.tokenizer {
235            Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
236                let p = checked_file.path().ok_or_else(|| {
237                    anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
238                })?;
239                HfTokenizer::from_file(p)
240                    .inspect_err(|err| {
241                        if let Some(serde_err) = err.downcast_ref::<serde_json::Error>()
242                            && let Ok(contents) = std::fs::read_to_string(p)
243                        {
244                            crate::log_json_err(&p.display().to_string(), &contents, serde_err);
245                        }
246                    })
247                    .map_err(anyhow::Error::msg)
248                    .with_context(|| p.display().to_string())
249            }
250            Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
251            None => {
252                anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
253            }
254        }
255    }
256
257    pub fn is_gguf(&self) -> bool {
258        match &self.model_info {
259            Some(info) => info.is_gguf(),
260            None => false,
261        }
262    }
263
264    /// Move the files this MDC uses into the NATS object store.
265    /// Updates the URI's to point to NATS.
266    pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
267        let nats_addr = nats_client.addr();
268        let bucket_name = self.slug().clone();
269        tracing::debug!(
270            nats_addr,
271            %bucket_name,
272            "Uploading model deployment card fields to NATS"
273        );
274
275        macro_rules! nats_upload {
276            ($field:expr, $enum_variant:path, $filename:literal) => {
277                if let Some($enum_variant(src_file)) = $field.as_mut()
278                    && let Some(path) = src_file.path()
279                {
280                    let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
281                    let dest = url::Url::parse(&target)?;
282                    nats_client.object_store_upload(path, &dest).await?;
283                    src_file.move_to_url(dest);
284                }
285            };
286        }
287
288        nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
289        nats_upload!(
290            self.gen_config,
291            GenerationConfig::HfGenerationConfigJson,
292            "generation_config.json"
293        );
294        nats_upload!(
295            self.prompt_formatter,
296            PromptFormatterArtifact::HfTokenizerConfigJson,
297            "tokenizer_config.json"
298        );
299        nats_upload!(
300            self.chat_template_file,
301            PromptFormatterArtifact::HfChatTemplate,
302            "chat_template.jinja"
303        );
304        nats_upload!(
305            self.tokenizer,
306            TokenizerKind::HfTokenizerJson,
307            "tokenizer.json"
308        );
309
310        Ok(())
311    }
312
313    /// Move the files this MDC uses from the NATS object store to local disk.
314    /// Updates the URI's to point to the created files.
315    ///
316    /// The returned TempDir must be kept alive, it cleans up on drop.
317    async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
318        let nats_addr = nats_client.addr();
319        let bucket_name = self.slug();
320        let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
321        tracing::debug!(
322            nats_addr,
323            %bucket_name,
324            target_dir = %target_dir.path().display(),
325            "Downloading model deployment card fields from NATS"
326        );
327
328        macro_rules! nats_download {
329            ($field:expr, $enum_variant:path, $filename:literal) => {
330                if let Some($enum_variant(src_file)) = $field.as_mut()
331                    && let Some(src_url) = src_file.url()
332                {
333                    let target = target_dir.path().join($filename);
334                    nats_client.object_store_download(src_url, &target).await?;
335                    if !src_file.checksum_matches(&target) {
336                        anyhow::bail!(
337                            "Invalid {} in NATS for {}, checksum does not match.",
338                            $filename,
339                            self.display_name
340                        );
341                    }
342                    src_file.move_to_disk(target);
343                }
344            };
345        }
346
347        nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
348        nats_download!(
349            self.gen_config,
350            GenerationConfig::HfGenerationConfigJson,
351            "generation_config.json"
352        );
353        nats_download!(
354            self.prompt_formatter,
355            PromptFormatterArtifact::HfTokenizerConfigJson,
356            "tokenizer_config.json"
357        );
358        nats_download!(
359            self.chat_template_file,
360            PromptFormatterArtifact::HfChatTemplate,
361            "chat_template.jinja"
362        );
363        nats_download!(
364            self.tokenizer,
365            TokenizerKind::HfTokenizerJson,
366            "tokenizer.json"
367        );
368
369        Ok(target_dir)
370    }
371
372    /// Delete this card from the key-value store and it's URLs from the object store
373    pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
374        let nats_addr = nats_client.addr();
375        let bucket_name = self.slug();
376        tracing::trace!(
377            nats_addr,
378            %bucket_name,
379            "Delete model deployment card from NATS"
380        );
381        nats_client
382            .object_store_delete_bucket(bucket_name.as_ref())
383            .await
384    }
385
386    /// Allow user to override the name we register this model under.
387    /// Corresponds to vllm's `--served-model-name`.
388    pub fn set_name(&mut self, name: &str) {
389        self.display_name = name.to_string();
390        self.slug = Slug::from_string(name);
391    }
392
393    /// Build an in-memory ModelDeploymentCard from either:
394    /// - a folder containing config.json, tokenizer.json and token_config.json
395    /// - a GGUF file
396    ///   With an optional custom template
397    pub fn load_from_disk(
398        config_path: impl AsRef<Path>,
399        custom_template_path: Option<&Path>,
400    ) -> anyhow::Result<ModelDeploymentCard> {
401        let config_path = config_path.as_ref();
402        if config_path.is_dir() {
403            Self::from_local_path(config_path, custom_template_path)
404        } else {
405            // GGUF files don't support custom templates yet
406            if custom_template_path.is_some() {
407                anyhow::bail!("Custom templates are not supported for GGUF files");
408            }
409            Self::from_gguf(config_path)
410        }
411    }
412
413    /// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use.
414    /// Card should be fully local and ready to use when the call returns.
415    pub async fn load_from_store(
416        model_slug: &Slug,
417        drt: &DistributedRuntime,
418    ) -> anyhow::Result<Option<Self>> {
419        let Some(etcd_client) = drt.etcd_client() else {
420            // Should be impossible because we only get here on an etcd event
421            anyhow::bail!("Missing etcd_client");
422        };
423        let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
424        let card_store = Arc::new(KeyValueStoreManager::new(store));
425        let Some(mut card) = card_store
426            .load::<ModelDeploymentCard>(ROOT_PATH, model_slug)
427            .await?
428        else {
429            return Ok(None);
430        };
431        // This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
432        card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
433        Ok(Some(card))
434    }
435
436    /// Creates a ModelDeploymentCard from a local directory path.
437    ///
438    /// Currently HuggingFace format is supported and following files are expected:
439    /// - config.json: Model configuration in HuggingFace format
440    /// - tokenizer.json: Tokenizer configuration in HuggingFace format
441    /// - tokenizer_config.json: Optional prompt formatter configuration
442    ///
443    /// # Arguments
444    /// * `local_root_dir` - Path to the local model directory
445    ///
446    /// # Errors
447    /// Returns an error if:
448    /// - The path doesn't exist or isn't a directory
449    /// - The path contains invalid Unicode characters
450    /// - Required model files are missing or invalid
451    fn from_local_path(
452        local_root_dir: impl AsRef<Path>,
453        custom_template_path: Option<&Path>,
454    ) -> anyhow::Result<Self> {
455        let local_root_dir = local_root_dir.as_ref();
456        check_valid_local_repo_path(local_root_dir)?;
457        let repo_id = local_root_dir
458            .canonicalize()?
459            .to_str()
460            .ok_or_else(|| anyhow::anyhow!("Path contains invalid Unicode"))?
461            .to_string();
462        let model_name = local_root_dir
463            .file_name()
464            .and_then(|n| n.to_str())
465            .ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
466
467        Self::from_repo(&repo_id, model_name, custom_template_path)
468    }
469
470    fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
471        let model_name = gguf_file
472            .iter()
473            .next_back()
474            .map(|n| n.to_string_lossy().to_string());
475        let Some(model_name) = model_name else {
476            // I think this would only happy on an empty path
477            anyhow::bail!(
478                "Could not extract model name from path '{}'",
479                gguf_file.display()
480            );
481        };
482
483        // TODO: we do this in HFConfig also, unify
484        let content = load_gguf(gguf_file)?;
485        let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
486            .to_u32()
487            .unwrap_or(0);
488        tracing::debug!(context_length, "Loaded context length from GGUF");
489
490        Ok(Self {
491            display_name: model_name.to_string(),
492            slug: Slug::from_string(model_name),
493            model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())),
494            tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
495            gen_config: None, // AFAICT there is no equivalent in a GGUF
496            prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
497            chat_template_file: None,
498            prompt_context: None, // TODO - auto-detect prompt context
499            revision: 0,
500            last_published: None,
501            context_length,
502            kv_cache_block_size: 0,
503            migration_limit: 0,
504            user_data: None,
505            runtime_config: ModelRuntimeConfig::default(),
506            cache_dir: None,
507        })
508    }
509
510    fn from_repo(
511        repo_id: &str,
512        model_name: &str,
513        custom_template_path: Option<&Path>,
514    ) -> anyhow::Result<Self> {
515        // This is usually the right choice
516        let context_length = crate::file_json_field(
517            &PathBuf::from(repo_id).join("config.json"),
518            "max_position_embeddings",
519        )
520        // But sometimes this is
521        .or_else(|_| {
522            crate::file_json_field(
523                &PathBuf::from(repo_id).join("tokenizer_config.json"),
524                "model_max_length",
525            )
526        })
527        // If neither of those are present let the engine default it
528        .unwrap_or(0);
529
530        // Load chat template - either custom or from repo
531        let chat_template_file = if let Some(template_path) = custom_template_path {
532            if !template_path.exists() {
533                anyhow::bail!(
534                    "Custom template file does not exist: {}",
535                    template_path.display()
536                );
537            }
538
539            // Verify the file is readable
540            let _template_content = std::fs::read_to_string(template_path).with_context(|| {
541                format!(
542                    "Failed to read custom template file: {}",
543                    template_path.display()
544                )
545            })?;
546
547            Some(PromptFormatterArtifact::HfChatTemplate(
548                CheckedFile::from_disk(template_path)?,
549            ))
550        } else {
551            PromptFormatterArtifact::chat_template_from_repo(repo_id)?
552        };
553
554        Ok(Self {
555            display_name: model_name.to_string(),
556            slug: Slug::from_string(model_name),
557            model_info: Some(ModelInfoType::from_repo(repo_id)?),
558            tokenizer: Some(TokenizerKind::from_repo(repo_id)?),
559            gen_config: GenerationConfig::from_repo(repo_id).ok(), // optional
560            prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
561            chat_template_file,
562            prompt_context: None, // TODO - auto-detect prompt context
563            revision: 0,
564            last_published: None,
565            context_length,
566            kv_cache_block_size: 0, // set later
567            migration_limit: 0,
568            user_data: None,
569            runtime_config: ModelRuntimeConfig::default(),
570            cache_dir: None,
571        })
572    }
573}
574
575impl Versioned for ModelDeploymentCard {
576    fn revision(&self) -> u64 {
577        self.revision
578    }
579
580    fn set_revision(&mut self, revision: u64) {
581        self.last_published = Some(chrono::Utc::now());
582        self.revision = revision;
583    }
584}
585
586impl fmt::Display for ModelDeploymentCard {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        write!(f, "{}", self.slug())
589    }
590}
591pub trait ModelInfo: Send + Sync {
592    /// Model type
593    fn model_type(&self) -> String;
594
595    /// Token ID for the beginning of sequence
596    fn bos_token_id(&self) -> TokenIdType;
597
598    /// Token ID for the end of sequence
599    fn eos_token_ids(&self) -> Vec<TokenIdType>;
600
601    /// Maximum position embeddings / max sequence length
602    /// TODO: This is only used in a single test, no other code. Remove?
603    fn max_position_embeddings(&self) -> Option<usize>;
604
605    /// Vocabulary size
606    /// TODO: This is only used in a single test, no other code. Remove?
607    fn vocab_size(&self) -> Option<usize>;
608}
609
610impl ModelInfoType {
611    pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
612        match self {
613            Self::HfConfigJson(checked_file) => {
614                let Some(path) = checked_file.path() else {
615                    anyhow::bail!("model info is not a local path: {checked_file:?}");
616                };
617                Ok(HFConfig::from_json_file(path)?)
618            }
619            Self::GGUF(path) => Ok(HFConfig::from_gguf(path)?),
620        }
621    }
622    pub fn is_gguf(&self) -> bool {
623        matches!(self, Self::GGUF(_))
624    }
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628struct HFConfig {
629    /// denotes the mixin to the flattened data model which can be present
630    /// in the config.json file
631    architectures: Vec<String>,
632
633    /// general model type
634    model_type: String,
635
636    text_config: Option<HFTextConfig>,
637
638    // Sometimes it's inside HFTextConfig, sometimes it's here
639    eos_token_id: Option<serde_json::Value>,
640}
641
642#[derive(Debug, Clone, Serialize, Deserialize)]
643struct HFTextConfig {
644    // It can take multiple attempts to load this, so Option
645    bos_token_id: Option<TokenIdType>,
646
647    // We set this once bos_token_id is loaded so we don't have to deal with Option
648    #[serde(default)]
649    final_bos_token_id: TokenIdType,
650
651    eos_token_id: Option<serde_json::Value>,
652
653    #[serde(default)]
654    final_eos_token_ids: Vec<TokenIdType>,
655
656    /// max sequence length
657    max_position_embeddings: Option<usize>,
658
659    /// number of layers in the model
660    num_hidden_layers: usize,
661
662    /// number of attention heads in the model
663    num_attention_heads: Option<usize>,
664
665    /// Vocabulary size
666    vocab_size: Option<usize>,
667}
668
669impl HFConfig {
670    fn from_json_file<P: AsRef<Path>>(file: P) -> Result<Arc<dyn ModelInfo>> {
671        let file_path = file.as_ref();
672        let contents = std::fs::read_to_string(file_path)?;
673        let mut config: Self = json_five::from_str(&contents)
674            .inspect_err(|err| {
675                tracing::error!(path=%file_path.display(), %err, "Failed to parse config.json as JSON5");
676            })?;
677        if config.text_config.is_none() {
678            let text_config: HFTextConfig = json_five::from_str(&contents)
679                .inspect_err(|err| {
680                    tracing::error!(path=%file_path.display(), %err, "Failed to parse text config from config.json as JSON5");
681                })?;
682            config.text_config = Some(text_config);
683        }
684
685        // Sometimes bos_token_id is in generation_config.json not config.json
686        let Some(text_config) = config.text_config.as_mut() else {
687            anyhow::bail!(
688                "Missing text config fields (model_type, eos_token_ids, etc) in config.json"
689            );
690        };
691
692        let gencfg_path = file_path
693            .parent()
694            .unwrap_or_else(|| Path::new(""))
695            .join("generation_config.json");
696        if text_config.bos_token_id.is_none() {
697            let bos_token_id = crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id")
698                .context(
699                    "missing bos_token_id in generation_config.json and config.json, cannot load",
700                )?;
701            text_config.bos_token_id = Some(bos_token_id);
702        }
703        // Now that we have it for sure, set it in the non-Option field
704        let final_bos_token_id = text_config.bos_token_id.take().unwrap();
705        text_config.final_bos_token_id = final_bos_token_id;
706
707        // TODO: refactor this when we switch to per-architecture tokenization
708        let final_eos_token_ids: Vec<TokenIdType> = config
709            .eos_token_id
710            .as_ref()
711            .or(text_config.eos_token_id.as_ref())
712            .and_then(|v| {
713                if v.is_number() {
714                    v.as_number()
715                        .and_then(|n| n.as_u64())
716                        .map(|n| vec![n as TokenIdType])
717                } else if v.is_array() {
718                    let arr = v.as_array().unwrap(); // Safety: We just checked
719                    Some(
720                        arr.iter()
721                            .filter_map(|inner_v| {
722                                inner_v
723                                    .as_number()
724                                    .and_then(|n| n.as_u64())
725                                    .map(|n| n as TokenIdType)
726                            })
727                            .collect(),
728                    )
729                } else {
730                    tracing::error!(
731                        ?v,
732                        path = %file_path.display(),
733                        "eos_token_id is not a number or an array, cannot use"
734                    );
735                    None
736                }
737            })
738            .or_else(|| {
739                // Maybe it's in generation_config.json
740                crate::file_json_field(&gencfg_path, "eos_token_id")
741                .inspect_err(
742                    |err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
743                )
744                .ok()
745            })
746            .ok_or_else(|| {
747                anyhow::anyhow!(
748                    "missing eos_token_id in config.json and generation_config.json, cannot load"
749                )
750            })?;
751        text_config.final_eos_token_ids = final_eos_token_ids;
752
753        Ok(Arc::new(config))
754    }
755    fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
756        let content = load_gguf(gguf_file)?;
757        let model_config_metadata: ContentConfig = (&content).into();
758        let num_hidden_layers =
759            content.get_metadata()[&format!("{}.block_count", content.arch())].to_u32()? as usize;
760
761        let bos_token_id = content.get_metadata()["tokenizer.ggml.bos_token_id"].to_u32()?;
762        let eos_token_id = content.get_metadata()["tokenizer.ggml.eos_token_id"].to_u32()?;
763
764        // to_vec returns a Vec that's already there, so it's cheap
765        let vocab_size = content.get_metadata()["tokenizer.ggml.tokens"]
766            .to_vec()?
767            .len();
768
769        let arch = content.arch().to_string();
770        Ok(Arc::new(HFConfig {
771            architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
772            // "general.architecture"
773            model_type: arch,
774            text_config: Some(HFTextConfig {
775                bos_token_id: None,
776                final_bos_token_id: bos_token_id,
777
778                eos_token_id: None,
779                final_eos_token_ids: vec![eos_token_id],
780
781                // "llama.context_length"
782                max_position_embeddings: Some(model_config_metadata.max_seq_len()),
783                // "llama.block_count"
784                num_hidden_layers,
785                // "llama.attention.head_count"
786                num_attention_heads: Some(model_config_metadata.num_attn_heads()),
787                // "tokenizer.ggml.tokens".len()
788                vocab_size: Some(vocab_size),
789            }),
790            eos_token_id: None,
791        }))
792    }
793}
794
795impl ModelInfo for HFConfig {
796    fn model_type(&self) -> String {
797        self.model_type.clone()
798    }
799
800    fn bos_token_id(&self) -> TokenIdType {
801        self.text_config.as_ref().unwrap().final_bos_token_id
802    }
803
804    fn eos_token_ids(&self) -> Vec<TokenIdType> {
805        self.text_config
806            .as_ref()
807            .unwrap()
808            .final_eos_token_ids
809            .clone()
810    }
811
812    fn max_position_embeddings(&self) -> Option<usize> {
813        self.text_config.as_ref().unwrap().max_position_embeddings
814    }
815
816    fn vocab_size(&self) -> Option<usize> {
817        self.text_config.as_ref().unwrap().vocab_size
818    }
819}
820
821impl TokenizerKind {
822    pub fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
823        let content = load_gguf(gguf_file)?;
824        let out = crate::gguf::convert_gguf_to_hf_tokenizer(&content)
825            .with_context(|| gguf_file.display().to_string())?;
826        Ok(TokenizerKind::GGUF(Box::new(out.tokenizer)))
827    }
828}
829
830pub(crate) fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
831    let filename = gguf_file.display().to_string();
832    let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
833    // vec because GGUF can be split into multiple files (shards)
834    let mut readers = vec![&mut f];
835    crate::gguf::Content::from_readers(&mut readers).with_context(|| filename.clone())
836}
837
838fn capitalize(s: &str) -> String {
839    let mut chars = s.chars();
840    match chars.next() {
841        None => String::new(),
842        Some(first) => first.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase(),
843    }
844}
845
846impl ModelInfoType {
847    pub fn from_repo(repo_id: &str) -> Result<Self> {
848        let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("config.json"))
849            .with_context(|| format!("unable to extract config.json from repo {repo_id}"))?;
850        Ok(Self::HfConfigJson(f))
851    }
852}
853
854impl GenerationConfig {
855    pub fn from_repo(repo_id: &str) -> Result<Self> {
856        let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("generation_config.json"))
857            .with_context(|| format!("unable to extract generation_config from repo {repo_id}"))?;
858        Ok(Self::HfGenerationConfigJson(f))
859    }
860}
861
862impl PromptFormatterArtifact {
863    pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
864        // we should only error if we expect a prompt formatter and it's not found
865        // right now, we don't know when to expect it, so we just return Ok(Some/None)
866        match CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer_config.json")) {
867            Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))),
868            Err(_) => Ok(None),
869        }
870    }
871
872    pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
873        match CheckedFile::from_disk(PathBuf::from(repo_id).join("chat_template.jinja")) {
874            Ok(f) => Ok(Some(Self::HfChatTemplate(f))),
875            Err(_) => Ok(None),
876        }
877    }
878}
879
880impl TokenizerKind {
881    pub fn from_repo(repo_id: &str) -> Result<Self> {
882        let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer.json"))
883            .with_context(|| format!("unable to extract tokenizer kind from repo {repo_id}"))?;
884        Ok(Self::HfTokenizerJson(f))
885    }
886}
887
888/// Checks if the provided path is a valid local repository path.
889///
890/// # Arguments
891/// * `path` - Path to validate
892///
893/// # Errors
894/// Returns an error if the path doesn't exist or isn't a directory
895fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
896    let path = path.as_ref();
897    if !path.exists() {
898        return Err(anyhow::anyhow!(
899            "Model path does not exist: {}",
900            path.display()
901        ));
902    }
903
904    if !path.is_dir() {
905        return Err(anyhow::anyhow!(
906            "Model path is not a directory: {}",
907            path.display()
908        ));
909    }
910    Ok(())
911}
912
913#[cfg(test)]
914mod tests {
915    use super::HFConfig;
916    use std::path::Path;
917
918    #[test]
919    pub fn test_config_json_llama3() -> anyhow::Result<()> {
920        let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
921            .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
922        let config = HFConfig::from_json_file(&config_file)?;
923        assert_eq!(config.bos_token_id(), 128000);
924        Ok(())
925    }
926
927    #[test]
928    pub fn test_config_json_llama4() -> anyhow::Result<()> {
929        let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
930            .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
931        let config = HFConfig::from_json_file(&config_file)?;
932        assert_eq!(config.bos_token_id(), 200000);
933        Ok(())
934    }
935
936    /// The Python JSON parser accepts `Infinity` as a numeric value. This is explicitly against the
937    /// JSON spec, but inevitably people rely on it, so we have to allow it.
938    /// We treat that file as JSON5 (a lenient superset of JSON) to be able to parse it.
939    #[test]
940    fn test_invalid_json_but_py_accepts_it() {
941        dynamo_runtime::logging::init();
942        let path = "tests/data/sample-models/NVIDIA-Nemotron-Nano-12B-v2-Base/config.json";
943        let _ = HFConfig::from_json_file(path).unwrap();
944    }
945}