mod convert;
mod detect;
pub(crate) mod mapping;
#[cfg(test)]
mod tests;
use crate::error::{Error, Result};
use crate::Tensor;
use std::collections::HashMap;
use std::path::Path;
pub(crate) use convert::tensor_to_f32_vec;
pub(crate) use detect::{
detect_architecture, find_safetensors_files, parse_checkpoint_step_from_path,
};
pub(crate) use mapping::map_weight_name;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Architecture {
Llama,
Qwen2,
Mistral,
RoBERTa,
Gguf,
Auto,
}
pub fn load_safetensors_weights(
model_path: &Path,
arch: Architecture,
) -> Result<HashMap<String, Tensor>> {
use safetensors::SafeTensors;
let st_files = find_safetensors_files(model_path)?;
if st_files.is_empty() {
return Err(Error::ConfigError(format!(
"No SafeTensors files found in {}",
model_path.display()
)));
}
let mut weights = HashMap::new();
let mut detected_arch = arch;
for st_path in &st_files {
let data = std::fs::read(st_path).map_err(|e| {
Error::ConfigError(format!("Failed to read {}: {}", st_path.display(), e))
})?;
let tensors = SafeTensors::deserialize(&data).map_err(|e| {
Error::ConfigError(format!("Failed to parse SafeTensors {}: {}", st_path.display(), e))
})?;
if detected_arch == Architecture::Auto {
detected_arch = detect_architecture(&tensors);
eprintln!(" Detected architecture: {detected_arch:?}");
}
for name in tensors.names() {
if let Ok(tensor_view) = tensors.tensor(name) {
if let Some(values) = tensor_to_f32_vec(&tensor_view) {
let mapped_name = map_weight_name(name, detected_arch);
let tensor = Tensor::from_vec(values, true);
weights.insert(mapped_name, tensor);
}
}
}
}
eprintln!(" Loaded {} weight tensors", weights.len());
Ok(weights)
}
pub fn expected_weight_count(num_layers: usize, has_lm_head: bool) -> usize {
let base = 2 + (num_layers * 9);
if has_lm_head {
base + 1
} else {
base
}
}
#[allow(dead_code)]
pub fn expected_weight_count_with_biases(num_layers: usize, has_lm_head: bool) -> usize {
let base = 2 + (num_layers * 12); if has_lm_head {
base + 1
} else {
base
}
}
#[allow(clippy::implicit_hasher)]
pub fn validate_weights(weights: &HashMap<String, Tensor>, num_layers: usize) -> Result<()> {
if !weights.contains_key("model.embed_tokens.weight") {
return Err(Error::ConfigError("Missing model.embed_tokens.weight".into()));
}
if !weights.contains_key("model.norm.weight") {
return Err(Error::ConfigError("Missing model.norm.weight".into()));
}
for i in 0..num_layers {
let layer_prefix = format!("model.layers.{i}");
let required = [
".input_layernorm.weight",
".self_attn.q_proj.weight",
".self_attn.k_proj.weight",
".self_attn.v_proj.weight",
".self_attn.o_proj.weight",
".post_attention_layernorm.weight",
".mlp.gate_proj.weight",
".mlp.up_proj.weight",
".mlp.down_proj.weight",
];
for suffix in required {
let key = format!("{layer_prefix}{suffix}");
if !weights.contains_key(&key) {
return Err(Error::ConfigError(format!("Missing {key}")));
}
}
}
let has_lm_head = weights.contains_key("lm_head.weight");
let expected = expected_weight_count(num_layers, has_lm_head);
let actual = weights.len();
if actual < expected {
eprintln!(
"Warning: Expected at least {expected} weights, found {actual} (may have extra bias tensors)"
);
}
Ok(())
}