use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub vocab_size: u64,
pub hidden_size: u64,
pub num_layers: u64,
pub num_attention_heads: u64,
pub head_dim: u64,
pub num_key_value_heads: u64,
pub max_position_embeddings: u64,
pub torch_dtype: String,
pub activation_function: String,
pub layer_norm_epsilon: f64,
pub initializer_range: Option<f64>,
pub attention_dropout: Option<f64>,
pub embedding_dropout: Option<f64>,
pub residual_dropout: Option<f64>,
pub bytes_per_token: u64,
pub kib_per_token: f64,
pub model_type: String,
pub architectures: Vec<String>,
}
impl ModelConfig {
pub fn from_hf_config(config: &Value) -> Result<Self, Box<dyn std::error::Error>> {
let model_type = detect_model_architecture(config);
let architectures = extract_architectures(config);
let vocab_size = get_config_value_u64(config, &["vocab_size", "vocabulary_size"])
.ok_or("vocab_size not found")?;
let hidden_size = get_config_value_u64(
config,
&[
"hidden_size", "n_embd", "d_model", "model_dim", ],
)
.ok_or("hidden_size not found")?;
let num_attention_heads = get_config_value_u64(
config,
&[
"num_attention_heads", "n_head", "num_heads", "attention_heads", ],
)
.ok_or("num_attention_heads not found")?;
let num_layers = get_config_value_u64(
config,
&[
"num_hidden_layers", "n_layer", "num_layers", "n_layers", "depth", ],
)
.ok_or("num_layers not found")?;
let head_dim = hidden_size / num_attention_heads;
let num_key_value_heads = get_config_value_u64(
config,
&[
"num_key_value_heads", "num_kv_heads", ],
)
.unwrap_or(1);
let max_position_embeddings = get_config_value_u64(
config,
&[
"max_position_embeddings", "n_positions", "max_seq_length", "seq_length", ],
)
.ok_or("max_position_embeddings not found")?;
let torch_dtype = get_config_value_str(config, &["torch_dtype", "dtype", "precision"])
.unwrap_or("float32")
.to_string();
let activation_function = get_config_value_str(
config,
&[
"activation_function", "hidden_act", "feed_forward_proj", "activation", ],
)
.unwrap_or("gelu")
.to_string();
let layer_norm_epsilon = get_config_value_f64(
config,
&[
"layer_norm_epsilon", "layer_norm_eps", "rms_norm_eps", "norm_epsilon", ],
)
.unwrap_or(1e-5);
let initializer_range = get_config_value_f64(
config,
&[
"initializer_range", "init_std", "weight_init_std", ],
);
let attention_dropout = get_config_value_f64(
config,
&[
"attn_pdrop", "attention_dropout", "attention_probs_dropout_prob", ],
);
let embedding_dropout = get_config_value_f64(
config,
&[
"embd_pdrop", "embed_dropout", "hidden_dropout_prob", ],
);
let residual_dropout = get_config_value_f64(
config,
&[
"resid_pdrop", "residual_dropout", "hidden_dropout_prob", ],
);
let size_of_dtype = get_bytes_per_element(&torch_dtype);
let bytes_per_token = 2 * num_layers * num_key_value_heads * head_dim * size_of_dtype;
let kib_per_token = bytes_per_token as f64 / 1000.0;
Ok(ModelConfig {
vocab_size,
hidden_size,
num_layers,
num_attention_heads,
head_dim,
num_key_value_heads,
max_position_embeddings,
torch_dtype,
activation_function,
layer_norm_epsilon,
initializer_range,
attention_dropout,
embedding_dropout,
residual_dropout,
bytes_per_token,
kib_per_token,
model_type,
architectures,
})
}
pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
let json = serde_json::to_string_pretty(self)?;
std::fs::write(path, json)?;
Ok(())
}
pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let json = std::fs::read_to_string(path)?;
let config = serde_json::from_str(&json)?;
Ok(config)
}
pub fn print_summary(&self) {
println!("\n=== Model Configuration Summary ===");
println!("Model Type: {} {:?}", self.model_type, self.architectures);
println!("\n🎯 ESSENTIAL:");
println!(" Vocab Size: {}", self.vocab_size);
println!(" Hidden Size: {}", self.hidden_size);
println!(" Number of Layers: {}", self.num_layers);
println!(" Attention Heads: {}", self.num_attention_heads);
println!(
" Key-Value Heads: {} {}",
self.num_key_value_heads,
if self.num_key_value_heads == self.num_attention_heads {
"(same as attn)"
} else {
"(GQA)"
}
);
println!(" Head Dimension: {}", self.head_dim);
println!(
" Max Position Embeddings: {}",
self.max_position_embeddings
);
println!(" Torch Dtype: {}", self.torch_dtype);
println!("\n🔧 IMPORTANT:");
println!(" Activation Function: {}", self.activation_function);
println!("\n⚙️ OPTIONAL:");
println!(" Layer Norm Epsilon: {}", self.layer_norm_epsilon);
if let Some(range) = self.initializer_range {
println!(" Initializer Range: {}", range);
}
if let Some(dropout) = self.attention_dropout {
println!(" Attention Dropout: {}", dropout);
}
if let Some(dropout) = self.embedding_dropout {
println!(" Embedding Dropout: {}", dropout);
}
if let Some(dropout) = self.residual_dropout {
println!(" Residual Dropout: {}", dropout);
}
println!("\n💾 MEMORY USAGE:");
println!(" Size per Token: {:.2} KiB", self.kib_per_token);
println!(" Bytes per Token: {} bytes", self.bytes_per_token);
println!("=====================================\n");
}
}
fn get_config_value_u64(config: &Value, keys: &[&str]) -> Option<u64> {
for key in keys {
if let Some(value) = config.get(key).and_then(|v| v.as_u64()) {
return Some(value);
}
}
None
}
fn get_config_value_f64(config: &Value, keys: &[&str]) -> Option<f64> {
for key in keys {
if let Some(value) = config.get(key).and_then(|v| v.as_f64()) {
return Some(value);
}
}
None
}
fn get_config_value_str<'a>(config: &'a Value, keys: &[&str]) -> Option<&'a str> {
for key in keys {
if let Some(value) = config.get(key).and_then(|v| v.as_str()) {
return Some(value);
}
}
None
}
fn detect_model_architecture(config: &Value) -> String {
if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
return model_type.to_string();
}
if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
return arch.to_lowercase();
}
}
if config.get("n_embd").is_some() && config.get("n_head").is_some() {
return "gpt2".to_string();
} else if config.get("hidden_size").is_some() && config.get("num_attention_heads").is_some() {
return "bert".to_string();
}
"unknown".to_string()
}
fn extract_architectures(config: &Value) -> Vec<String> {
if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
return architectures
.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect();
}
vec![]
}
fn get_bytes_per_element(torch_dtype: &str) -> u64 {
let numeric_part = torch_dtype
.chars()
.filter(|c| c.is_ascii_digit())
.collect::<String>();
if let Ok(parsed_num) = numeric_part.parse::<u64>() {
parsed_num / 8
} else {
4
}
}