use axonml_tensor::Tensor;
use std::collections::HashMap;
use crate::error::LLMResult;
pub trait LoadStateDict {
fn load_state_dict(
&mut self,
state_dict: &HashMap<String, Tensor<f32>>,
strict: bool,
) -> LLMResult<LoadResult>;
fn state_dict_keys(&self) -> Vec<String>;
}
#[derive(Debug, Default)]
pub struct LoadResult {
pub unexpected_keys: Vec<String>,
pub missing_keys: Vec<String>,
pub loaded_count: usize,
}
impl LoadResult {
pub fn is_success(&self, strict: bool) -> bool {
!strict || self.missing_keys.is_empty()
}
pub fn print_summary(&self) {
println!("Loaded {} parameters", self.loaded_count);
if !self.missing_keys.is_empty() {
println!("Missing keys ({}):", self.missing_keys.len());
for key in &self.missing_keys {
println!(" - {}", key);
}
}
if !self.unexpected_keys.is_empty() {
println!("Unexpected keys ({}):", self.unexpected_keys.len());
for key in &self.unexpected_keys {
println!(" - {}", key);
}
}
}
}
pub fn map_hf_to_axonml(hf_name: &str, arch: &str) -> String {
let name = hf_name
.strip_prefix("model.")
.or_else(|| hf_name.strip_prefix("transformer."))
.unwrap_or(hf_name);
match arch {
"llama" | "mistral" => map_llama_weights(name),
"phi" => map_phi_weights(name),
_ => name.to_string(),
}
}
fn map_llama_weights(name: &str) -> String {
name.replace("self_attn.", "attention.")
.replace("input_layernorm", "input_norm")
.replace("post_attention_layernorm", "post_attn_norm")
}
fn map_phi_weights(name: &str) -> String {
name.replace("self_attn.", "attention.")
.replace("fc1", "up_proj")
.replace("fc2", "down_proj")
}
pub fn map_axonml_to_hf(axonml_name: &str, arch: &str) -> String {
match arch {
"llama" | "mistral" => {
let name = axonml_name
.replace("attention.", "self_attn.")
.replace("input_norm", "input_layernorm")
.replace("post_attn_norm", "post_attention_layernorm");
format!("model.{}", name)
}
"phi" => {
let name = axonml_name
.replace("attention.", "self_attn.")
.replace("up_proj", "fc1")
.replace("down_proj", "fc2");
format!("model.{}", name)
}
_ => axonml_name.to_string(),
}
}
pub fn load_with_mapping<M: LoadStateDict>(
model: &mut M,
weights: &HashMap<String, Tensor<f32>>,
arch: &str,
strict: bool,
) -> LLMResult<LoadResult> {
let mapped: HashMap<String, Tensor<f32>> = weights
.iter()
.map(|(k, v)| (map_hf_to_axonml(k, arch), v.clone()))
.collect();
model.load_state_dict(&mapped, strict)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_mapping() {
assert_eq!(
map_hf_to_axonml("model.layers.0.self_attn.q_proj.weight", "llama"),
"layers.0.attention.q_proj.weight"
);
}
#[test]
fn test_load_result() {
let mut result = LoadResult::default();
result.loaded_count = 10;
assert!(result.is_success(true));
result.missing_keys.push("test".to_string());
assert!(!result.is_success(true));
assert!(result.is_success(false));
}
}