use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::{slug::Slug, storage::kv};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
use crate::protocols::TokenIdType;
pub const ROOT_PATH: &str = "v1/mdc";
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ModelInfoType {
HfConfigJson(CheckedFile),
}
impl ModelInfoType {
pub fn checksum(&self) -> String {
match self {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
}
}
pub fn is_local(&self) -> bool {
match self {
ModelInfoType::HfConfigJson(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
ModelInfoType::HfConfigJson(c) => c.update_dir(dir),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerKind {
HfTokenizerJson(CheckedFile),
TikTokenModel(CheckedFile),
}
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => {
c.checksum().to_string()
}
}
}
pub fn is_local(&self) -> bool {
match self {
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
TokenizerKind::HfTokenizerJson(c) | TokenizerKind::TikTokenModel(c) => {
c.update_dir(dir)
}
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(CheckedFile),
HfChatTemplate { is_custom: bool, file: CheckedFile },
}
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.checksum().to_string(),
}
}
pub fn is_local(&self) -> bool {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.is_local(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.update_dir(dir),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.update_dir(dir),
}
}
pub fn is_custom(&self) -> bool {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(_) => false,
PromptFormatterArtifact::HfChatTemplate { is_custom, .. } => *is_custom,
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum PromptContextMixin {
OaiChat,
Llama3DateTime,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum GenerationConfig {
HfGenerationConfigJson(CheckedFile),
}
impl GenerationConfig {
pub fn checksum(&self) -> String {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
}
}
pub fn is_local(&self) -> bool {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.update_dir(dir),
}
}
}
fn is_exclusively_mistral_model(directory: &Path) -> bool {
!directory.join("config.json").exists() && directory.join("params.json").exists()
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
pub display_name: String,
slug: Slug,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source_path: Option<String>,
pub model_info: Option<ModelInfoType>,
pub tokenizer: Option<TokenizerKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_file: Option<PromptFormatterArtifact>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_context: Option<Vec<PromptContextMixin>>,
pub context_length: u32,
pub kv_cache_block_size: u32,
pub migration_limit: u32,
pub model_type: ModelType,
pub model_input: ModelInput,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lora: Option<LoraInfo>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>,
#[serde(default)]
pub runtime_config: ModelRuntimeConfig,
#[serde(default)]
pub media_decoder: Option<MediaDecoder>,
#[serde(default)]
pub media_fetcher: Option<MediaFetcher>,
#[serde(skip, default)]
checksum: OnceLock<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct LoraInfo {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_gpu_lora_count: Option<u32>,
}
impl ModelDeploymentCard {
pub fn builder() -> ModelDeploymentCardBuilder {
ModelDeploymentCardBuilder::default()
}
pub fn with_name_only(name: &str) -> ModelDeploymentCard {
ModelDeploymentCard {
display_name: name.to_string(),
slug: Slug::from_string(name),
..Default::default()
}
}
pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
let contents = std::fs::read_to_string(&file)?;
Ok(serde_json::from_str(&contents).inspect_err(|err| {
crate::log_json_err(&file.as_ref().display().to_string(), &contents, err)
})?)
}
pub fn load_from_json_str(contents: &str) -> Result<Self, anyhow::Error> {
Ok(serde_json::from_str(contents)
.inspect_err(|err| crate::log_json_err("unknown", contents, err))?)
}
pub fn save_to_json_file(&self, file: &str) -> Result<(), anyhow::Error> {
std::fs::write(file, self.to_json()?)?;
Ok(())
}
#[inline]
pub fn name(&self) -> &str {
&self.display_name
}
#[inline]
pub fn slug(&self) -> &Slug {
&self.slug
}
pub fn to_json(&self) -> Result<String, anyhow::Error> {
Ok(serde_json::to_string(self)?)
}
pub fn mdcsum(&self) -> &str {
self.checksum
.get_or_init(|| {
let mut bytes_to_hash: Vec<u8> = Vec::with_capacity(512);
bytes_to_hash.extend(self.display_name.as_bytes());
if let Some(source_path) = self.source_path.as_ref() {
bytes_to_hash.extend(source_path.as_bytes());
}
if let Some(model_info) = self.model_info.as_ref() {
bytes_to_hash.extend(model_info.checksum().as_bytes());
}
if let Some(tokenizer) = self.tokenizer.as_ref() {
bytes_to_hash.extend(tokenizer.checksum().as_bytes());
}
if let Some(prompt_formatter) = self.prompt_formatter.as_ref() {
bytes_to_hash.extend(prompt_formatter.checksum().as_bytes());
}
if let Some(chat_template) = self.chat_template_file.as_ref() {
bytes_to_hash.extend(chat_template.checksum().as_bytes());
}
if let Some(gen_config) = self.gen_config.as_ref() {
bytes_to_hash.extend(gen_config.checksum().as_bytes());
}
if let Some(prompt_context_vec) = self.prompt_context.as_ref() {
bytes_to_hash.extend(format!("{prompt_context_vec:?}").as_bytes());
}
bytes_to_hash.extend(self.context_length.to_be_bytes());
bytes_to_hash.extend(self.kv_cache_block_size.to_be_bytes());
blake3::hash(&bytes_to_hash).to_string()
})
.as_ref()
}
pub fn has_tokenizer(&self) -> bool {
self.tokenizer.is_some()
}
pub fn tokenizer(&self) -> anyhow::Result<crate::tokenizers::Tokenizer> {
match &self.tokenizer {
Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
let p = checked_file.path().ok_or_else(|| {
anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
})?;
let hf = HfTokenizer::from_file(p)
.inspect_err(|err| {
if let Some(serde_err) = err.downcast_ref::<serde_json::Error>()
&& let Ok(contents) = std::fs::read_to_string(p)
{
crate::log_json_err(&p.display().to_string(), &contents, serde_err);
}
})
.map_err(anyhow::Error::msg)
.with_context(|| p.display().to_string())?;
Ok(crate::tokenizers::Tokenizer::from(Arc::new(
crate::tokenizers::HuggingFaceTokenizer::from_tokenizer(hf),
)))
}
Some(TokenizerKind::TikTokenModel(checked_file)) => {
let p = checked_file.path().ok_or_else(|| {
anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
})?;
let path_str = p.to_str().ok_or_else(|| {
anyhow::anyhow!("Tokenizer path contains invalid UTF-8: {}", p.display())
})?;
let tokenizer = crate::tokenizers::TikTokenTokenizer::from_file_auto(path_str)
.with_context(|| {
format!("Failed to load tiktoken tokenizer from {}", p.display())
})?;
Ok(crate::tokenizers::Tokenizer::from(Arc::new(tokenizer)))
}
None => {
anyhow::bail!(
"Blank ModelDeploymentCard does not have a tokenizer. Is this a mistral model? If so, the `--use-<framework>-tokenizer` flag in the engine command is required."
);
}
}
}
pub(crate) fn set_source_path(&mut self, source_path: PathBuf) {
self.source_path = Some(source_path.display().to_string());
}
pub fn set_name(&mut self, name: &str) {
self.display_name = name.to_string();
self.slug = Slug::from_string(name);
}
pub fn source_path(&self) -> &str {
self.source_path.as_ref().unwrap_or(&self.display_name)
}
pub fn load_from_disk(
config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> {
Self::from_local_path(config_path.as_ref(), custom_template_path)
}
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_input, ModelInput::Tokens)
}
pub async fn download_config(&mut self) -> anyhow::Result<()> {
if self.has_local_files() {
tracing::trace!("All model config is local, not downloading");
return Ok(());
}
if self.model_type.supports_tensor() {
tracing::debug!(
display_name = %self.display_name,
"Skipping config download for TensorBased model"
);
return Ok(());
}
let ignore_weights = true;
let local_path = crate::hub::from_hf(self.source_path(), ignore_weights).await?;
self.update_dir(&local_path);
Ok(())
}
pub fn move_to_url(&mut self, base_url: &str) -> anyhow::Result<()> {
macro_rules! change {
($field:expr, $enum_variant:path) => {
if let Some($enum_variant(src_file)) = $field.as_mut()
&& let Some(filename) = src_file
.path()
.and_then(|p| p.file_name())
.and_then(|f| f.to_str())
.map(|f| f.to_string())
{
let hf_url = url::Url::parse(base_url)
.and_then(|u| u.join(filename.as_ref()))
.context(filename)?;
src_file.move_to_url(hf_url);
}
};
}
change!(self.model_info, ModelInfoType::HfConfigJson);
change!(self.gen_config, GenerationConfig::HfGenerationConfigJson);
change!(
self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson
);
change!(self.tokenizer, TokenizerKind::HfTokenizerJson);
change!(self.tokenizer, TokenizerKind::TikTokenModel);
if let Some(PromptFormatterArtifact::HfChatTemplate {
file: src_file,
is_custom,
}) = self.chat_template_file.as_mut()
{
if *is_custom {
tracing::info!(
"Detected custom chat template. Ensure file exists in the same location on all hosts."
);
} else if let Some(filename) = src_file
.path()
.and_then(|p| p.file_name())
.and_then(|f| f.to_str())
.map(|f| f.to_string())
{
let hf_url = url::Url::parse(base_url)
.and_then(|u| u.join(filename.as_ref()))
.context(filename)?;
src_file.move_to_url(hf_url);
}
}
Ok(())
}
fn has_local_files(&self) -> bool {
let has_model_info = self
.model_info
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_tokenizer = self
.tokenizer
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_prompt_formatter = self
.prompt_formatter
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_chat_template_file = self
.chat_template_file
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_gen_config = self
.gen_config
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
has_model_info
&& has_tokenizer
&& has_prompt_formatter
&& has_chat_template_file
&& has_gen_config
}
fn update_dir(&mut self, dir: &Path) {
if let Some(model_info) = self.model_info.as_mut() {
model_info.update_dir(dir);
}
if let Some(tk) = self.tokenizer.as_mut() {
tk.update_dir(dir);
}
if let Some(pf) = self.prompt_formatter.as_mut() {
pf.update_dir(dir);
}
if let Some(gc) = self.gen_config.as_mut() {
gc.update_dir(dir);
}
if let Some(ct) = self.chat_template_file.as_mut()
&& !ct.is_custom()
{
ct.update_dir(dir);
}
}
fn from_local_path(
local_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
check_valid_local_repo_path(&local_path)?;
Self::from_repo_checkout(&local_path, custom_template_path)
}
fn from_repo_checkout(
local_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
let local_path = local_path.as_ref();
let context_length =
crate::file_json_field(&local_path.join("config.json"), "max_position_embeddings")
.or_else(|_| {
crate::file_json_field(
&local_path.join("tokenizer_config.json"),
"model_max_length",
)
})
.unwrap_or(0);
let is_mistral_model = is_exclusively_mistral_model(local_path);
let (model_info, tokenizer, gen_config, prompt_formatter) = if !is_mistral_model {
(
Some(ModelInfoType::from_disk(local_path)?),
Some(TokenizerKind::from_disk(local_path)?),
GenerationConfig::from_disk(local_path).ok(),
PromptFormatterArtifact::from_disk(local_path)?,
)
} else {
(None, None, None, None)
};
let chat_template_file = if is_mistral_model {
None
} else if let Some(template_path) = custom_template_path {
if !template_path.exists() {
anyhow::bail!(
"Custom template file does not exist: {}",
template_path.display()
);
}
let _template_content = std::fs::read_to_string(template_path).with_context(|| {
format!(
"Failed to read custom template file: {}",
template_path.display()
)
})?;
Some(PromptFormatterArtifact::HfChatTemplate {
is_custom: custom_template_path.is_some(),
file: CheckedFile::from_disk(template_path)?,
})
} else {
PromptFormatterArtifact::chat_template_from_disk(local_path)?
};
let display_name = local_path.display().to_string();
Ok(Self {
slug: Slug::from_string(&display_name),
display_name,
source_path: None,
model_info,
tokenizer,
gen_config,
prompt_formatter,
chat_template_file,
prompt_context: None, context_length,
kv_cache_block_size: 0, migration_limit: 0,
model_type: Default::default(), model_input: Default::default(), lora: None,
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
media_decoder: None,
media_fetcher: None,
checksum: OnceLock::new(),
})
}
}
impl PartialEq for ModelDeploymentCard {
fn eq(&self, other: &ModelDeploymentCard) -> bool {
self.mdcsum() == other.mdcsum()
}
}
impl kv::Versioned for ModelDeploymentCard {
fn revision(&self) -> u64 {
0
}
fn set_revision(&mut self, _revision: u64) {}
}
impl fmt::Display for ModelDeploymentCard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.slug())
}
}
pub trait ModelInfo: Send + Sync {
fn model_type(&self) -> String;
fn bos_token_id(&self) -> Option<TokenIdType>;
fn eos_token_ids(&self) -> Vec<TokenIdType>;
fn max_position_embeddings(&self) -> Option<usize>;
fn vocab_size(&self) -> Option<usize>;
}
impl ModelInfoType {
pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
match self {
Self::HfConfigJson(checked_file) => {
let Some(path) = checked_file.path() else {
anyhow::bail!("model info is not a local path: {checked_file:?}");
};
Ok(HFConfig::from_json_file(path)?)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFConfig {
architectures: Vec<String>,
model_type: String,
text_config: Option<HFTextConfig>,
eos_token_id: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig {
bos_token_id: Option<TokenIdType>,
eos_token_id: Option<serde_json::Value>,
#[serde(default)]
final_eos_token_ids: Vec<TokenIdType>,
max_position_embeddings: Option<usize>,
num_hidden_layers: Option<usize>,
num_attention_heads: Option<usize>,
vocab_size: Option<usize>,
}
impl HFConfig {
fn from_json_file<P: AsRef<Path>>(file: P) -> Result<Arc<dyn ModelInfo>> {
let file_path = file.as_ref();
let contents = std::fs::read_to_string(file_path)?;
let mut config: Self = json_five::from_str(&contents)
.inspect_err(|err| {
tracing::error!(path=%file_path.display(), %err, "Failed to parse config.json as JSON5");
})?;
if config.text_config.is_none() {
let text_config: HFTextConfig = json_five::from_str(&contents)
.inspect_err(|err| {
tracing::error!(path=%file_path.display(), %err, "Failed to parse text config from config.json as JSON5");
})?;
config.text_config = Some(text_config);
}
let Some(text_config) = config.text_config.as_mut() else {
anyhow::bail!(
"Missing text config fields (model_type, eos_token_ids, etc) in config.json"
);
};
let gencfg_path = file_path
.parent()
.unwrap_or_else(|| Path::new(""))
.join("generation_config.json");
if text_config.bos_token_id.is_none() {
text_config.bos_token_id =
crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id").ok();
}
let final_eos_token_ids: Vec<TokenIdType> = {
crate::file_json_field::<serde_json::Value>(&gencfg_path, "eos_token_id")
.inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
)
.ok().and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else if v.is_array() {
let arr = v.as_array().unwrap();
Some(
arr.iter()
.filter_map(|inner_v| {
inner_v
.as_number()
.and_then(|n| n.as_u64())
.map(|n| n as TokenIdType)
})
.collect(),
)
} else {
None
}
})
}.or_else(|| {
config
.eos_token_id
.as_ref()
.or(text_config.eos_token_id.as_ref())
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else {
serde_json::from_value(v.clone())
.map(Some)
.unwrap_or_else(|err| {
tracing::error!(
?v,
path = %file_path.display(),
"eos_token_id is not a number or an array, cannot deserialize: {err}",
);
None
})
}
})
})
.ok_or_else(|| {
anyhow::anyhow!(
"missing eos_token_id in config.json and generation_config.json, cannot load"
)
})?;
text_config.final_eos_token_ids = final_eos_token_ids;
Ok(Arc::new(config))
}
}
impl ModelInfo for HFConfig {
fn model_type(&self) -> String {
self.model_type.clone()
}
fn bos_token_id(&self) -> Option<TokenIdType> {
self.text_config.as_ref().and_then(|tc| tc.bos_token_id)
}
fn eos_token_ids(&self) -> Vec<TokenIdType> {
self.text_config
.as_ref()
.unwrap()
.final_eos_token_ids
.clone()
}
fn max_position_embeddings(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().max_position_embeddings
}
fn vocab_size(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().vocab_size
}
}
impl ModelInfoType {
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("config.json")).with_context(|| {
format!(
"unable to extract config.json from directory {}",
directory.display()
)
})?;
Ok(Self::HfConfigJson(f))
}
}
impl GenerationConfig {
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("generation_config.json")).with_context(
|| {
format!(
"unable to extract generation_config from directory {}",
directory.display()
)
},
)?;
Ok(Self::HfGenerationConfigJson(f))
}
}
impl PromptFormatterArtifact {
pub fn from_disk(directory: &Path) -> Result<Option<Self>> {
match CheckedFile::from_disk(directory.join("tokenizer_config.json")) {
Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))),
Err(_) => Ok(None),
}
}
pub fn chat_template_from_disk(directory: &Path) -> Result<Option<Self>> {
match CheckedFile::from_disk(directory.join("chat_template.jinja")) {
Ok(f) => Ok(Some(Self::HfChatTemplate {
file: f,
is_custom: false,
})),
Err(_) => Ok(None),
}
}
}
impl TokenizerKind {
pub fn from_disk(directory: &Path) -> Result<Self> {
if let Ok(f) = CheckedFile::from_disk(directory.join("tokenizer.json")) {
return Ok(Self::HfTokenizerJson(f));
}
if let Ok(f) = CheckedFile::from_disk(directory.join("tiktoken.model")) {
return Ok(Self::TikTokenModel(f));
}
let tiktoken_files: Vec<_> = std::fs::read_dir(directory)
.into_iter()
.flatten()
.flatten()
.filter(|entry| entry.path().extension().is_some_and(|e| e == "tiktoken"))
.collect();
if tiktoken_files.len() == 1 {
if let Ok(f) = CheckedFile::from_disk(tiktoken_files[0].path()) {
return Ok(Self::TikTokenModel(f));
}
} else if tiktoken_files.len() > 1 {
let names: Vec<_> = tiktoken_files
.iter()
.map(|e| e.path().display().to_string())
.collect();
anyhow::bail!(
"Multiple .tiktoken files found in {}: {:?}. Cannot determine which to use.",
directory.display(),
names
);
}
anyhow::bail!(
"No tokenizer.json or tiktoken model file found in {}",
directory.display()
)
}
}
fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
if !path.exists() {
return Err(anyhow::anyhow!(
"Model path does not exist: {}",
path.display()
));
}
if !path.is_dir() {
return Err(anyhow::anyhow!(
"Model path is not a directory: {}",
path.display()
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::HFConfig;
use std::collections::HashSet;
use std::path::Path;
#[test]
pub fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), Some(128000));
let eos_token_id_set: HashSet<_> = config.eos_token_ids().iter().cloned().collect();
assert_eq!(eos_token_id_set, vec![128001, 128009].into_iter().collect());
Ok(())
}
#[test]
pub fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), Some(200000));
Ok(())
}
#[test]
fn test_invalid_json_but_py_accepts_it() {
dynamo_runtime::logging::init();
let path = "tests/data/sample-models/NVIDIA-Nemotron-Nano-12B-v2-Base/config.json";
let _ = HFConfig::from_json_file(path).unwrap();
}
}