entrenar/transformer/weights/
mod.rs1mod convert;
14mod detect;
15pub(crate) mod mapping;
16
17#[cfg(test)]
18mod tests;
19
20use crate::error::{Error, Result};
21use crate::Tensor;
22use std::collections::HashMap;
23use std::path::Path;
24
25pub(crate) use convert::tensor_to_f32_vec;
26pub(crate) use detect::{
27 detect_architecture, find_safetensors_files, parse_checkpoint_step_from_path,
28};
29pub(crate) use mapping::map_weight_name;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum Architecture {
34 Llama,
36 Qwen2,
38 Mistral,
40 RoBERTa,
42 Gguf,
44 Auto,
46}
47
48pub fn load_safetensors_weights(
60 model_path: &Path,
61 arch: Architecture,
62) -> Result<HashMap<String, Tensor>> {
63 use safetensors::SafeTensors;
64
65 let st_files = find_safetensors_files(model_path)?;
67 if st_files.is_empty() {
68 return Err(Error::ConfigError(format!(
69 "No SafeTensors files found in {}",
70 model_path.display()
71 )));
72 }
73
74 let mut weights = HashMap::new();
75 let mut detected_arch = arch;
76
77 for st_path in &st_files {
79 let data = std::fs::read(st_path).map_err(|e| {
80 Error::ConfigError(format!("Failed to read {}: {}", st_path.display(), e))
81 })?;
82
83 let tensors = SafeTensors::deserialize(&data).map_err(|e| {
84 Error::ConfigError(format!("Failed to parse SafeTensors {}: {}", st_path.display(), e))
85 })?;
86
87 if detected_arch == Architecture::Auto {
89 detected_arch = detect_architecture(&tensors);
90 eprintln!(" Detected architecture: {detected_arch:?}");
91 }
92
93 for name in tensors.names() {
95 if let Ok(tensor_view) = tensors.tensor(name) {
96 if let Some(values) = tensor_to_f32_vec(&tensor_view) {
98 let mapped_name = map_weight_name(name, detected_arch);
100 let tensor = Tensor::from_vec(values, true);
101 weights.insert(mapped_name, tensor);
102 }
103 }
104 }
105 }
106
107 eprintln!(" Loaded {} weight tensors", weights.len());
108 Ok(weights)
109}
110
111pub fn expected_weight_count(num_layers: usize, has_lm_head: bool) -> usize {
113 let base = 2 + (num_layers * 9);
129 if has_lm_head {
130 base + 1
131 } else {
132 base
133 }
134}
135
136#[allow(dead_code)]
138pub fn expected_weight_count_with_biases(num_layers: usize, has_lm_head: bool) -> usize {
139 let base = 2 + (num_layers * 12); if has_lm_head {
141 base + 1
142 } else {
143 base
144 }
145}
146
147#[allow(clippy::implicit_hasher)]
149pub fn validate_weights(weights: &HashMap<String, Tensor>, num_layers: usize) -> Result<()> {
150 if !weights.contains_key("model.embed_tokens.weight") {
152 return Err(Error::ConfigError("Missing model.embed_tokens.weight".into()));
153 }
154
155 if !weights.contains_key("model.norm.weight") {
157 return Err(Error::ConfigError("Missing model.norm.weight".into()));
158 }
159
160 for i in 0..num_layers {
162 let layer_prefix = format!("model.layers.{i}");
163
164 let required = [
166 ".input_layernorm.weight",
167 ".self_attn.q_proj.weight",
168 ".self_attn.k_proj.weight",
169 ".self_attn.v_proj.weight",
170 ".self_attn.o_proj.weight",
171 ".post_attention_layernorm.weight",
172 ".mlp.gate_proj.weight",
173 ".mlp.up_proj.weight",
174 ".mlp.down_proj.weight",
175 ];
176
177 for suffix in required {
178 let key = format!("{layer_prefix}{suffix}");
179 if !weights.contains_key(&key) {
180 return Err(Error::ConfigError(format!("Missing {key}")));
181 }
182 }
183 }
184
185 let has_lm_head = weights.contains_key("lm_head.weight");
187 let expected = expected_weight_count(num_layers, has_lm_head);
188 let actual = weights.len();
189 if actual < expected {
190 eprintln!(
192 "Warning: Expected at least {expected} weights, found {actual} (may have extra bias tensors)"
193 );
194 }
195
196 Ok(())
197}