Skip to main content

entrenar/transformer/weights/
mod.rs

1//! Weight loading module for transformer models
2//!
3//! This module provides functions to load model weights from SafeTensors files
4//! and convert them to the format expected by `Transformer::from_params`.
5//!
6//! Supports:
7//! - Qwen2/Qwen2.5 architecture
8//! - LLaMA architecture
9//! - Mistral architecture
10//!
11//! Weight name mapping follows HuggingFace conventions.
12
13mod convert;
14mod detect;
15pub(crate) mod mapping;
16
17#[cfg(test)]
18mod tests;
19
20use crate::error::{Error, Result};
21use crate::Tensor;
22use std::collections::HashMap;
23use std::path::Path;
24
25pub(crate) use convert::tensor_to_f32_vec;
26pub(crate) use detect::{
27    detect_architecture, find_safetensors_files, parse_checkpoint_step_from_path,
28};
29pub(crate) use mapping::map_weight_name;
30
31/// Architecture type for weight name mapping
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Architecture {
34    /// LLaMA and LLaMA-2 models
35    Llama,
36    /// Qwen2 and Qwen2.5 models (includes Qwen2.5-Coder)
37    Qwen2,
38    /// Mistral models
39    Mistral,
40    /// RoBERTa / CodeBERT encoder models (ENC-006)
41    RoBERTa,
42    /// GGUF tensor names (token_embd, blk.N.attn_q, etc.)
43    Gguf,
44    /// Auto-detect from weight names
45    Auto,
46}
47
48/// Load transformer weights from SafeTensors file(s)
49///
50/// # Arguments
51///
52/// * `model_path` - Path to model directory or single SafeTensors file
53/// * `arch` - Model architecture (use Auto to detect from weight names)
54///
55/// # Returns
56///
57/// HashMap of parameter names mapped to Tensor values.
58/// Names follow the HuggingFace LLaMA convention expected by `Transformer::from_params`.
59pub fn load_safetensors_weights(
60    model_path: &Path,
61    arch: Architecture,
62) -> Result<HashMap<String, Tensor>> {
63    use safetensors::SafeTensors;
64
65    // Find SafeTensors files
66    let st_files = find_safetensors_files(model_path)?;
67    if st_files.is_empty() {
68        return Err(Error::ConfigError(format!(
69            "No SafeTensors files found in {}",
70            model_path.display()
71        )));
72    }
73
74    let mut weights = HashMap::new();
75    let mut detected_arch = arch;
76
77    // Process each SafeTensors file
78    for st_path in &st_files {
79        let data = std::fs::read(st_path).map_err(|e| {
80            Error::ConfigError(format!("Failed to read {}: {}", st_path.display(), e))
81        })?;
82
83        let tensors = SafeTensors::deserialize(&data).map_err(|e| {
84            Error::ConfigError(format!("Failed to parse SafeTensors {}: {}", st_path.display(), e))
85        })?;
86
87        // Auto-detect architecture from first file
88        if detected_arch == Architecture::Auto {
89            detected_arch = detect_architecture(&tensors);
90            eprintln!("  Detected architecture: {detected_arch:?}");
91        }
92
93        // Load and map tensors
94        for name in tensors.names() {
95            if let Ok(tensor_view) = tensors.tensor(name) {
96                // Convert tensor to f32 values
97                if let Some(values) = tensor_to_f32_vec(&tensor_view) {
98                    // Map name to standard LLaMA convention
99                    let mapped_name = map_weight_name(name, detected_arch);
100                    let tensor = Tensor::from_vec(values, true);
101                    weights.insert(mapped_name, tensor);
102                }
103            }
104        }
105    }
106
107    eprintln!("  Loaded {} weight tensors", weights.len());
108    Ok(weights)
109}
110
111/// Get expected weight count for a transformer model (minimum — without attention biases)
112pub fn expected_weight_count(num_layers: usize, has_lm_head: bool) -> usize {
113    // Per layer (minimum):
114    //   input_layernorm.weight (1)
115    //   self_attn: q_proj, k_proj, v_proj, o_proj weights (4)
116    //   post_attention_layernorm.weight (1)
117    //   mlp: gate_proj, up_proj, down_proj (3)
118    // = 9 per layer (without biases)
119    //
120    // Models with attention biases (e.g. Qwen2) have 3 additional per layer:
121    //   self_attn: q_proj.bias, k_proj.bias, v_proj.bias
122    // = 12 per layer (with biases)
123    //
124    // Global:
125    //   embed_tokens.weight (1)
126    //   norm.weight (1)
127    //   lm_head.weight (optional, 1)
128    let base = 2 + (num_layers * 9);
129    if has_lm_head {
130        base + 1
131    } else {
132        base
133    }
134}
135
136/// Get expected weight count including attention biases
137#[allow(dead_code)]
138pub fn expected_weight_count_with_biases(num_layers: usize, has_lm_head: bool) -> usize {
139    let base = 2 + (num_layers * 12); // 9 weights + 3 biases per layer
140    if has_lm_head {
141        base + 1
142    } else {
143        base
144    }
145}
146
147/// Validate that loaded weights match expected architecture
148#[allow(clippy::implicit_hasher)]
149pub fn validate_weights(weights: &HashMap<String, Tensor>, num_layers: usize) -> Result<()> {
150    // Check embedding
151    if !weights.contains_key("model.embed_tokens.weight") {
152        return Err(Error::ConfigError("Missing model.embed_tokens.weight".into()));
153    }
154
155    // Check final norm
156    if !weights.contains_key("model.norm.weight") {
157        return Err(Error::ConfigError("Missing model.norm.weight".into()));
158    }
159
160    // Check each layer
161    for i in 0..num_layers {
162        let layer_prefix = format!("model.layers.{i}");
163
164        // Required layer weights
165        let required = [
166            ".input_layernorm.weight",
167            ".self_attn.q_proj.weight",
168            ".self_attn.k_proj.weight",
169            ".self_attn.v_proj.weight",
170            ".self_attn.o_proj.weight",
171            ".post_attention_layernorm.weight",
172            ".mlp.gate_proj.weight",
173            ".mlp.up_proj.weight",
174            ".mlp.down_proj.weight",
175        ];
176
177        for suffix in required {
178            let key = format!("{layer_prefix}{suffix}");
179            if !weights.contains_key(&key) {
180                return Err(Error::ConfigError(format!("Missing {key}")));
181            }
182        }
183    }
184
185    // Check weight count for informational purposes
186    let has_lm_head = weights.contains_key("lm_head.weight");
187    let expected = expected_weight_count(num_layers, has_lm_head);
188    let actual = weights.len();
189    if actual < expected {
190        // This is a warning, not an error - some models may have extra or fewer weights
191        eprintln!(
192            "Warning: Expected at least {expected} weights, found {actual} (may have extra bias tensors)"
193        );
194    }
195
196    Ok(())
197}