use std::fmt;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::nats;
use either::Either;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
use url::Url;
use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType;
pub const BUCKET_NAME: &str = "mdc";
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ModelInfoType {
HfConfigJson(String),
GGUF(PathBuf),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerKind {
HfTokenizerJson(String),
GGUF(Box<HfTokenizer>),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(String),
GGUF(PathBuf),
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum PromptContextMixin {
OaiChat,
Llama3DateTime,
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
pub display_name: String,
pub service_name: String,
pub model_info: Option<ModelInfoType>,
pub tokenizer: Option<TokenizerKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_context: Option<Vec<PromptContextMixin>>,
pub last_published: Option<chrono::DateTime<chrono::Utc>>,
#[serde(default, skip_serializing)]
pub revision: u64,
}
impl ModelDeploymentCard {
pub fn builder() -> ModelDeploymentCardBuilder {
ModelDeploymentCardBuilder::default()
}
pub fn with_name_only(name: &str) -> ModelDeploymentCard {
ModelDeploymentCard {
display_name: name.to_string(),
service_name: Slug::slugify(name).to_string(),
..Default::default()
}
}
pub fn service_name_slug(s: &str) -> Slug {
Slug::from_string(s)
}
pub fn expiry_check_period() -> Duration {
match CARD_MAX_AGE.to_std() {
Ok(duration) => duration / 3,
Err(_) => {
unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
}
}
}
pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
Ok(serde_json::from_str(&std::fs::read_to_string(file)?)?)
}
pub fn load_from_json_str(json: &str) -> Result<Self, anyhow::Error> {
Ok(serde_json::from_str(json)?)
}
pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
std::fs::write(file, self.to_json()?)?;
Ok(())
}
pub fn set_service_name(&mut self, service_name: &str) {
self.service_name = service_name.to_string();
}
pub fn slug(&self) -> Slug {
ModelDeploymentCard::service_name_slug(&self.service_name)
}
pub fn to_json(&self) -> Result<String, anyhow::Error> {
Ok(serde_json::to_string(self)?)
}
pub fn mdcsum(&self) -> String {
let json = self.to_json().unwrap();
format!("{}", blake3::hash(json.as_bytes()))
}
pub fn is_expired(&self) -> bool {
if let Some(last_published) = self.last_published.as_ref() {
chrono::Utc::now() - last_published > CARD_MAX_AGE
} else {
false
}
}
pub fn has_tokenizer(&self) -> bool {
self.tokenizer.is_some()
}
pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
match &self.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => {
HfTokenizer::from_file(file).map_err(anyhow::Error::msg)
}
Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
None => {
anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer");
}
}
}
pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
tracing::debug!(
nats_addr,
%bucket_name,
"Uploading model deployment card fields to NATS"
);
if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info {
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/config.json");
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.model_info = Some(ModelInfoType::HfConfigJson(target));
}
}
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_file)) =
self.prompt_formatter
{
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer_config.json");
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.prompt_formatter =
Some(PromptFormatterArtifact::HfTokenizerConfigJson(target));
}
}
if let Some(TokenizerKind::HfTokenizerJson(ref src_file)) = self.tokenizer {
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer.json");
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target));
}
}
Ok(())
}
pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
tracing::debug!(
nats_addr,
%bucket_name,
target_dir = %target_dir.path().display(),
"Downloading model deployment card fields from NATS"
);
if let Some(ModelInfoType::HfConfigJson(ref src_url)) = self.model_info {
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("config.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.model_info = Some(ModelInfoType::HfConfigJson(target.display().to_string()));
}
}
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_url)) =
self.prompt_formatter
{
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("tokenizer_config.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.prompt_formatter = Some(PromptFormatterArtifact::HfTokenizerConfigJson(
target.display().to_string(),
));
}
}
if let Some(TokenizerKind::HfTokenizerJson(ref src_url)) = self.tokenizer {
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("tokenizer.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target.display().to_string()));
}
}
Ok(target_dir)
}
pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
tracing::trace!(
nats_addr,
%bucket_name,
"Delete model deployment card from NATS"
);
nats_client
.object_store_delete_bucket(bucket_name.as_ref())
.await
}
}
impl Versioned for ModelDeploymentCard {
fn revision(&self) -> u64 {
self.revision
}
fn set_revision(&mut self, revision: u64) {
self.last_published = Some(chrono::Utc::now());
self.revision = revision;
}
}
impl fmt::Display for ModelDeploymentCard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.slug())
}
}
pub trait ModelInfo: Send + Sync {
fn model_type(&self) -> String;
fn bos_token_id(&self) -> TokenIdType;
fn eos_token_ids(&self) -> Vec<TokenIdType>;
fn max_position_embeddings(&self) -> usize;
fn vocab_size(&self) -> usize;
}
impl ModelInfoType {
pub async fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
match self {
Self::HfConfigJson(info) => HFConfig::from_json_file(info).await,
Self::GGUF(path) => HFConfig::from_gguf(path),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFConfig {
architectures: Vec<String>,
model_type: String,
text_config: Option<HFTextConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig {
bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
max_position_embeddings: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
vocab_size: usize,
}
impl HFConfig {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let contents = std::fs::read_to_string(file)?;
let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?;
config.text_config = Some(text_config);
}
Ok(Arc::new(config))
}
fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
let content = load_gguf(gguf_file)?;
let model_config_metadata: ContentConfig = (&content).into();
let num_hidden_layers =
content.get_metadata()[&format!("{}.block_count", content.arch())].to_u32()? as usize;
let bos_token_id = content.get_metadata()["tokenizer.ggml.bos_token_id"].to_u32()?;
let eos_token_id = content.get_metadata()["tokenizer.ggml.eos_token_id"].to_u32()?;
let vocab_size = content.get_metadata()["tokenizer.ggml.tokens"]
.to_vec()?
.len();
let arch = content.arch().to_string();
Ok(Arc::new(HFConfig {
architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
model_type: arch,
text_config: Some(HFTextConfig {
bos_token_id,
eos_token_id: Either::Left(eos_token_id),
max_position_embeddings: model_config_metadata.max_seq_len(),
num_hidden_layers,
num_attention_heads: model_config_metadata.num_attn_heads(),
vocab_size,
}),
}))
}
}
impl ModelInfo for HFConfig {
fn model_type(&self) -> String {
self.model_type.clone()
}
fn bos_token_id(&self) -> TokenIdType {
self.text_config.as_ref().unwrap().bos_token_id
}
fn eos_token_ids(&self) -> Vec<TokenIdType> {
match &self.text_config.as_ref().unwrap().eos_token_id {
Either::Left(eos_token_id) => vec![*eos_token_id],
Either::Right(eos_token_ids) => eos_token_ids.clone(),
}
}
fn max_position_embeddings(&self) -> usize {
self.text_config.as_ref().unwrap().max_position_embeddings
}
fn vocab_size(&self) -> usize {
self.text_config.as_ref().unwrap().vocab_size
}
}
impl TokenizerKind {
pub fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
let content = load_gguf(gguf_file)?;
let out = crate::gguf::convert_gguf_to_hf_tokenizer(&content)
.with_context(|| gguf_file.display().to_string())?;
Ok(TokenizerKind::GGUF(Box::new(out.tokenizer)))
}
}
fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
let filename = gguf_file.display().to_string();
let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
let mut readers = vec![&mut f];
crate::gguf::Content::from_readers(&mut readers).with_context(|| filename.clone())
}
fn capitalize(s: &str) -> String {
s.chars()
.enumerate()
.map(|(i, c)| {
if i == 0 {
c.to_uppercase().to_string()
} else {
c.to_lowercase().to_string()
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::HFConfig;
use std::path::Path;
#[tokio::test]
pub async fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 128000);
Ok(())
}
#[tokio::test]
pub async fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 200000);
Ok(())
}
}