Skip to main content

entrenar/hf_pipeline/loader/
safetensors.rs

1//! SafeTensors-based teacher model implementation
2
3use crate::hf_pipeline::error::{FetchError, Result};
4use ndarray::Array2;
5use std::path::Path;
6
7use super::{MemoryEstimate, TeacherModel};
8
9/// SafeTensors-based teacher model
10pub struct SafeTensorsTeacher {
11    /// Model weights by tensor name
12    weights: std::collections::HashMap<String, Array2<f32>>,
13    /// Tensor names (in order)
14    tensor_names: Vec<String>,
15    /// Number of layers
16    num_layers: usize,
17    /// Hidden dimension
18    hidden_size: usize,
19    /// Total parameter count
20    param_count: u64,
21}
22
23impl SafeTensorsTeacher {
24    /// Load model from SafeTensors file
25    ///
26    /// # Arguments
27    ///
28    /// * `path` - Path to model directory containing model.safetensors
29    ///
30    /// # Errors
31    ///
32    /// Returns error if file not found or parsing fails.
33    pub fn load(path: &Path) -> Result<Self> {
34        use safetensors::SafeTensors;
35
36        let model_path = path.join("model.safetensors");
37        if !model_path.exists() {
38            return Err(FetchError::FileNotFound {
39                repo: path.display().to_string(),
40                file: "model.safetensors".into(),
41            });
42        }
43
44        // Read the file into memory (safe approach for models up to ~10GB)
45        let data = std::fs::read(&model_path)?;
46
47        // Parse SafeTensors
48        let tensors = SafeTensors::deserialize(&data)
49            .map_err(|e| FetchError::SafeTensorsParseError { message: e.to_string() })?;
50
51        // Extract tensor names and compute statistics
52        let tensor_names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
53
54        // Calculate total parameter count
55        let mut param_count: u64 = 0;
56        for name in &tensor_names {
57            if let Ok(info) = tensors.tensor(name) {
58                let numel: u64 = info.shape().iter().map(|&x| x as u64).product();
59                param_count += numel;
60            }
61        }
62
63        // Detect number of layers from tensor naming convention
64        // Common patterns: "encoder.layer.N.", "layers.N.", "h.N."
65        let num_layers = detect_layer_count(&tensor_names);
66
67        // Detect hidden size from weight tensor shapes
68        let hidden_size = detect_hidden_size(&tensors, &tensor_names);
69
70        Ok(Self {
71            weights: std::collections::HashMap::new(), // Lazy load on demand
72            tensor_names,
73            num_layers,
74            hidden_size,
75            param_count,
76        })
77    }
78
79    /// Get list of tensor names in the model
80    #[must_use]
81    pub fn tensor_names(&self) -> &[String] {
82        &self.tensor_names
83    }
84
85    /// Get model weights by tensor name
86    ///
87    /// Note: Currently returns an empty map as weights are loaded on-demand
88    /// for memory efficiency. Future versions will support lazy loading.
89    #[must_use]
90    pub fn weights(&self) -> &std::collections::HashMap<String, Array2<f32>> {
91        &self.weights
92    }
93
94    /// Create mock teacher for testing
95    #[cfg(test)]
96    pub fn mock(num_layers: usize, hidden_size: usize) -> Self {
97        let param_count = (num_layers as u64) * (hidden_size as u64).pow(2) * 4;
98        Self {
99            weights: std::collections::HashMap::new(),
100            tensor_names: Vec::new(),
101            num_layers,
102            hidden_size,
103            param_count,
104        }
105    }
106}
107
108/// Detect number of layers from tensor naming patterns
109fn detect_layer_count(names: &[String]) -> usize {
110    use std::collections::HashSet;
111
112    let mut layer_indices: HashSet<usize> = HashSet::new();
113
114    for name in names {
115        // Match patterns like "encoder.layer.0.", "layers.0.", "h.0."
116        if let Some(idx) = extract_layer_index(name) {
117            layer_indices.insert(idx);
118        }
119    }
120
121    if layer_indices.is_empty() {
122        // Default to 12 if can't detect (BERT-base assumption)
123        12
124    } else {
125        layer_indices.len()
126    }
127}
128
129/// Parse a leading integer from `s`, stopping at the first `.` or end of string.
130fn parse_leading_index(s: &str) -> Option<usize> {
131    let numeric_part = match s.find('.') {
132        Some(end) => &s[..end],
133        None => s,
134    };
135    numeric_part.parse::<usize>().ok()
136}
137
138/// Extract layer index from tensor name
139fn extract_layer_index(name: &str) -> Option<usize> {
140    // Common patterns for layer indices
141    const PATTERNS: &[&str] = &[".layer.", ".layers.", ".h."];
142
143    PATTERNS.iter().find_map(|pattern| {
144        let pos = name.find(pattern)?;
145        let after_pattern = &name[pos + pattern.len()..];
146        parse_leading_index(after_pattern)
147    })
148}
149
150/// Extract the dimension of a square 2D tensor, optionally requiring a minimum size.
151fn square_dim(
152    tensors: &safetensors::SafeTensors<'_>,
153    name: &str,
154    min_size: usize,
155) -> Option<usize> {
156    let shape = tensors.tensor(name).ok()?.shape().to_vec();
157    if shape.len() == 2 && shape[0] == shape[1] && shape[0] >= min_size {
158        Some(shape[0])
159    } else {
160        None
161    }
162}
163
164/// Detect hidden size from attention query weight tensors.
165///
166/// Looks for tensors matching known query-projection naming patterns
167/// (e.g. `.query.weight`, `.q_proj.weight`) that are square matrices.
168fn detect_from_query_weights(
169    tensors: &safetensors::SafeTensors<'_>,
170    names: &[String],
171) -> Option<usize> {
172    const QUERY_PATTERNS: &[&str] =
173        &[".query.weight", ".q_proj.weight", ".self_attn.q_proj.weight"];
174
175    names.iter().find_map(|name| {
176        let matches_pattern = QUERY_PATTERNS.iter().any(|p| name.ends_with(p));
177        matches_pattern.then(|| square_dim(tensors, name, 1)).flatten()
178    })
179}
180
181/// Detect hidden size from any large square weight matrix (fallback heuristic).
182fn detect_from_square_weights(
183    tensors: &safetensors::SafeTensors<'_>,
184    names: &[String],
185) -> Option<usize> {
186    names
187        .iter()
188        .filter(|name| name.contains("weight"))
189        .find_map(|name| square_dim(tensors, name, 256))
190}
191
192/// Detect hidden size from tensor shapes
193fn detect_hidden_size(tensors: &safetensors::SafeTensors<'_>, names: &[String]) -> usize {
194    detect_from_query_weights(tensors, names)
195        .or_else(|| detect_from_square_weights(tensors, names))
196        // C-15 (Meyer DbC): 0 = unknown, no architecture-specific magic number.
197        .unwrap_or(0)
198}
199
200impl TeacherModel for SafeTensorsTeacher {
201    fn forward(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
202        // Mock implementation - just pass through
203        Ok(input.clone())
204    }
205
206    fn hidden_states(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
207        // Return one hidden state per layer
208        Ok(vec![input.clone(); self.num_layers])
209    }
210
211    fn attention_weights(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>> {
212        // Return attention weights per layer
213        let (batch, _seq) = input.dim();
214        let attn = Array2::<f32>::ones((batch, batch));
215        Ok(vec![attn; self.num_layers])
216    }
217
218    fn estimate_memory(&self, batch_size: usize, seq_len: usize) -> MemoryEstimate {
219        MemoryEstimate::fp16(self.param_count, batch_size, seq_len, self.hidden_size)
220    }
221
222    fn param_count(&self) -> u64 {
223        self.param_count
224    }
225
226    fn num_layers(&self) -> usize {
227        self.num_layers
228    }
229
230    fn hidden_size(&self) -> usize {
231        self.hidden_size
232    }
233}