use serde::Deserialize;
use crate::ir::{Architecture, DType, ModelConfig};
#[derive(Debug, Deserialize)]
pub struct HFConfig {
pub model_type: Option<String>,
pub architectures: Option<Vec<String>>,
pub hidden_size: Option<usize>,
pub intermediate_size: Option<usize>,
pub num_hidden_layers: Option<usize>,
pub num_attention_heads: Option<usize>,
pub num_key_value_heads: Option<usize>,
pub head_dim: Option<usize>,
pub vocab_size: Option<usize>,
pub max_position_embeddings: Option<usize>,
pub rms_norm_eps: Option<f64>,
pub layer_norm_eps: Option<f64>,
pub layer_norm_epsilon: Option<f64>,
pub rope_theta: Option<f64>,
pub torch_dtype: Option<String>,
pub sliding_window: Option<usize>,
}
impl HFConfig {
pub fn from_json(json: &[u8]) -> Result<Self, serde_json::Error> {
serde_json::from_slice(json)
}
pub fn detect_architecture(&self) -> Option<Architecture> {
if let Some(ref model_type) = self.model_type {
match model_type.as_str() {
"llama" => return Some(Architecture::Llama),
"qwen2" => return Some(Architecture::Qwen2),
"mistral" => return Some(Architecture::Mistral),
"phi3" | "phi" => return Some(Architecture::Phi3),
"gemma" | "gemma2" => return Some(Architecture::Gemma),
"stablelm" | "stablelm_epoch" => return Some(Architecture::StableLM),
_ => {}
}
}
if let Some(ref archs) = self.architectures {
for arch in archs {
if arch.contains("Llama") {
return Some(Architecture::Llama);
}
if arch.contains("Qwen2") {
return Some(Architecture::Qwen2);
}
if arch.contains("Mistral") {
return Some(Architecture::Mistral);
}
if arch.contains("Phi3") || arch.contains("Phi") {
return Some(Architecture::Phi3);
}
if arch.contains("Gemma") {
return Some(Architecture::Gemma);
}
if arch.contains("StableLm") || arch.contains("StableLM") {
return Some(Architecture::StableLM);
}
}
}
None
}
pub fn to_model_config(&self) -> Option<ModelConfig> {
let architecture = self.detect_architecture()?;
let hidden_size = self.hidden_size?;
let num_attention_heads = self.num_attention_heads?;
let head_dim = self.head_dim.unwrap_or(hidden_size / num_attention_heads);
let num_kv_heads = self.num_key_value_heads.unwrap_or(num_attention_heads);
let norm_eps = self
.rms_norm_eps
.or(self.layer_norm_eps)
.or(self.layer_norm_epsilon)
.unwrap_or(1e-5);
let dtype = self
.torch_dtype
.as_deref()
.map(parse_torch_dtype)
.unwrap_or(DType::F16);
let qkv_bias = matches!(architecture, Architecture::Qwen2);
let sliding_window_size = self.sliding_window;
let hidden_activation = match architecture {
Architecture::Gemma => crate::ir::HiddenActivation::GeluApprox,
_ => crate::ir::HiddenActivation::SiLU,
};
Some(ModelConfig {
architecture,
hidden_size,
intermediate_size: self.intermediate_size.unwrap_or(hidden_size * 4),
num_layers: self.num_hidden_layers?,
num_attention_heads,
num_kv_heads,
head_dim,
vocab_size: self.vocab_size.unwrap_or(32000),
max_seq_len: self.max_position_embeddings.unwrap_or(2048),
rms_norm_eps: norm_eps as f32,
rope_theta: self.rope_theta.unwrap_or(10000.0) as f32,
dtype,
sliding_window_size,
qkv_bias,
hidden_activation,
})
}
}
fn parse_torch_dtype(s: &str) -> DType {
match s {
"float32" | "fp32" => DType::F32,
"float16" | "fp16" => DType::F16,
"bfloat16" | "bf16" => DType::BF16,
_ => DType::F16,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_llama_config() {
let json = r#"{
"model_type": "llama",
"architectures": ["LlamaForCausalLM"],
"hidden_size": 2048,
"intermediate_size": 5632,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"vocab_size": 32000,
"max_position_embeddings": 2048,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"torch_dtype": "float16"
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
assert_eq!(config.detect_architecture(), Some(Architecture::Llama));
let mc = config.to_model_config().unwrap();
assert_eq!(mc.hidden_size, 2048);
assert_eq!(mc.intermediate_size, 5632);
assert_eq!(mc.num_layers, 16);
assert_eq!(mc.num_attention_heads, 32);
assert_eq!(mc.num_kv_heads, 8);
assert_eq!(mc.head_dim, 64);
assert_eq!(mc.vocab_size, 32000);
assert_eq!(mc.dtype, DType::F16);
assert!(!mc.qkv_bias);
assert_eq!(mc.sliding_window_size, None);
}
#[test]
fn parse_qwen2_config() {
let json = r#"{
"model_type": "qwen2",
"architectures": ["Qwen2ForCausalLM"],
"hidden_size": 1536,
"intermediate_size": 8960,
"num_hidden_layers": 28,
"num_attention_heads": 12,
"num_key_value_heads": 2,
"vocab_size": 151936,
"max_position_embeddings": 32768,
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"torch_dtype": "bfloat16"
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
assert_eq!(config.detect_architecture(), Some(Architecture::Qwen2));
let mc = config.to_model_config().unwrap();
assert_eq!(mc.hidden_size, 1536);
assert_eq!(mc.num_kv_heads, 2);
assert_eq!(mc.dtype, DType::BF16);
assert_eq!(mc.head_dim, 128);
assert!(mc.qkv_bias);
assert_eq!(mc.sliding_window_size, None);
}
#[test]
fn parse_mistral_config_with_sliding_window() {
let json = r#"{
"model_type": "mistral",
"architectures": ["MistralForCausalLM"],
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"vocab_size": 32000,
"max_position_embeddings": 4096,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float16"
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
assert_eq!(config.detect_architecture(), Some(Architecture::Mistral));
let mc = config.to_model_config().unwrap();
assert_eq!(mc.architecture, Architecture::Mistral);
assert_eq!(mc.sliding_window_size, Some(4096));
assert!(!mc.qkv_bias);
}
#[test]
fn parse_smollm_config() {
let json = r#"{
"model_type": "llama",
"architectures": ["LlamaForCausalLM"],
"hidden_size": 576,
"intermediate_size": 1536,
"num_hidden_layers": 30,
"num_attention_heads": 9,
"num_key_value_heads": 3,
"vocab_size": 49152,
"max_position_embeddings": 2048,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"torch_dtype": "bfloat16"
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
let mc = config.to_model_config().unwrap();
assert_eq!(mc.architecture, Architecture::Llama);
assert_eq!(mc.hidden_size, 576);
assert_eq!(mc.head_dim, 64);
}
#[test]
fn detect_architecture_from_architectures_field() {
let json = r#"{
"architectures": ["MistralForCausalLM"],
"hidden_size": 4096,
"num_hidden_layers": 32,
"num_attention_heads": 32
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
assert_eq!(config.detect_architecture(), Some(Architecture::Mistral));
}
#[test]
fn defaults_for_missing_fields() {
let json = r#"{
"model_type": "llama",
"hidden_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 8
}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
let mc = config.to_model_config().unwrap();
assert_eq!(mc.num_kv_heads, 8);
assert_eq!(mc.intermediate_size, 2048);
assert_eq!(mc.vocab_size, 32000);
assert_eq!(mc.max_seq_len, 2048);
assert_eq!(mc.dtype, DType::F16);
}
#[test]
fn unknown_architecture_returns_none() {
let json = r#"{"model_type": "unknown_arch"}"#;
let config = HFConfig::from_json(json.as_bytes()).unwrap();
assert_eq!(config.detect_architecture(), None);
}
}