Skip to main content

oxicuda_lm/
weights.rs

1//! Model weight storage.
2//!
3//! `WeightTensor` is a named flat `f32` buffer with shape metadata.
4//! `ModelWeights` is a dictionary mapping string names to weight tensors,
5//! mirroring the PyTorch / HuggingFace naming convention so that callers
6//! can address weights with familiar names like `"model.layers.0.self_attn.q_proj.weight"`.
7
8use std::collections::HashMap;
9
10use crate::error::{LmError, LmResult};
11
12// ─── WeightTensor ────────────────────────────────────────────────────────────
13
14/// A named flat `f32` weight tensor.
15///
16/// Weights are stored in **row-major** order.  For a 2-D matrix `W[m, n]`,
17/// the element at row `i` column `j` is at `data[i * n + j]`.
18#[derive(Debug, Clone)]
19pub struct WeightTensor {
20    /// Raw data in row-major order.
21    pub data: Vec<f32>,
22    /// Shape of the tensor (product equals `data.len()`).
23    pub shape: Vec<usize>,
24}
25
26impl WeightTensor {
27    // ── Constructors ─────────────────────────────────────────────────────
28
29    /// Construct from existing data, validating that `data.len()` equals the
30    /// product of `shape`.
31    pub fn from_data(data: Vec<f32>, shape: Vec<usize>) -> LmResult<Self> {
32        let expected: usize = shape.iter().product();
33        if data.len() != expected {
34            return Err(LmError::WeightDataLengthMismatch {
35                data_len: data.len(),
36                shape: shape.clone(),
37                expected,
38            });
39        }
40        Ok(Self { data, shape })
41    }
42
43    /// Tensor filled with zeros.
44    pub fn zeros(shape: &[usize]) -> Self {
45        let n: usize = shape.iter().product();
46        Self {
47            data: vec![0.0_f32; n],
48            shape: shape.to_vec(),
49        }
50    }
51
52    /// Tensor filled with ones.
53    pub fn ones(shape: &[usize]) -> Self {
54        let n: usize = shape.iter().product();
55        Self {
56            data: vec![1.0_f32; n],
57            shape: shape.to_vec(),
58        }
59    }
60
61    /// Identity-like weight: ones on the "diagonal" of a 2-D matrix.
62    ///
63    /// For a non-square matrix the identity is placed in the top-left corner.
64    pub fn eye(rows: usize, cols: usize) -> Self {
65        let mut data = vec![0.0_f32; rows * cols];
66        for i in 0..rows.min(cols) {
67            data[i * cols + i] = 1.0;
68        }
69        Self {
70            data,
71            shape: vec![rows, cols],
72        }
73    }
74
75    // ── Accessors ────────────────────────────────────────────────────────
76
77    /// Total number of elements.
78    pub fn n_elements(&self) -> usize {
79        self.data.len()
80    }
81
82    /// Number of tensor dimensions.
83    pub fn ndim(&self) -> usize {
84        self.shape.len()
85    }
86
87    /// Borrow the underlying data.
88    pub fn as_slice(&self) -> &[f32] {
89        &self.data
90    }
91
92    /// Mutable borrow of the underlying data.
93    pub fn as_mut_slice(&mut self) -> &mut [f32] {
94        &mut self.data
95    }
96
97    /// For a 2-D weight matrix `[rows × cols]`, return the `row`-th row slice.
98    pub fn row_slice(&self, row: usize) -> LmResult<&[f32]> {
99        if self.shape.len() != 2 {
100            return Err(LmError::DimensionMismatch {
101                expected: 2,
102                got: self.shape.len(),
103            });
104        }
105        let cols = self.shape[1];
106        let start = row * cols;
107        if start + cols > self.data.len() {
108            return Err(LmError::DimensionMismatch {
109                expected: row,
110                got: self.shape[0],
111            });
112        }
113        Ok(&self.data[start..start + cols])
114    }
115
116    /// Validate that this tensor has the given shape.
117    pub fn validate_shape(&self, expected: &[usize]) -> LmResult<()> {
118        if self.shape != expected {
119            return Err(LmError::WeightShapeMismatch {
120                name: String::new(),
121                expected: expected.to_vec(),
122                got: self.shape.clone(),
123            });
124        }
125        Ok(())
126    }
127}
128
129// ─── ModelWeights ────────────────────────────────────────────────────────────
130
131/// A dictionary of named weight tensors.
132///
133/// Weight names follow the HuggingFace naming convention, e.g.:
134/// - `"model.embed_tokens.weight"`
135/// - `"model.layers.0.self_attn.q_proj.weight"`
136/// - `"model.layers.0.mlp.gate_proj.weight"`
137/// - `"lm_head.weight"`
138#[derive(Debug, Clone, Default)]
139pub struct ModelWeights {
140    weights: HashMap<String, WeightTensor>,
141}
142
143impl ModelWeights {
144    /// Create an empty weight store.
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// Insert (or overwrite) a tensor under `name`.
150    pub fn insert(&mut self, name: impl Into<String>, tensor: WeightTensor) {
151        self.weights.insert(name.into(), tensor);
152    }
153
154    /// Retrieve a tensor by name.
155    pub fn get(&self, name: &str) -> LmResult<&WeightTensor> {
156        self.weights
157            .get(name)
158            .ok_or_else(|| LmError::WeightNotFound { name: name.into() })
159    }
160
161    /// Retrieve a tensor by name and validate its shape.
162    pub fn get_checked(&self, name: &str, expected_shape: &[usize]) -> LmResult<&WeightTensor> {
163        let t = self.get(name)?;
164        if t.shape != expected_shape {
165            return Err(LmError::WeightShapeMismatch {
166                name: name.into(),
167                expected: expected_shape.to_vec(),
168                got: t.shape.clone(),
169            });
170        }
171        Ok(t)
172    }
173
174    /// Whether `name` is present in the store.
175    pub fn contains(&self, name: &str) -> bool {
176        self.weights.contains_key(name)
177    }
178
179    /// Iterator over all `(name, tensor)` pairs.
180    pub fn iter(&self) -> impl Iterator<Item = (&str, &WeightTensor)> {
181        self.weights.iter().map(|(k, v)| (k.as_str(), v))
182    }
183
184    /// Total number of scalar parameters across all tensors.
185    pub fn n_params(&self) -> usize {
186        self.weights.values().map(|t| t.n_elements()).sum()
187    }
188
189    /// Number of named weight entries.
190    pub fn len(&self) -> usize {
191        self.weights.len()
192    }
193
194    /// Whether the store is empty.
195    pub fn is_empty(&self) -> bool {
196        self.weights.is_empty()
197    }
198}
199
200// ─── Tests ───────────────────────────────────────────────────────────────────
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn weight_tensor_zeros() {
208        let w = WeightTensor::zeros(&[4, 8]);
209        assert_eq!(w.n_elements(), 32);
210        assert!(w.data.iter().all(|&x| x == 0.0));
211        assert_eq!(w.shape, vec![4, 8]);
212    }
213
214    #[test]
215    fn weight_tensor_ones() {
216        let w = WeightTensor::ones(&[3, 3]);
217        assert_eq!(w.n_elements(), 9);
218        assert!(w.data.iter().all(|&x| x == 1.0));
219    }
220
221    #[test]
222    fn weight_tensor_eye() {
223        let w = WeightTensor::eye(3, 3);
224        assert_eq!(w.data[0], 1.0);
225        assert_eq!(w.data[1], 0.0);
226        assert_eq!(w.data[4], 1.0); // [1,1]
227        assert_eq!(w.data[8], 1.0); // [2,2]
228    }
229
230    #[test]
231    fn weight_tensor_from_data_ok() {
232        let d = vec![1.0_f32, 2.0, 3.0, 4.0];
233        let w = WeightTensor::from_data(d.clone(), vec![2, 2])
234            .expect("4 elements with shape [2,2] should match");
235        assert_eq!(w.data, d);
236    }
237
238    #[test]
239    fn weight_tensor_from_data_shape_mismatch() {
240        let d = vec![1.0_f32; 5];
241        let err = WeightTensor::from_data(d, vec![2, 2]).unwrap_err();
242        assert!(matches!(err, LmError::WeightDataLengthMismatch { .. }));
243    }
244
245    #[test]
246    fn weight_tensor_row_slice() {
247        let w = WeightTensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])
248            .expect("4 elements with shape [2,2] should match");
249        assert_eq!(
250            w.row_slice(0).expect("row 0 of 2x2 tensor should exist"),
251            &[1.0_f32, 2.0]
252        );
253        assert_eq!(
254            w.row_slice(1).expect("row 1 of 2x2 tensor should exist"),
255            &[3.0_f32, 4.0]
256        );
257    }
258
259    #[test]
260    fn weight_tensor_row_slice_non_2d_errors() {
261        let w = WeightTensor::zeros(&[8]);
262        assert!(w.row_slice(0).is_err());
263    }
264
265    #[test]
266    fn weight_tensor_validate_shape_ok() {
267        let w = WeightTensor::zeros(&[4, 8]);
268        w.validate_shape(&[4, 8])
269            .expect("validate_shape should succeed when shape matches");
270    }
271
272    #[test]
273    fn weight_tensor_validate_shape_fail() {
274        let w = WeightTensor::zeros(&[4, 8]);
275        assert!(w.validate_shape(&[8, 4]).is_err());
276    }
277
278    #[test]
279    fn model_weights_insert_and_get() {
280        let mut mw = ModelWeights::new();
281        mw.insert("embed", WeightTensor::zeros(&[10, 4]));
282        let t = mw
283            .get("embed")
284            .expect("'embed' key should exist after insertion");
285        assert_eq!(t.shape, vec![10, 4]);
286    }
287
288    #[test]
289    fn model_weights_get_missing_errors() {
290        let mw = ModelWeights::new();
291        assert!(matches!(
292            mw.get("missing"),
293            Err(LmError::WeightNotFound { .. })
294        ));
295    }
296
297    #[test]
298    fn model_weights_get_checked_shape_error() {
299        let mut mw = ModelWeights::new();
300        mw.insert("w", WeightTensor::zeros(&[4, 8]));
301        assert!(matches!(
302            mw.get_checked("w", &[8, 4]),
303            Err(LmError::WeightShapeMismatch { .. })
304        ));
305    }
306
307    #[test]
308    fn model_weights_n_params() {
309        let mut mw = ModelWeights::new();
310        mw.insert("a", WeightTensor::zeros(&[4, 4])); // 16
311        mw.insert("b", WeightTensor::zeros(&[3, 3])); // 9
312        assert_eq!(mw.n_params(), 25);
313    }
314
315    #[test]
316    fn model_weights_len_and_empty() {
317        let mut mw = ModelWeights::new();
318        assert!(mw.is_empty());
319        mw.insert("x", WeightTensor::zeros(&[1]));
320        assert_eq!(mw.len(), 1);
321        assert!(!mw.is_empty());
322    }
323}