use std::path::Path;
use crate::model::{
ActivationType, Architecture, ModelConfig, RopeConfig, RopeScalingType, RopeType,
};
use super::error::{OnnxError, OnnxResult};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct HfConfig {
pub model_type: Option<String>,
pub vocab_size: Option<usize>,
pub hidden_size: Option<usize>,
pub intermediate_size: Option<usize>,
pub num_hidden_layers: Option<usize>,
pub num_attention_heads: Option<usize>,
pub num_key_value_heads: Option<usize>,
pub max_position_embeddings: Option<usize>,
pub rms_norm_eps: Option<f32>,
pub rope_theta: Option<f32>,
pub rope_scaling: Option<RopeScalingConfig>,
pub tie_word_embeddings: Option<bool>,
pub hidden_act: Option<String>,
#[serde(default)]
pub attention_bias: bool,
#[serde(default)]
pub mlp_bias: bool,
pub head_dim: Option<usize>,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<serde_json::Value>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScalingConfig {
#[serde(rename = "type")]
pub scaling_type: Option<String>,
pub factor: Option<f32>,
}
impl HfConfig {
pub fn from_file<P: AsRef<Path>>(path: P) -> OnnxResult<Self> {
let path = path.as_ref();
let data = std::fs::read_to_string(path)
.map_err(|e| OnnxError::MissingConfig(format!("{}: {}", path.display(), e)))?;
Self::from_json(&data)
}
pub fn from_json(json: &str) -> OnnxResult<Self> {
serde_json::from_str(json)
.map_err(|e| OnnxError::ConfigParse(format!("Failed to parse config.json: {}", e)))
}
pub fn architecture(&self) -> Architecture {
match self.model_type.as_deref() {
Some("llama") => Architecture::Llama,
Some("mistral") => Architecture::Mistral,
Some("qwen2") => Architecture::Qwen2,
Some("codellama") => Architecture::CodeLlama,
Some("yi") => Architecture::Yi,
Some("deepseek") | Some("deepseek_v2") => Architecture::DeepSeek,
Some("mixtral") => Architecture::Mixtral,
_ => Architecture::Unknown,
}
}
pub fn to_model_config(&self) -> OnnxResult<ModelConfig> {
let hidden_size = self
.hidden_size
.ok_or_else(|| OnnxError::ConfigParse("missing hidden_size in config.json".into()))?;
let num_heads = self.num_attention_heads.ok_or_else(|| {
OnnxError::ConfigParse("missing num_attention_heads in config.json".into())
})?;
let num_layers = self.num_hidden_layers.ok_or_else(|| {
OnnxError::ConfigParse("missing num_hidden_layers in config.json".into())
})?;
let num_kv_heads = self.num_key_value_heads.unwrap_or(num_heads);
let head_dim = self.head_dim.unwrap_or(hidden_size / num_heads);
let intermediate_size = self.intermediate_size.unwrap_or(hidden_size * 4 * 2 / 3);
let max_seq_len = self.max_position_embeddings.unwrap_or(2048);
let norm_eps = self.rms_norm_eps.unwrap_or(1e-5);
let vocab_size = self.vocab_size.unwrap_or(32000);
let architecture = self.architecture();
let rope_type = match architecture {
Architecture::Qwen2 => RopeType::NeoX,
_ => RopeType::Normal,
};
let freq_base = self.rope_theta.unwrap_or(10000.0);
let freq_scale = self
.rope_scaling
.as_ref()
.and_then(|s| s.factor)
.unwrap_or(1.0);
let rope_config = RopeConfig {
freq_base,
freq_scale,
n_dims: head_dim,
scaling_type: RopeScalingType::None,
original_max_position_embeddings: max_seq_len,
rope_type,
};
let hidden_act = match self.hidden_act.as_deref() {
Some("gelu") | Some("gelu_new") => ActivationType::GELU,
_ => ActivationType::SiLU,
};
Ok(ModelConfig {
vocab_size,
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
max_seq_len,
norm_eps,
rope_config,
use_parallel_residual: false,
hidden_act,
attention_bias: self.attention_bias,
mlp_bias: self.mlp_bias,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
num_experts: 0,
num_experts_per_token: 0,
expert_intermediate_size: 0,
key_length: head_dim,
value_length: head_dim,
ssm_d_inner: 0,
ssm_d_state: 0,
ssm_n_group: 0,
ssm_dt_rank: 0,
ssm_conv_kernel: 0,
attn_logit_softcap: 0.0,
final_logit_softcap: 0.0,
sliding_window: 0,
has_combined_qkv: false,
uses_layer_norm: false,
uses_gelu: false,
has_ffn_gate: true,
})
}
}