1use std::fmt;
17use std::fs::File;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use std::time::Duration;
21
22use anyhow::{Context, Result};
23use derive_builder::Builder;
24use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
25use serde::{Deserialize, Serialize};
26use tokenizers::Tokenizer as HfTokenizer;
27use url::Url;
28
29use crate::gguf::{Content, ContentConfig, ModelConfigLike};
30use crate::protocols::TokenIdType;
31
32const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
34
35#[derive(Serialize, Deserialize, Clone, Debug)]
36#[serde(rename_all = "snake_case")]
37pub enum ModelInfoType {
38 HfConfigJson(String),
39 GGUF(PathBuf),
40}
41
42#[derive(Serialize, Deserialize, Clone, Debug)]
43#[serde(rename_all = "snake_case")]
44pub enum TokenizerKind {
45 HfTokenizerJson(String),
46 GGUF(Box<HfTokenizer>),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug)]
62#[serde(rename_all = "snake_case")]
63pub enum PromptFormatterArtifact {
64 HfTokenizerConfigJson(String),
65 GGUF(PathBuf),
66}
67
68#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
69#[serde(rename_all = "snake_case")]
70pub enum PromptContextMixin {
71 OaiChat,
73
74 Llama3DateTime,
76}
77
78#[derive(Serialize, Deserialize, Clone, Debug)]
79#[serde(rename_all = "snake_case")]
80pub enum GenerationConfig {
81 HfGenerationConfigJson(String),
82 GGUF(PathBuf),
83}
84
85#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
86pub struct ModelDeploymentCard {
87 pub display_name: String,
89
90 pub service_name: String,
93
94 pub model_info: Option<ModelInfoType>,
96
97 pub tokenizer: Option<TokenizerKind>,
99
100 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub prompt_formatter: Option<PromptFormatterArtifact>,
103
104 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub gen_config: Option<GenerationConfig>,
107
108 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub prompt_context: Option<Vec<PromptContextMixin>>,
111
112 pub last_published: Option<chrono::DateTime<chrono::Utc>>,
114
115 #[serde(default, skip_serializing)]
117 pub revision: u64,
118
119 pub context_length: usize,
121
122 pub kv_cache_block_size: usize,
125}
126
127impl ModelDeploymentCard {
128 pub fn builder() -> ModelDeploymentCardBuilder {
129 ModelDeploymentCardBuilder::default()
130 }
131
132 pub fn with_name_only(name: &str) -> ModelDeploymentCard {
138 ModelDeploymentCard {
139 display_name: name.to_string(),
140 service_name: Slug::slugify(name).to_string(),
141 ..Default::default()
142 }
143 }
144
145 pub fn expiry_check_period() -> Duration {
147 match CARD_MAX_AGE.to_std() {
148 Ok(duration) => duration / 3,
149 Err(_) => {
150 unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
152 }
153 }
154 }
155
156 pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
158 Ok(serde_json::from_str(&std::fs::read_to_string(file)?)?)
159 }
160
161 pub fn load_from_json_str(json: &str) -> Result<Self, anyhow::Error> {
163 Ok(serde_json::from_str(json)?)
164 }
165
166 pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
172 std::fs::write(file, self.to_json()?)?;
173 Ok(())
174 }
175
176 pub fn set_service_name(&mut self, service_name: &str) {
177 self.service_name = service_name.to_string();
178 }
179
180 pub fn slug(&self) -> Slug {
181 Slug::from_string(&self.display_name)
182 }
183
184 pub fn to_json(&self) -> Result<String, anyhow::Error> {
186 Ok(serde_json::to_string(self)?)
187 }
188
189 pub fn mdcsum(&self) -> String {
190 let json = self.to_json().unwrap();
191 format!("{}", blake3::hash(json.as_bytes()))
192 }
193
194 pub fn is_expired(&self) -> bool {
196 if let Some(last_published) = self.last_published.as_ref() {
197 chrono::Utc::now() - last_published > CARD_MAX_AGE
198 } else {
199 false
200 }
201 }
202
203 pub fn has_tokenizer(&self) -> bool {
206 self.tokenizer.is_some()
207 }
208
209 pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
210 match &self.tokenizer {
211 Some(TokenizerKind::HfTokenizerJson(file)) => {
212 HfTokenizer::from_file(file).map_err(anyhow::Error::msg)
213 }
214 Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
215 None => {
216 anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
217 }
218 }
219 }
220
221 pub fn is_gguf(&self) -> bool {
222 match &self.model_info {
223 Some(info) => info.is_gguf(),
224 None => false,
225 }
226 }
227
228 pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
231 let nats_addr = nats_client.addr();
232 let bucket_name = self.slug();
233 tracing::debug!(
234 nats_addr,
235 %bucket_name,
236 "Uploading model deployment card fields to NATS"
237 );
238
239 macro_rules! nats_upload {
240 ($field:expr, $enum_variant:path, $filename:literal) => {
241 if let Some($enum_variant(src_file)) = $field.take() {
242 if !nats::is_nats_url(&src_file) {
243 let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
244 nats_client
245 .object_store_upload(
246 &std::path::PathBuf::from(&src_file),
247 url::Url::parse(&target)?,
248 )
249 .await?;
250 $field = Some($enum_variant(target));
251 }
252 }
253 };
254 }
255
256 nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
257 nats_upload!(
258 self.prompt_formatter,
259 PromptFormatterArtifact::HfTokenizerConfigJson,
260 "tokenizer_config.json"
261 );
262 nats_upload!(
263 self.tokenizer,
264 TokenizerKind::HfTokenizerJson,
265 "tokenizer.json"
266 );
267 nats_upload!(
268 self.gen_config,
269 GenerationConfig::HfGenerationConfigJson,
270 "generation_config.json"
271 );
272
273 Ok(())
274 }
275
276 pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
281 let nats_addr = nats_client.addr();
282 let bucket_name = self.slug();
283 let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
284 tracing::debug!(
285 nats_addr,
286 %bucket_name,
287 target_dir = %target_dir.path().display(),
288 "Downloading model deployment card fields from NATS"
289 );
290
291 macro_rules! nats_download {
292 ($field:expr, $enum_variant:path, $filename:literal) => {
293 if let Some($enum_variant(src_url)) = $field.take() {
294 if nats::is_nats_url(&src_url) {
295 let target = target_dir.path().join($filename);
296 nats_client
297 .object_store_download(Url::parse(&src_url)?, &target)
298 .await?;
299 $field = Some($enum_variant(target.display().to_string()));
300 }
301 }
302 };
303 }
304
305 nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
306 nats_download!(
307 self.prompt_formatter,
308 PromptFormatterArtifact::HfTokenizerConfigJson,
309 "tokenizer_config.json"
310 );
311 nats_download!(
312 self.tokenizer,
313 TokenizerKind::HfTokenizerJson,
314 "tokenizer.json"
315 );
316 nats_download!(
317 self.gen_config,
318 GenerationConfig::HfGenerationConfigJson,
319 "generation_config.json"
320 );
321
322 Ok(target_dir)
323 }
324
325 pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
327 let nats_addr = nats_client.addr();
328 let bucket_name = self.slug();
329 tracing::trace!(
330 nats_addr,
331 %bucket_name,
332 "Delete model deployment card from NATS"
333 );
334 nats_client
335 .object_store_delete_bucket(bucket_name.as_ref())
336 .await
337 }
338}
339
340impl Versioned for ModelDeploymentCard {
341 fn revision(&self) -> u64 {
342 self.revision
343 }
344
345 fn set_revision(&mut self, revision: u64) {
346 self.last_published = Some(chrono::Utc::now());
347 self.revision = revision;
348 }
349}
350
351impl fmt::Display for ModelDeploymentCard {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 write!(f, "{}", self.slug())
354 }
355}
356pub trait ModelInfo: Send + Sync {
357 fn model_type(&self) -> String;
359
360 fn bos_token_id(&self) -> TokenIdType;
362
363 fn eos_token_ids(&self) -> Vec<TokenIdType>;
365
366 fn max_position_embeddings(&self) -> Option<usize>;
369
370 fn vocab_size(&self) -> Option<usize>;
373}
374
375impl ModelInfoType {
376 pub async fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
377 match self {
378 Self::HfConfigJson(info) => HFConfig::from_json_file(info).await,
379 Self::GGUF(path) => HFConfig::from_gguf(path),
380 }
381 }
382 pub fn is_gguf(&self) -> bool {
383 matches!(self, Self::GGUF(_))
384 }
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388struct HFConfig {
389 architectures: Vec<String>,
392
393 model_type: String,
395
396 text_config: Option<HFTextConfig>,
397
398 eos_token_id: Option<serde_json::Value>,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403struct HFTextConfig {
404 bos_token_id: Option<TokenIdType>,
406
407 #[serde(default)]
409 final_bos_token_id: TokenIdType,
410
411 eos_token_id: Option<serde_json::Value>,
412
413 #[serde(default)]
414 final_eos_token_ids: Vec<TokenIdType>,
415
416 max_position_embeddings: Option<usize>,
418
419 num_hidden_layers: usize,
421
422 num_attention_heads: Option<usize>,
424
425 vocab_size: Option<usize>,
427}
428
429impl HFConfig {
430 async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
431 let file_pathbuf = PathBuf::from(file);
432 let contents = std::fs::read_to_string(file)?;
433 let mut config: Self = serde_json::from_str(&contents)?;
434 if config.text_config.is_none() {
435 let text_config: HFTextConfig = serde_json::from_str(&contents)?;
436 config.text_config = Some(text_config);
437 }
438 let Some(text_config) = config.text_config.as_mut() else {
440 anyhow::bail!(
441 "Missing text config fields (model_type, eos_token_ids, etc) in config.json"
442 );
443 };
444
445 if text_config.bos_token_id.is_none() {
446 let bos_token_id = crate::file_json_field::<TokenIdType>(
447 &Path::join(
448 file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
449 "generation_config.json",
450 ),
451 "bos_token_id",
452 )
453 .context(
454 "missing bos_token_id in generation_config.json and config.json, cannot load",
455 )?;
456 text_config.bos_token_id = Some(bos_token_id);
457 }
458 let final_bos_token_id = text_config.bos_token_id.take().unwrap();
460 text_config.final_bos_token_id = final_bos_token_id;
461
462 let final_eos_token_ids: Vec<TokenIdType> = config
464 .eos_token_id
465 .as_ref()
466 .or(text_config.eos_token_id.as_ref())
467 .and_then(|v| {
468 if v.is_number() {
469 v.as_number()
470 .and_then(|n| n.as_u64())
471 .map(|n| vec![n as TokenIdType])
472 } else if v.is_array() {
473 let arr = v.as_array().unwrap(); Some(
475 arr.iter()
476 .filter_map(|inner_v| {
477 inner_v
478 .as_number()
479 .and_then(|n| n.as_u64())
480 .map(|n| n as TokenIdType)
481 })
482 .collect(),
483 )
484 } else {
485 tracing::error!(
486 ?v,
487 file,
488 "eos_token_id is not a number or an array, cannot use"
489 );
490 None
491 }
492 })
493 .or_else(|| {
494 crate::file_json_field(
496 &Path::join(
497 file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
498 "generation_config.json",
499 ),
500 "eos_token_id",
501 )
502 .inspect_err(
503 |err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
504 )
505 .ok()
506 })
507 .ok_or_else(|| {
508 anyhow::anyhow!(
509 "missing eos_token_id in config.json and generation_config.json, cannot load"
510 )
511 })?;
512 text_config.final_eos_token_ids = final_eos_token_ids;
513
514 Ok(Arc::new(config))
515 }
516 fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
517 let content = load_gguf(gguf_file)?;
518 let model_config_metadata: ContentConfig = (&content).into();
519 let num_hidden_layers =
520 content.get_metadata()[&format!("{}.block_count", content.arch())].to_u32()? as usize;
521
522 let bos_token_id = content.get_metadata()["tokenizer.ggml.bos_token_id"].to_u32()?;
523 let eos_token_id = content.get_metadata()["tokenizer.ggml.eos_token_id"].to_u32()?;
524
525 let vocab_size = content.get_metadata()["tokenizer.ggml.tokens"]
527 .to_vec()?
528 .len();
529
530 let arch = content.arch().to_string();
531 Ok(Arc::new(HFConfig {
532 architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
533 model_type: arch,
535 text_config: Some(HFTextConfig {
536 bos_token_id: None,
537 final_bos_token_id: bos_token_id,
538
539 eos_token_id: None,
540 final_eos_token_ids: vec![eos_token_id],
541
542 max_position_embeddings: Some(model_config_metadata.max_seq_len()),
544 num_hidden_layers,
546 num_attention_heads: Some(model_config_metadata.num_attn_heads()),
548 vocab_size: Some(vocab_size),
550 }),
551 eos_token_id: None,
552 }))
553 }
554}
555
556impl ModelInfo for HFConfig {
557 fn model_type(&self) -> String {
558 self.model_type.clone()
559 }
560
561 fn bos_token_id(&self) -> TokenIdType {
562 self.text_config.as_ref().unwrap().final_bos_token_id
563 }
564
565 fn eos_token_ids(&self) -> Vec<TokenIdType> {
566 self.text_config
567 .as_ref()
568 .unwrap()
569 .final_eos_token_ids
570 .clone()
571 }
572
573 fn max_position_embeddings(&self) -> Option<usize> {
574 self.text_config.as_ref().unwrap().max_position_embeddings
575 }
576
577 fn vocab_size(&self) -> Option<usize> {
578 self.text_config.as_ref().unwrap().vocab_size
579 }
580}
581
582impl TokenizerKind {
583 pub fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
584 let content = load_gguf(gguf_file)?;
585 let out = crate::gguf::convert_gguf_to_hf_tokenizer(&content)
586 .with_context(|| gguf_file.display().to_string())?;
587 Ok(TokenizerKind::GGUF(Box::new(out.tokenizer)))
588 }
589}
590
591pub(crate) fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
592 let filename = gguf_file.display().to_string();
593 let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
594 let mut readers = vec![&mut f];
596 crate::gguf::Content::from_readers(&mut readers).with_context(|| filename.clone())
597}
598
599fn capitalize(s: &str) -> String {
600 s.chars()
601 .enumerate()
602 .map(|(i, c)| {
603 if i == 0 {
604 c.to_uppercase().to_string()
605 } else {
606 c.to_lowercase().to_string()
607 }
608 })
609 .collect()
610}
611
612#[cfg(test)]
613mod tests {
614 use super::HFConfig;
615 use std::path::Path;
616
617 #[tokio::test]
618 pub async fn test_config_json_llama3() -> anyhow::Result<()> {
619 let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
620 .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
621 let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
622 assert_eq!(config.bos_token_id(), 128000);
623 Ok(())
624 }
625
626 #[tokio::test]
627 pub async fn test_config_json_llama4() -> anyhow::Result<()> {
628 let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
629 .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
630 let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
631 assert_eq!(config.bos_token_id(), 200000);
632 Ok(())
633 }
634}