use std::path::Path;
use crate::model::{
ActivationType, Architecture, ModelConfig, RopeConfig, RopeScalingType, RopeType,
};
use super::error::{ModelError, ModelResult};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct HfConfig {
pub model_type: Option<String>,
pub vocab_size: Option<usize>,
pub hidden_size: Option<usize>,
pub intermediate_size: Option<usize>,
pub num_hidden_layers: Option<usize>,
pub num_attention_heads: Option<usize>,
pub num_key_value_heads: Option<usize>,
pub max_position_embeddings: Option<usize>,
pub rms_norm_eps: Option<f32>,
pub rope_theta: Option<f32>,
pub rope_scaling: Option<RopeScalingConfig>,
pub tie_word_embeddings: Option<bool>,
pub hidden_act: Option<String>,
#[serde(default)]
pub attention_bias: bool,
#[serde(default)]
pub mlp_bias: bool,
pub head_dim: Option<usize>,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<serde_json::Value>,
#[serde(default)]
pub sliding_window_pattern: Option<Vec<bool>>,
#[serde(default)]
pub num_key_value_shared_layers: Option<usize>,
#[serde(default)]
pub n_embd_per_layer: Option<usize>,
#[serde(default)]
pub head_dim_swa: Option<usize>,
#[serde(default)]
pub rope_theta_swa: Option<f32>,
#[serde(default)]
pub sliding_window: Option<usize>,
#[serde(default)]
pub final_logit_softcapping: Option<f32>,
#[serde(default)]
pub partial_rotary_factor: Option<f32>,
#[serde(default)]
pub hidden_activation: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScalingConfig {
#[serde(rename = "type")]
pub scaling_type: Option<String>,
pub factor: Option<f32>,
}
impl HfConfig {
pub fn from_file<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
let path = path.as_ref();
let data = std::fs::read_to_string(path)
.map_err(|e| ModelError::ConfigError(format!("{}: {}", path.display(), e)))?;
Self::from_json(&data)
}
pub fn from_json(json: &str) -> ModelResult<Self> {
let raw: serde_json::Value = serde_json::from_str(json)
.map_err(|e| ModelError::ConfigError(format!("Failed to parse config.json: {e}")))?;
let config_value = if let Some(text_config) = raw.get("text_config") {
let mut merged = text_config.clone();
if let Some(mt) = raw.get("model_type") {
merged.as_object_mut().unwrap().insert("model_type".into(), mt.clone());
}
if let Some(tie) = raw.get("tie_word_embeddings") {
merged.as_object_mut().unwrap().insert("tie_word_embeddings".into(), tie.clone());
}
merged
} else {
raw.clone()
};
let mut config: HfConfig = serde_json::from_value(config_value.clone())
.map_err(|e| ModelError::ConfigError(format!("Failed to parse config.json: {e}")))?;
if let Some(obj) = config_value.as_object() {
if config.sliding_window_pattern.is_none() {
if let Some(layer_types) = obj.get("layer_types").and_then(|v| v.as_array()) {
config.sliding_window_pattern = Some(
layer_types.iter()
.map(|v| v.as_str() == Some("sliding_attention"))
.collect()
);
}
}
if config.num_key_value_shared_layers.is_none() {
if let Some(v) = obj.get("num_kv_shared_layers").and_then(|v| v.as_u64()) {
config.num_key_value_shared_layers = Some(v as usize);
}
}
if config.n_embd_per_layer.is_none() {
if let Some(v) = obj.get("hidden_size_per_layer_input").and_then(|v| v.as_u64()) {
config.n_embd_per_layer = Some(v as usize);
}
}
if config.head_dim_swa.is_none() {
if let Some(global_hd) = obj.get("global_head_dim").and_then(|v| v.as_u64()) {
config.head_dim_swa = config.head_dim;
config.head_dim = Some(global_hd as usize);
}
}
if config.rope_theta_swa.is_none() {
if let Some(rp) = obj.get("rope_parameters").and_then(|v| v.as_object()) {
if let Some(swa) = rp.get("sliding_attention").and_then(|v| v.as_object()) {
if let Some(theta) = swa.get("rope_theta").and_then(|v| v.as_f64()) {
config.rope_theta_swa = Some(theta as f32);
}
}
if let Some(full) = rp.get("full_attention").and_then(|v| v.as_object()) {
if let Some(theta) = full.get("rope_theta").and_then(|v| v.as_f64()) {
config.rope_theta = Some(theta as f32);
}
if config.partial_rotary_factor.is_none() {
if let Some(prf) = full.get("partial_rotary_factor").and_then(|v| v.as_f64()) {
config.partial_rotary_factor = Some(prf as f32);
}
}
}
}
}
}
Ok(config)
}
pub fn architecture(&self) -> Architecture {
match self.model_type.as_deref() {
Some("llama") => Architecture::Llama,
Some("mistral") => Architecture::Mistral,
Some("qwen2") => Architecture::Qwen2,
Some("codellama") => Architecture::CodeLlama,
Some("yi") => Architecture::Yi,
Some("deepseek") | Some("deepseek_v2") => Architecture::DeepSeek,
Some("mixtral") => Architecture::Mixtral,
Some("gemma4") | Some("gemma4_text") => Architecture::Gemma4,
Some("gemma2") | Some("gemma2_text") => Architecture::Gemma2,
Some("gemma") => Architecture::Gemma,
_ => Architecture::Unknown,
}
}
pub fn to_model_config(&self) -> ModelResult<ModelConfig> {
let hidden_size = self
.hidden_size
.ok_or_else(|| ModelError::ConfigError("missing hidden_size in config.json".into()))?;
let num_heads = self.num_attention_heads.ok_or_else(|| {
ModelError::ConfigError("missing num_attention_heads in config.json".into())
})?;
let num_layers = self.num_hidden_layers.ok_or_else(|| {
ModelError::ConfigError("missing num_hidden_layers in config.json".into())
})?;
let num_kv_heads = self.num_key_value_heads.unwrap_or(num_heads);
let head_dim = self.head_dim.unwrap_or(hidden_size / num_heads);
let intermediate_size = self.intermediate_size.unwrap_or(hidden_size * 4 * 2 / 3);
let max_seq_len = self.max_position_embeddings.unwrap_or(2048);
let norm_eps = self.rms_norm_eps.unwrap_or(1e-5);
let vocab_size = self.vocab_size.unwrap_or(32000);
let architecture = self.architecture();
let rope_type = match architecture {
Architecture::Qwen2 => RopeType::NeoX,
_ => RopeType::Normal,
};
let freq_base = self.rope_theta.unwrap_or(10000.0);
let freq_scale = self
.rope_scaling
.as_ref()
.and_then(|s| s.factor)
.unwrap_or(1.0);
let rope_config = RopeConfig {
freq_base,
freq_scale,
n_dims: head_dim,
scaling_type: RopeScalingType::None,
original_max_position_embeddings: max_seq_len,
rope_type,
mrope_sections: None,
};
let act_str = self.hidden_activation.as_deref()
.or(self.hidden_act.as_deref());
let hidden_act = match act_str {
Some("gelu") | Some("gelu_new") | Some("gelu_pytorch_tanh") => ActivationType::GELU,
_ => ActivationType::SiLU,
};
let uses_gelu = matches!(hidden_act, ActivationType::GELU);
let sliding_window = self.sliding_window.unwrap_or(0);
let mut config = ModelConfig {
vocab_size,
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
max_seq_len,
norm_eps,
rope_config,
use_parallel_residual: false,
hidden_act,
attention_bias: self.attention_bias,
mlp_bias: self.mlp_bias,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
num_experts: 0,
num_experts_per_token: 0,
expert_intermediate_size: 0,
key_length: head_dim,
value_length: head_dim,
ssm_d_inner: 0,
ssm_d_state: 0,
ssm_n_group: 0,
ssm_dt_rank: 0,
ssm_conv_kernel: 0,
attn_logit_softcap: 0.0,
final_logit_softcap: self.final_logit_softcapping.unwrap_or(0.0),
sliding_window,
has_combined_qkv: false,
uses_layer_norm: false,
uses_gelu,
has_ffn_gate: true,
attention_layer_configs: None,
kv_source_layer: None,
};
if let Some(ref pattern) = self.sliding_window_pattern {
if architecture.has_heterogeneous_attention() {
let swa_head_dim = self.head_dim_swa.unwrap_or(config.head_dim);
let swa_kv_heads = config.num_kv_heads;
let swa_rope_freq_base = self.rope_theta_swa.unwrap_or(config.rope_config.freq_base);
let swa_rope_dims = swa_head_dim;
let global_head_dim = config.head_dim;
let global_kv_heads = config.num_kv_heads;
let global_rope_freq_base = config.rope_config.freq_base;
let prf = self.partial_rotary_factor.unwrap_or(1.0);
let global_rope_dims = (global_head_dim as f32 * prf) as usize;
config.attention_layer_configs =
Some(ModelConfig::build_attention_layer_configs_from_pattern(
pattern,
swa_head_dim,
swa_kv_heads,
swa_rope_freq_base,
swa_rope_dims,
sliding_window,
global_head_dim,
global_kv_heads,
global_rope_freq_base,
global_rope_dims,
));
if let Some(shared_layers) = self.num_key_value_shared_layers {
if shared_layers > 0 {
config.kv_source_layer = Some(ModelConfig::build_kv_source_mapping(
config.num_layers,
shared_layers,
config.attention_layer_configs.as_ref().unwrap(),
));
}
}
}
}
Ok(config)
}
}