1use std::fmt;
17use std::fs::File;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::common::checked_file::CheckedFile;
23use crate::local_model::runtime_config::ModelRuntimeConfig;
24use anyhow::{Context, Result};
25use derive_builder::Builder;
26use dynamo_runtime::DistributedRuntime;
27use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
28use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
29use serde::{Deserialize, Serialize};
30use tokenizers::Tokenizer as HfTokenizer;
31
32use crate::gguf::{Content, ContentConfig, ModelConfigLike};
33use crate::protocols::TokenIdType;
34
35pub const ROOT_PATH: &str = "mdc";
37
38const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
40
41#[derive(Serialize, Deserialize, Clone, Debug)]
42#[serde(rename_all = "snake_case")]
43pub enum ModelInfoType {
44 HfConfigJson(CheckedFile),
45 GGUF(PathBuf),
46}
47
48#[derive(Serialize, Deserialize, Clone, Debug)]
49#[serde(rename_all = "snake_case")]
50pub enum TokenizerKind {
51 HfTokenizerJson(CheckedFile),
52 GGUF(Box<HfTokenizer>),
53}
54
55#[derive(Serialize, Deserialize, Clone, Debug)]
68#[serde(rename_all = "snake_case")]
69pub enum PromptFormatterArtifact {
70 HfTokenizerConfigJson(CheckedFile),
71 HfChatTemplate(CheckedFile),
72 GGUF(PathBuf),
73}
74
75#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
76#[serde(rename_all = "snake_case")]
77pub enum PromptContextMixin {
78 OaiChat,
80
81 Llama3DateTime,
83}
84
85#[derive(Serialize, Deserialize, Clone, Debug)]
86#[serde(rename_all = "snake_case")]
87pub enum GenerationConfig {
88 HfGenerationConfigJson(CheckedFile),
89 GGUF(PathBuf),
90}
91
92#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
93pub struct ModelDeploymentCard {
94 pub display_name: String,
96
97 slug: Slug,
99
100 pub model_info: Option<ModelInfoType>,
102
103 pub tokenizer: Option<TokenizerKind>,
105
106 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub prompt_formatter: Option<PromptFormatterArtifact>,
109
110 #[serde(default, skip_serializing_if = "Option::is_none")]
112 pub chat_template_file: Option<PromptFormatterArtifact>,
113
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub gen_config: Option<GenerationConfig>,
117
118 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub prompt_context: Option<Vec<PromptContextMixin>>,
121
122 pub last_published: Option<chrono::DateTime<chrono::Utc>>,
124
125 #[serde(default, skip_serializing)]
127 pub revision: u64,
128
129 pub context_length: u32,
131
132 pub kv_cache_block_size: u32,
135
136 pub migration_limit: u32,
139
140 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub user_data: Option<serde_json::Value>,
143
144 #[serde(default)]
145 pub runtime_config: ModelRuntimeConfig,
146
147 #[serde(skip)]
148 cache_dir: Option<Arc<tempfile::TempDir>>,
149}
150
151impl ModelDeploymentCard {
152 pub fn builder() -> ModelDeploymentCardBuilder {
153 ModelDeploymentCardBuilder::default()
154 }
155
156 pub fn with_name_only(name: &str) -> ModelDeploymentCard {
162 ModelDeploymentCard {
163 display_name: name.to_string(),
164 slug: Slug::from_string(name),
165 ..Default::default()
166 }
167 }
168
169 pub fn expiry_check_period() -> Duration {
171 match CARD_MAX_AGE.to_std() {
172 Ok(duration) => duration / 3,
173 Err(_) => {
174 unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
176 }
177 }
178 }
179
180 pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
182 let contents = std::fs::read_to_string(&file)?;
183 Ok(serde_json::from_str(&contents).inspect_err(|err| {
184 crate::log_json_err(&file.as_ref().display().to_string(), &contents, err)
185 })?)
186 }
187
188 pub fn load_from_json_str(contents: &str) -> Result<Self, anyhow::Error> {
190 Ok(serde_json::from_str(contents)
191 .inspect_err(|err| crate::log_json_err("unknown", contents, err))?)
192 }
193
194 pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
200 std::fs::write(file, self.to_json()?)?;
201 Ok(())
202 }
203
204 pub fn slug(&self) -> &Slug {
205 &self.slug
206 }
207
208 pub fn to_json(&self) -> Result<String, anyhow::Error> {
210 Ok(serde_json::to_string(self)?)
211 }
212
213 pub fn mdcsum(&self) -> String {
214 let json = self.to_json().unwrap();
215 format!("{}", blake3::hash(json.as_bytes()))
216 }
217
218 pub fn is_expired(&self) -> bool {
220 if let Some(last_published) = self.last_published.as_ref() {
221 chrono::Utc::now() - last_published > CARD_MAX_AGE
222 } else {
223 false
224 }
225 }
226
227 pub fn has_tokenizer(&self) -> bool {
230 self.tokenizer.is_some()
231 }
232
233 pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
234 match &self.tokenizer {
235 Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
236 let p = checked_file.path().ok_or_else(|| {
237 anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
238 })?;
239 HfTokenizer::from_file(p)
240 .inspect_err(|err| {
241 if let Some(serde_err) = err.downcast_ref::<serde_json::Error>()
242 && let Ok(contents) = std::fs::read_to_string(p)
243 {
244 crate::log_json_err(&p.display().to_string(), &contents, serde_err);
245 }
246 })
247 .map_err(anyhow::Error::msg)
248 .with_context(|| p.display().to_string())
249 }
250 Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
251 None => {
252 anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
253 }
254 }
255 }
256
257 pub fn is_gguf(&self) -> bool {
258 match &self.model_info {
259 Some(info) => info.is_gguf(),
260 None => false,
261 }
262 }
263
264 pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
267 let nats_addr = nats_client.addr();
268 let bucket_name = self.slug().clone();
269 tracing::debug!(
270 nats_addr,
271 %bucket_name,
272 "Uploading model deployment card fields to NATS"
273 );
274
275 macro_rules! nats_upload {
276 ($field:expr, $enum_variant:path, $filename:literal) => {
277 if let Some($enum_variant(src_file)) = $field.as_mut()
278 && let Some(path) = src_file.path()
279 {
280 let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
281 let dest = url::Url::parse(&target)?;
282 nats_client.object_store_upload(path, &dest).await?;
283 src_file.move_to_url(dest);
284 }
285 };
286 }
287
288 nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
289 nats_upload!(
290 self.gen_config,
291 GenerationConfig::HfGenerationConfigJson,
292 "generation_config.json"
293 );
294 nats_upload!(
295 self.prompt_formatter,
296 PromptFormatterArtifact::HfTokenizerConfigJson,
297 "tokenizer_config.json"
298 );
299 nats_upload!(
300 self.chat_template_file,
301 PromptFormatterArtifact::HfChatTemplate,
302 "chat_template.jinja"
303 );
304 nats_upload!(
305 self.tokenizer,
306 TokenizerKind::HfTokenizerJson,
307 "tokenizer.json"
308 );
309
310 Ok(())
311 }
312
313 async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
318 let nats_addr = nats_client.addr();
319 let bucket_name = self.slug();
320 let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
321 tracing::debug!(
322 nats_addr,
323 %bucket_name,
324 target_dir = %target_dir.path().display(),
325 "Downloading model deployment card fields from NATS"
326 );
327
328 macro_rules! nats_download {
329 ($field:expr, $enum_variant:path, $filename:literal) => {
330 if let Some($enum_variant(src_file)) = $field.as_mut()
331 && let Some(src_url) = src_file.url()
332 {
333 let target = target_dir.path().join($filename);
334 nats_client.object_store_download(src_url, &target).await?;
335 if !src_file.checksum_matches(&target) {
336 anyhow::bail!(
337 "Invalid {} in NATS for {}, checksum does not match.",
338 $filename,
339 self.display_name
340 );
341 }
342 src_file.move_to_disk(target);
343 }
344 };
345 }
346
347 nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
348 nats_download!(
349 self.gen_config,
350 GenerationConfig::HfGenerationConfigJson,
351 "generation_config.json"
352 );
353 nats_download!(
354 self.prompt_formatter,
355 PromptFormatterArtifact::HfTokenizerConfigJson,
356 "tokenizer_config.json"
357 );
358 nats_download!(
359 self.chat_template_file,
360 PromptFormatterArtifact::HfChatTemplate,
361 "chat_template.jinja"
362 );
363 nats_download!(
364 self.tokenizer,
365 TokenizerKind::HfTokenizerJson,
366 "tokenizer.json"
367 );
368
369 Ok(target_dir)
370 }
371
372 pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
374 let nats_addr = nats_client.addr();
375 let bucket_name = self.slug();
376 tracing::trace!(
377 nats_addr,
378 %bucket_name,
379 "Delete model deployment card from NATS"
380 );
381 nats_client
382 .object_store_delete_bucket(bucket_name.as_ref())
383 .await
384 }
385
386 pub fn set_name(&mut self, name: &str) {
389 self.display_name = name.to_string();
390 self.slug = Slug::from_string(name);
391 }
392
393 pub fn load_from_disk(
398 config_path: impl AsRef<Path>,
399 custom_template_path: Option<&Path>,
400 ) -> anyhow::Result<ModelDeploymentCard> {
401 let config_path = config_path.as_ref();
402 if config_path.is_dir() {
403 Self::from_local_path(config_path, custom_template_path)
404 } else {
405 if custom_template_path.is_some() {
407 anyhow::bail!("Custom templates are not supported for GGUF files");
408 }
409 Self::from_gguf(config_path)
410 }
411 }
412
413 pub async fn load_from_store(
416 model_slug: &Slug,
417 drt: &DistributedRuntime,
418 ) -> anyhow::Result<Option<Self>> {
419 let Some(etcd_client) = drt.etcd_client() else {
420 anyhow::bail!("Missing etcd_client");
422 };
423 let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
424 let card_store = Arc::new(KeyValueStoreManager::new(store));
425 let Some(mut card) = card_store
426 .load::<ModelDeploymentCard>(ROOT_PATH, model_slug)
427 .await?
428 else {
429 return Ok(None);
430 };
431 card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
433 Ok(Some(card))
434 }
435
436 fn from_local_path(
452 local_root_dir: impl AsRef<Path>,
453 custom_template_path: Option<&Path>,
454 ) -> anyhow::Result<Self> {
455 let local_root_dir = local_root_dir.as_ref();
456 check_valid_local_repo_path(local_root_dir)?;
457 let repo_id = local_root_dir
458 .canonicalize()?
459 .to_str()
460 .ok_or_else(|| anyhow::anyhow!("Path contains invalid Unicode"))?
461 .to_string();
462 let model_name = local_root_dir
463 .file_name()
464 .and_then(|n| n.to_str())
465 .ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
466
467 Self::from_repo(&repo_id, model_name, custom_template_path)
468 }
469
470 fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
471 let model_name = gguf_file
472 .iter()
473 .next_back()
474 .map(|n| n.to_string_lossy().to_string());
475 let Some(model_name) = model_name else {
476 anyhow::bail!(
478 "Could not extract model name from path '{}'",
479 gguf_file.display()
480 );
481 };
482
483 let content = load_gguf(gguf_file)?;
485 let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
486 .to_u32()
487 .unwrap_or(0);
488 tracing::debug!(context_length, "Loaded context length from GGUF");
489
490 Ok(Self {
491 display_name: model_name.to_string(),
492 slug: Slug::from_string(model_name),
493 model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())),
494 tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
495 gen_config: None, prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
497 chat_template_file: None,
498 prompt_context: None, revision: 0,
500 last_published: None,
501 context_length,
502 kv_cache_block_size: 0,
503 migration_limit: 0,
504 user_data: None,
505 runtime_config: ModelRuntimeConfig::default(),
506 cache_dir: None,
507 })
508 }
509
510 fn from_repo(
511 repo_id: &str,
512 model_name: &str,
513 custom_template_path: Option<&Path>,
514 ) -> anyhow::Result<Self> {
515 let context_length = crate::file_json_field(
517 &PathBuf::from(repo_id).join("config.json"),
518 "max_position_embeddings",
519 )
520 .or_else(|_| {
522 crate::file_json_field(
523 &PathBuf::from(repo_id).join("tokenizer_config.json"),
524 "model_max_length",
525 )
526 })
527 .unwrap_or(0);
529
530 let chat_template_file = if let Some(template_path) = custom_template_path {
532 if !template_path.exists() {
533 anyhow::bail!(
534 "Custom template file does not exist: {}",
535 template_path.display()
536 );
537 }
538
539 let _template_content = std::fs::read_to_string(template_path).with_context(|| {
541 format!(
542 "Failed to read custom template file: {}",
543 template_path.display()
544 )
545 })?;
546
547 Some(PromptFormatterArtifact::HfChatTemplate(
548 CheckedFile::from_disk(template_path)?,
549 ))
550 } else {
551 PromptFormatterArtifact::chat_template_from_repo(repo_id)?
552 };
553
554 Ok(Self {
555 display_name: model_name.to_string(),
556 slug: Slug::from_string(model_name),
557 model_info: Some(ModelInfoType::from_repo(repo_id)?),
558 tokenizer: Some(TokenizerKind::from_repo(repo_id)?),
559 gen_config: GenerationConfig::from_repo(repo_id).ok(), prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
561 chat_template_file,
562 prompt_context: None, revision: 0,
564 last_published: None,
565 context_length,
566 kv_cache_block_size: 0, migration_limit: 0,
568 user_data: None,
569 runtime_config: ModelRuntimeConfig::default(),
570 cache_dir: None,
571 })
572 }
573}
574
575impl Versioned for ModelDeploymentCard {
576 fn revision(&self) -> u64 {
577 self.revision
578 }
579
580 fn set_revision(&mut self, revision: u64) {
581 self.last_published = Some(chrono::Utc::now());
582 self.revision = revision;
583 }
584}
585
586impl fmt::Display for ModelDeploymentCard {
587 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588 write!(f, "{}", self.slug())
589 }
590}
591pub trait ModelInfo: Send + Sync {
592 fn model_type(&self) -> String;
594
595 fn bos_token_id(&self) -> TokenIdType;
597
598 fn eos_token_ids(&self) -> Vec<TokenIdType>;
600
601 fn max_position_embeddings(&self) -> Option<usize>;
604
605 fn vocab_size(&self) -> Option<usize>;
608}
609
610impl ModelInfoType {
611 pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
612 match self {
613 Self::HfConfigJson(checked_file) => {
614 let Some(path) = checked_file.path() else {
615 anyhow::bail!("model info is not a local path: {checked_file:?}");
616 };
617 Ok(HFConfig::from_json_file(path)?)
618 }
619 Self::GGUF(path) => Ok(HFConfig::from_gguf(path)?),
620 }
621 }
622 pub fn is_gguf(&self) -> bool {
623 matches!(self, Self::GGUF(_))
624 }
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628struct HFConfig {
629 architectures: Vec<String>,
632
633 model_type: String,
635
636 text_config: Option<HFTextConfig>,
637
638 eos_token_id: Option<serde_json::Value>,
640}
641
642#[derive(Debug, Clone, Serialize, Deserialize)]
643struct HFTextConfig {
644 bos_token_id: Option<TokenIdType>,
646
647 #[serde(default)]
649 final_bos_token_id: TokenIdType,
650
651 eos_token_id: Option<serde_json::Value>,
652
653 #[serde(default)]
654 final_eos_token_ids: Vec<TokenIdType>,
655
656 max_position_embeddings: Option<usize>,
658
659 num_hidden_layers: usize,
661
662 num_attention_heads: Option<usize>,
664
665 vocab_size: Option<usize>,
667}
668
669impl HFConfig {
670 fn from_json_file<P: AsRef<Path>>(file: P) -> Result<Arc<dyn ModelInfo>> {
671 let file_path = file.as_ref();
672 let contents = std::fs::read_to_string(file_path)?;
673 let mut config: Self = json_five::from_str(&contents)
674 .inspect_err(|err| {
675 tracing::error!(path=%file_path.display(), %err, "Failed to parse config.json as JSON5");
676 })?;
677 if config.text_config.is_none() {
678 let text_config: HFTextConfig = json_five::from_str(&contents)
679 .inspect_err(|err| {
680 tracing::error!(path=%file_path.display(), %err, "Failed to parse text config from config.json as JSON5");
681 })?;
682 config.text_config = Some(text_config);
683 }
684
685 let Some(text_config) = config.text_config.as_mut() else {
687 anyhow::bail!(
688 "Missing text config fields (model_type, eos_token_ids, etc) in config.json"
689 );
690 };
691
692 let gencfg_path = file_path
693 .parent()
694 .unwrap_or_else(|| Path::new(""))
695 .join("generation_config.json");
696 if text_config.bos_token_id.is_none() {
697 let bos_token_id = crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id")
698 .context(
699 "missing bos_token_id in generation_config.json and config.json, cannot load",
700 )?;
701 text_config.bos_token_id = Some(bos_token_id);
702 }
703 let final_bos_token_id = text_config.bos_token_id.take().unwrap();
705 text_config.final_bos_token_id = final_bos_token_id;
706
707 let final_eos_token_ids: Vec<TokenIdType> = config
709 .eos_token_id
710 .as_ref()
711 .or(text_config.eos_token_id.as_ref())
712 .and_then(|v| {
713 if v.is_number() {
714 v.as_number()
715 .and_then(|n| n.as_u64())
716 .map(|n| vec![n as TokenIdType])
717 } else if v.is_array() {
718 let arr = v.as_array().unwrap(); Some(
720 arr.iter()
721 .filter_map(|inner_v| {
722 inner_v
723 .as_number()
724 .and_then(|n| n.as_u64())
725 .map(|n| n as TokenIdType)
726 })
727 .collect(),
728 )
729 } else {
730 tracing::error!(
731 ?v,
732 path = %file_path.display(),
733 "eos_token_id is not a number or an array, cannot use"
734 );
735 None
736 }
737 })
738 .or_else(|| {
739 crate::file_json_field(&gencfg_path, "eos_token_id")
741 .inspect_err(
742 |err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
743 )
744 .ok()
745 })
746 .ok_or_else(|| {
747 anyhow::anyhow!(
748 "missing eos_token_id in config.json and generation_config.json, cannot load"
749 )
750 })?;
751 text_config.final_eos_token_ids = final_eos_token_ids;
752
753 Ok(Arc::new(config))
754 }
755 fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
756 let content = load_gguf(gguf_file)?;
757 let model_config_metadata: ContentConfig = (&content).into();
758 let num_hidden_layers =
759 content.get_metadata()[&format!("{}.block_count", content.arch())].to_u32()? as usize;
760
761 let bos_token_id = content.get_metadata()["tokenizer.ggml.bos_token_id"].to_u32()?;
762 let eos_token_id = content.get_metadata()["tokenizer.ggml.eos_token_id"].to_u32()?;
763
764 let vocab_size = content.get_metadata()["tokenizer.ggml.tokens"]
766 .to_vec()?
767 .len();
768
769 let arch = content.arch().to_string();
770 Ok(Arc::new(HFConfig {
771 architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
772 model_type: arch,
774 text_config: Some(HFTextConfig {
775 bos_token_id: None,
776 final_bos_token_id: bos_token_id,
777
778 eos_token_id: None,
779 final_eos_token_ids: vec![eos_token_id],
780
781 max_position_embeddings: Some(model_config_metadata.max_seq_len()),
783 num_hidden_layers,
785 num_attention_heads: Some(model_config_metadata.num_attn_heads()),
787 vocab_size: Some(vocab_size),
789 }),
790 eos_token_id: None,
791 }))
792 }
793}
794
795impl ModelInfo for HFConfig {
796 fn model_type(&self) -> String {
797 self.model_type.clone()
798 }
799
800 fn bos_token_id(&self) -> TokenIdType {
801 self.text_config.as_ref().unwrap().final_bos_token_id
802 }
803
804 fn eos_token_ids(&self) -> Vec<TokenIdType> {
805 self.text_config
806 .as_ref()
807 .unwrap()
808 .final_eos_token_ids
809 .clone()
810 }
811
812 fn max_position_embeddings(&self) -> Option<usize> {
813 self.text_config.as_ref().unwrap().max_position_embeddings
814 }
815
816 fn vocab_size(&self) -> Option<usize> {
817 self.text_config.as_ref().unwrap().vocab_size
818 }
819}
820
821impl TokenizerKind {
822 pub fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
823 let content = load_gguf(gguf_file)?;
824 let out = crate::gguf::convert_gguf_to_hf_tokenizer(&content)
825 .with_context(|| gguf_file.display().to_string())?;
826 Ok(TokenizerKind::GGUF(Box::new(out.tokenizer)))
827 }
828}
829
830pub(crate) fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
831 let filename = gguf_file.display().to_string();
832 let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
833 let mut readers = vec![&mut f];
835 crate::gguf::Content::from_readers(&mut readers).with_context(|| filename.clone())
836}
837
838fn capitalize(s: &str) -> String {
839 let mut chars = s.chars();
840 match chars.next() {
841 None => String::new(),
842 Some(first) => first.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase(),
843 }
844}
845
846impl ModelInfoType {
847 pub fn from_repo(repo_id: &str) -> Result<Self> {
848 let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("config.json"))
849 .with_context(|| format!("unable to extract config.json from repo {repo_id}"))?;
850 Ok(Self::HfConfigJson(f))
851 }
852}
853
854impl GenerationConfig {
855 pub fn from_repo(repo_id: &str) -> Result<Self> {
856 let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("generation_config.json"))
857 .with_context(|| format!("unable to extract generation_config from repo {repo_id}"))?;
858 Ok(Self::HfGenerationConfigJson(f))
859 }
860}
861
862impl PromptFormatterArtifact {
863 pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
864 match CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer_config.json")) {
867 Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))),
868 Err(_) => Ok(None),
869 }
870 }
871
872 pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
873 match CheckedFile::from_disk(PathBuf::from(repo_id).join("chat_template.jinja")) {
874 Ok(f) => Ok(Some(Self::HfChatTemplate(f))),
875 Err(_) => Ok(None),
876 }
877 }
878}
879
880impl TokenizerKind {
881 pub fn from_repo(repo_id: &str) -> Result<Self> {
882 let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer.json"))
883 .with_context(|| format!("unable to extract tokenizer kind from repo {repo_id}"))?;
884 Ok(Self::HfTokenizerJson(f))
885 }
886}
887
888fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
896 let path = path.as_ref();
897 if !path.exists() {
898 return Err(anyhow::anyhow!(
899 "Model path does not exist: {}",
900 path.display()
901 ));
902 }
903
904 if !path.is_dir() {
905 return Err(anyhow::anyhow!(
906 "Model path is not a directory: {}",
907 path.display()
908 ));
909 }
910 Ok(())
911}
912
913#[cfg(test)]
914mod tests {
915 use super::HFConfig;
916 use std::path::Path;
917
918 #[test]
919 pub fn test_config_json_llama3() -> anyhow::Result<()> {
920 let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
921 .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
922 let config = HFConfig::from_json_file(&config_file)?;
923 assert_eq!(config.bos_token_id(), 128000);
924 Ok(())
925 }
926
927 #[test]
928 pub fn test_config_json_llama4() -> anyhow::Result<()> {
929 let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
930 .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
931 let config = HFConfig::from_json_file(&config_file)?;
932 assert_eq!(config.bos_token_id(), 200000);
933 Ok(())
934 }
935
936 #[test]
940 fn test_invalid_json_but_py_accepts_it() {
941 dynamo_runtime::logging::init();
942 let path = "tests/data/sample-models/NVIDIA-Nemotron-Nano-12B-v2-Base/config.json";
943 let _ = HFConfig::from_json_file(path).unwrap();
944 }
945}