entrenar/hf_pipeline/loader/
safetensors.rs1use crate::hf_pipeline::error::{FetchError, Result};
4use ndarray::Array2;
5use std::path::Path;
6
7use super::{MemoryEstimate, TeacherModel};
8
9pub struct SafeTensorsTeacher {
11 weights: std::collections::HashMap<String, Array2<f32>>,
13 tensor_names: Vec<String>,
15 num_layers: usize,
17 hidden_size: usize,
19 param_count: u64,
21}
22
23impl SafeTensorsTeacher {
24 pub fn load(path: &Path) -> Result<Self> {
34 use safetensors::SafeTensors;
35
36 let model_path = path.join("model.safetensors");
37 if !model_path.exists() {
38 return Err(FetchError::FileNotFound {
39 repo: path.display().to_string(),
40 file: "model.safetensors".into(),
41 });
42 }
43
44 let data = std::fs::read(&model_path)?;
46
47 let tensors = SafeTensors::deserialize(&data)
49 .map_err(|e| FetchError::SafeTensorsParseError { message: e.to_string() })?;
50
51 let tensor_names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
53
54 let mut param_count: u64 = 0;
56 for name in &tensor_names {
57 if let Ok(info) = tensors.tensor(name) {
58 let numel: u64 = info.shape().iter().map(|&x| x as u64).product();
59 param_count += numel;
60 }
61 }
62
63 let num_layers = detect_layer_count(&tensor_names);
66
67 let hidden_size = detect_hidden_size(&tensors, &tensor_names);
69
70 Ok(Self {
71 weights: std::collections::HashMap::new(), tensor_names,
73 num_layers,
74 hidden_size,
75 param_count,
76 })
77 }
78
79 #[must_use]
81 pub fn tensor_names(&self) -> &[String] {
82 &self.tensor_names
83 }
84
85 #[must_use]
90 pub fn weights(&self) -> &std::collections::HashMap<String, Array2<f32>> {
91 &self.weights
92 }
93
94 #[cfg(test)]
96 pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
97 let param_count = (num_layers as u64) * (hidden_size as u64).pow(2) * 4;
98 Self {
99 weights: std::collections::HashMap::new(),
100 tensor_names: Vec::new(),
101 num_layers,
102 hidden_size,
103 param_count,
104 }
105 }
106}
107
108fn detect_layer_count(names: &[String]) -> usize {
110 use std::collections::HashSet;
111
112 let mut layer_indices: HashSet<usize> = HashSet::new();
113
114 for name in names {
115 if let Some(idx) = extract_layer_index(name) {
117 layer_indices.insert(idx);
118 }
119 }
120
121 if layer_indices.is_empty() {
122 12
124 } else {
125 layer_indices.len()
126 }
127}
128
129fn parse_leading_index(s: &str) -> Option<usize> {
131 let numeric_part = match s.find('.') {
132 Some(end) => &s[..end],
133 None => s,
134 };
135 numeric_part.parse::<usize>().ok()
136}
137
138fn extract_layer_index(name: &str) -> Option<usize> {
140 const PATTERNS: &[&str] = &[".layer.", ".layers.", ".h."];
142
143 PATTERNS.iter().find_map(|pattern| {
144 let pos = name.find(pattern)?;
145 let after_pattern = &name[pos + pattern.len()..];
146 parse_leading_index(after_pattern)
147 })
148}
149
150fn square_dim(
152 tensors: &safetensors::SafeTensors<'_>,
153 name: &str,
154 min_size: usize,
155) -> Option<usize> {
156 let shape = tensors.tensor(name).ok()?.shape().to_vec();
157 if shape.len() == 2 && shape[0] == shape[1] && shape[0] >= min_size {
158 Some(shape[0])
159 } else {
160 None
161 }
162}
163
164fn detect_from_query_weights(
169 tensors: &safetensors::SafeTensors<'_>,
170 names: &[String],
171) -> Option<usize> {
172 const QUERY_PATTERNS: &[&str] =
173 &[".query.weight", ".q_proj.weight", ".self_attn.q_proj.weight"];
174
175 names.iter().find_map(|name| {
176 let matches_pattern = QUERY_PATTERNS.iter().any(|p| name.ends_with(p));
177 matches_pattern.then(|| square_dim(tensors, name, 1)).flatten()
178 })
179}
180
181fn detect_from_square_weights(
183 tensors: &safetensors::SafeTensors<'_>,
184 names: &[String],
185) -> Option<usize> {
186 names
187 .iter()
188 .filter(|name| name.contains("weight"))
189 .find_map(|name| square_dim(tensors, name, 256))
190}
191
192fn detect_hidden_size(tensors: &safetensors::SafeTensors<'_>, names: &[String]) -> usize {
194 detect_from_query_weights(tensors, names)
195 .or_else(|| detect_from_square_weights(tensors, names))
196 .unwrap_or(0)
198}
199
200impl TeacherModel for SafeTensorsTeacher {
201 fn forward(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
202 Ok(input.clone())
204 }
205
206 fn hidden_states(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
207 Ok(vec![input.clone(); self.num_layers])
209 }
210
211 fn attention_weights(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
212 let (batch, _seq) = input.dim();
214 let attn = Array2::<f32>::ones((batch, batch));
215 Ok(vec![attn; self.num_layers])
216 }
217
218 fn estimate_memory(&self, batch_size: usize, seq_len: usize) -> MemoryEstimate {
219 MemoryEstimate::fp16(self.param_count, batch_size, seq_len, self.hidden_size)
220 }
221
222 fn param_count(&self) -> u64 {
223 self.param_count
224 }
225
226 fn num_layers(&self) -> usize {
227 self.num_layers
228 }
229
230 fn hidden_size(&self) -> usize {
231 self.hidden_size
232 }
233}