1use std::fmt;
29use std::fs::File;
30use std::path::{Path, PathBuf};
31use std::sync::Arc;
32use std::time::Duration;
33
34use anyhow::{Context, Result};
35use derive_builder::Builder;
36use dynamo_runtime::slug::Slug;
37use dynamo_runtime::transports::nats;
38use either::Either;
39use serde::{Deserialize, Serialize};
40use tokenizers::Tokenizer as HfTokenizer;
41use url::Url;
42
43use crate::gguf::{Content, ContentConfig};
44use crate::key_value_store::Versioned;
45use crate::protocols::TokenIdType;
46
47pub const BUCKET_NAME: &str = "mdc";
48
49pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
52
53const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
55
56#[derive(Serialize, Deserialize, Clone, Debug)]
57#[serde(rename_all = "snake_case")]
58pub enum ModelInfoType {
59 HfConfigJson(String),
60 GGUF(PathBuf),
61}
62
63#[derive(Serialize, Deserialize, Clone, Debug)]
64#[serde(rename_all = "snake_case")]
65pub enum TokenizerKind {
66 HfTokenizerJson(String),
67 GGUF(Box<HfTokenizer>),
68}
69
70#[derive(Serialize, Deserialize, Clone, Debug)]
83#[serde(rename_all = "snake_case")]
84pub enum PromptFormatterArtifact {
85 HfTokenizerConfigJson(String),
86 GGUF(PathBuf),
87}
88
89#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
90#[serde(rename_all = "snake_case")]
91pub enum PromptContextMixin {
92 OaiChat,
94
95 Llama3DateTime,
97}
98
99#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
100pub struct ModelDeploymentCard {
101 pub display_name: String,
103
104 pub service_name: String,
107
108 pub model_info: Option<ModelInfoType>,
110
111 pub tokenizer: Option<TokenizerKind>,
113
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub prompt_formatter: Option<PromptFormatterArtifact>,
117
118 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub prompt_context: Option<Vec<PromptContextMixin>>,
121
122 pub last_published: Option<chrono::DateTime<chrono::Utc>>,
124
125 #[serde(default, skip_serializing)]
127 pub revision: u64,
128
129 #[serde(default)]
133 pub requires_preprocessing: bool,
134}
135
136impl ModelDeploymentCard {
137 pub fn builder() -> ModelDeploymentCardBuilder {
138 ModelDeploymentCardBuilder::default()
139 }
140
141 pub fn with_name_only(name: &str) -> ModelDeploymentCard {
147 ModelDeploymentCard {
148 display_name: name.to_string(),
149 service_name: Slug::from_string(name).to_string(),
150 ..Default::default()
151 }
152 }
153
154 pub fn service_name_slug(s: &str) -> Slug {
158 Slug::from_string(s)
159 }
160
161 pub fn expiry_check_period() -> Duration {
163 match CARD_MAX_AGE.to_std() {
164 Ok(duration) => duration / 3,
165 Err(_) => {
166 unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
168 }
169 }
170 }
171
172 pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
174 let mut card: ModelDeploymentCard = serde_json::from_str(&std::fs::read_to_string(file)?)?;
175 card.requires_preprocessing = false;
176 Ok(card)
177 }
178
179 pub fn load_from_json_str(json: &str) -> Result<Self, anyhow::Error> {
181 Ok(serde_json::from_str(json)?)
182 }
183
184 pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
190 std::fs::write(file, self.to_json()?)?;
191 Ok(())
192 }
193
194 pub fn set_service_name(&mut self, service_name: &str) {
195 self.service_name = service_name.to_string();
196 }
197
198 pub fn slug(&self) -> Slug {
199 ModelDeploymentCard::service_name_slug(&self.service_name)
200 }
201
202 pub fn to_json(&self) -> Result<String, anyhow::Error> {
204 Ok(serde_json::to_string(self)?)
205 }
206
207 pub fn mdcsum(&self) -> String {
208 let json = self.to_json().unwrap();
209 format!("{}", blake3::hash(json.as_bytes()))
210 }
211
212 pub fn is_expired(&self) -> bool {
214 if let Some(last_published) = self.last_published.as_ref() {
215 chrono::Utc::now() - last_published > CARD_MAX_AGE
216 } else {
217 false
218 }
219 }
220
221 pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
222 match &self.tokenizer {
223 Some(TokenizerKind::HfTokenizerJson(file)) => {
224 HfTokenizer::from_file(file).map_err(anyhow::Error::msg)
225 }
226 Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
227 None => {
228 anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
229 }
230 }
231 }
232
233 pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
236 let nats_addr = nats_client.addr();
237 let bucket_name = self.slug();
238 tracing::debug!(
239 nats_addr,
240 %bucket_name,
241 "Uploading model deployment card to NATS"
242 );
243
244 if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info {
245 if !nats::is_nats_url(src_file) {
246 let target = format!("nats://{nats_addr}/{bucket_name}/config.json");
247 nats_client
248 .object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
249 .await?;
250 self.model_info = Some(ModelInfoType::HfConfigJson(target));
251 }
252 }
253
254 if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_file)) =
255 self.prompt_formatter
256 {
257 if !nats::is_nats_url(src_file) {
258 let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer_config.json");
259 nats_client
260 .object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
261 .await?;
262 self.prompt_formatter =
263 Some(PromptFormatterArtifact::HfTokenizerConfigJson(target));
264 }
265 }
266
267 if let Some(TokenizerKind::HfTokenizerJson(ref src_file)) = self.tokenizer {
268 if !nats::is_nats_url(src_file) {
269 let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer.json");
270 nats_client
271 .object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
272 .await?;
273 self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target));
274 }
275 }
276
277 Ok(())
278 }
279
280 pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
282 let nats_addr = nats_client.addr();
283 let bucket_name = self.slug();
284 tracing::trace!(
285 nats_addr,
286 %bucket_name,
287 "Delete model deployment card from NATS"
288 );
289 nats_client
290 .object_store_delete_bucket(bucket_name.as_ref())
291 .await
292 }
293}
294
295impl Versioned for ModelDeploymentCard {
296 fn revision(&self) -> u64 {
297 self.revision
298 }
299
300 fn set_revision(&mut self, revision: u64) {
301 self.last_published = Some(chrono::Utc::now());
302 self.revision = revision;
303 }
304}
305
306impl fmt::Display for ModelDeploymentCard {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 write!(f, "{}", self.slug())
309 }
310}
311pub trait ModelInfo: Send + Sync {
312 fn model_type(&self) -> String;
314
315 fn bos_token_id(&self) -> TokenIdType;
317
318 fn eos_token_ids(&self) -> Vec<TokenIdType>;
320
321 fn max_position_embeddings(&self) -> usize;
323
324 fn vocab_size(&self) -> usize;
326}
327
328impl ModelInfoType {
329 pub async fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
330 match self {
331 Self::HfConfigJson(info) => HFConfig::from_json_file(info).await,
332 Self::GGUF(path) => HFConfig::from_gguf(path),
333 }
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338struct HFConfig {
339 bos_token_id: TokenIdType,
340
341 #[serde(with = "either::serde_untagged")]
342 eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
343
344 architectures: Vec<String>,
347
348 model_type: String,
350
351 max_position_embeddings: usize,
353
354 num_hidden_layers: usize,
356
357 num_attention_heads: usize,
359
360 vocab_size: usize,
362}
363
364impl HFConfig {
365 async fn from_json_file(file: &String) -> Result<Arc<dyn ModelInfo>> {
366 let contents = std::fs::read_to_string(file)?;
367 let config: Self = serde_json::from_str(&contents)?;
368 Ok(Arc::new(config))
369 }
370 fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
371 let content = load_gguf(gguf_file)?;
372 let model_config_metadata: ContentConfig = (&content).into();
373 let num_hidden_layers =
374 content.get_metadata()[&format!("{}.block_count", content.arch())].to_u32()? as usize;
375
376 let bos_token_id = content.get_metadata()["tokenizer.ggml.bos_token_id"].to_u32()?;
377 let eos_token_id = content.get_metadata()["tokenizer.ggml.eos_token_id"].to_u32()?;
378
379 let vocab_size = content.get_metadata()["tokenizer.ggml.tokens"]
381 .to_vec()?
382 .len();
383
384 let arch = content.arch().to_string();
385 Ok(Arc::new(HFConfig {
386 bos_token_id,
387 eos_token_id: Either::Left(eos_token_id),
388 architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
389 model_type: arch,
391 max_position_embeddings: model_config_metadata.max_seq_len(),
393 num_hidden_layers,
395 num_attention_heads: model_config_metadata.num_attn_heads(),
397 vocab_size,
399 }))
400 }
401}
402
403impl ModelInfo for HFConfig {
404 fn model_type(&self) -> String {
405 self.model_type.clone()
406 }
407
408 fn bos_token_id(&self) -> TokenIdType {
409 self.bos_token_id
410 }
411
412 fn eos_token_ids(&self) -> Vec<TokenIdType> {
413 match &self.eos_token_id {
414 Either::Left(eos_token_id) => vec![*eos_token_id],
415 Either::Right(eos_token_ids) => eos_token_ids.clone(),
416 }
417 }
418
419 fn max_position_embeddings(&self) -> usize {
420 self.max_position_embeddings
421 }
422
423 fn vocab_size(&self) -> usize {
424 self.vocab_size
425 }
426}
427
428impl TokenizerKind {
429 pub fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
430 let content = load_gguf(gguf_file)?;
431 let out = crate::gguf::convert_gguf_to_hf_tokenizer(&content)
432 .with_context(|| gguf_file.display().to_string())?;
433 Ok(TokenizerKind::GGUF(Box::new(out.tokenizer)))
434 }
435}
436
437fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
438 let filename = gguf_file.display().to_string();
439 let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
440 let mut readers = vec![&mut f];
442 crate::gguf::Content::from_readers(&mut readers).with_context(|| filename.clone())
443}
444
445fn capitalize(s: &str) -> String {
446 s.chars()
447 .enumerate()
448 .map(|(i, c)| {
449 if i == 0 {
450 c.to_uppercase().to_string()
451 } else {
452 c.to_lowercase().to_string()
453 }
454 })
455 .collect()
456}