1use std::collections::HashMap;
9
10use crate::error::{LmError, LmResult};
11
12#[derive(Debug, Clone)]
19pub struct WeightTensor {
20 pub data: Vec<f32>,
22 pub shape: Vec<usize>,
24}
25
26impl WeightTensor {
27 pub fn from_data(data: Vec<f32>, shape: Vec<usize>) -> LmResult<Self> {
32 let expected: usize = shape.iter().product();
33 if data.len() != expected {
34 return Err(LmError::WeightDataLengthMismatch {
35 data_len: data.len(),
36 shape: shape.clone(),
37 expected,
38 });
39 }
40 Ok(Self { data, shape })
41 }
42
43 pub fn zeros(shape: &[usize]) -> Self {
45 let n: usize = shape.iter().product();
46 Self {
47 data: vec![0.0_f32; n],
48 shape: shape.to_vec(),
49 }
50 }
51
52 pub fn ones(shape: &[usize]) -> Self {
54 let n: usize = shape.iter().product();
55 Self {
56 data: vec![1.0_f32; n],
57 shape: shape.to_vec(),
58 }
59 }
60
61 pub fn eye(rows: usize, cols: usize) -> Self {
65 let mut data = vec![0.0_f32; rows * cols];
66 for i in 0..rows.min(cols) {
67 data[i * cols + i] = 1.0;
68 }
69 Self {
70 data,
71 shape: vec![rows, cols],
72 }
73 }
74
75 pub fn n_elements(&self) -> usize {
79 self.data.len()
80 }
81
82 pub fn ndim(&self) -> usize {
84 self.shape.len()
85 }
86
87 pub fn as_slice(&self) -> &[f32] {
89 &self.data
90 }
91
92 pub fn as_mut_slice(&mut self) -> &mut [f32] {
94 &mut self.data
95 }
96
97 pub fn row_slice(&self, row: usize) -> LmResult<&[f32]> {
99 if self.shape.len() != 2 {
100 return Err(LmError::DimensionMismatch {
101 expected: 2,
102 got: self.shape.len(),
103 });
104 }
105 let cols = self.shape[1];
106 let start = row * cols;
107 if start + cols > self.data.len() {
108 return Err(LmError::DimensionMismatch {
109 expected: row,
110 got: self.shape[0],
111 });
112 }
113 Ok(&self.data[start..start + cols])
114 }
115
116 pub fn validate_shape(&self, expected: &[usize]) -> LmResult<()> {
118 if self.shape != expected {
119 return Err(LmError::WeightShapeMismatch {
120 name: String::new(),
121 expected: expected.to_vec(),
122 got: self.shape.clone(),
123 });
124 }
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone, Default)]
139pub struct ModelWeights {
140 weights: HashMap<String, WeightTensor>,
141}
142
143impl ModelWeights {
144 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn insert(&mut self, name: impl Into<String>, tensor: WeightTensor) {
151 self.weights.insert(name.into(), tensor);
152 }
153
154 pub fn get(&self, name: &str) -> LmResult<&WeightTensor> {
156 self.weights
157 .get(name)
158 .ok_or_else(|| LmError::WeightNotFound { name: name.into() })
159 }
160
161 pub fn get_checked(&self, name: &str, expected_shape: &[usize]) -> LmResult<&WeightTensor> {
163 let t = self.get(name)?;
164 if t.shape != expected_shape {
165 return Err(LmError::WeightShapeMismatch {
166 name: name.into(),
167 expected: expected_shape.to_vec(),
168 got: t.shape.clone(),
169 });
170 }
171 Ok(t)
172 }
173
174 pub fn contains(&self, name: &str) -> bool {
176 self.weights.contains_key(name)
177 }
178
179 pub fn iter(&self) -> impl Iterator<Item = (&str, &WeightTensor)> {
181 self.weights.iter().map(|(k, v)| (k.as_str(), v))
182 }
183
184 pub fn n_params(&self) -> usize {
186 self.weights.values().map(|t| t.n_elements()).sum()
187 }
188
189 pub fn len(&self) -> usize {
191 self.weights.len()
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.weights.is_empty()
197 }
198}
199
200#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn weight_tensor_zeros() {
208 let w = WeightTensor::zeros(&[4, 8]);
209 assert_eq!(w.n_elements(), 32);
210 assert!(w.data.iter().all(|&x| x == 0.0));
211 assert_eq!(w.shape, vec![4, 8]);
212 }
213
214 #[test]
215 fn weight_tensor_ones() {
216 let w = WeightTensor::ones(&[3, 3]);
217 assert_eq!(w.n_elements(), 9);
218 assert!(w.data.iter().all(|&x| x == 1.0));
219 }
220
221 #[test]
222 fn weight_tensor_eye() {
223 let w = WeightTensor::eye(3, 3);
224 assert_eq!(w.data[0], 1.0);
225 assert_eq!(w.data[1], 0.0);
226 assert_eq!(w.data[4], 1.0); assert_eq!(w.data[8], 1.0); }
229
230 #[test]
231 fn weight_tensor_from_data_ok() {
232 let d = vec![1.0_f32, 2.0, 3.0, 4.0];
233 let w = WeightTensor::from_data(d.clone(), vec![2, 2])
234 .expect("4 elements with shape [2,2] should match");
235 assert_eq!(w.data, d);
236 }
237
238 #[test]
239 fn weight_tensor_from_data_shape_mismatch() {
240 let d = vec![1.0_f32; 5];
241 let err = WeightTensor::from_data(d, vec![2, 2]).unwrap_err();
242 assert!(matches!(err, LmError::WeightDataLengthMismatch { .. }));
243 }
244
245 #[test]
246 fn weight_tensor_row_slice() {
247 let w = WeightTensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])
248 .expect("4 elements with shape [2,2] should match");
249 assert_eq!(
250 w.row_slice(0).expect("row 0 of 2x2 tensor should exist"),
251 &[1.0_f32, 2.0]
252 );
253 assert_eq!(
254 w.row_slice(1).expect("row 1 of 2x2 tensor should exist"),
255 &[3.0_f32, 4.0]
256 );
257 }
258
259 #[test]
260 fn weight_tensor_row_slice_non_2d_errors() {
261 let w = WeightTensor::zeros(&[8]);
262 assert!(w.row_slice(0).is_err());
263 }
264
265 #[test]
266 fn weight_tensor_validate_shape_ok() {
267 let w = WeightTensor::zeros(&[4, 8]);
268 w.validate_shape(&[4, 8])
269 .expect("validate_shape should succeed when shape matches");
270 }
271
272 #[test]
273 fn weight_tensor_validate_shape_fail() {
274 let w = WeightTensor::zeros(&[4, 8]);
275 assert!(w.validate_shape(&[8, 4]).is_err());
276 }
277
278 #[test]
279 fn model_weights_insert_and_get() {
280 let mut mw = ModelWeights::new();
281 mw.insert("embed", WeightTensor::zeros(&[10, 4]));
282 let t = mw
283 .get("embed")
284 .expect("'embed' key should exist after insertion");
285 assert_eq!(t.shape, vec![10, 4]);
286 }
287
288 #[test]
289 fn model_weights_get_missing_errors() {
290 let mw = ModelWeights::new();
291 assert!(matches!(
292 mw.get("missing"),
293 Err(LmError::WeightNotFound { .. })
294 ));
295 }
296
297 #[test]
298 fn model_weights_get_checked_shape_error() {
299 let mut mw = ModelWeights::new();
300 mw.insert("w", WeightTensor::zeros(&[4, 8]));
301 assert!(matches!(
302 mw.get_checked("w", &[8, 4]),
303 Err(LmError::WeightShapeMismatch { .. })
304 ));
305 }
306
307 #[test]
308 fn model_weights_n_params() {
309 let mut mw = ModelWeights::new();
310 mw.insert("a", WeightTensor::zeros(&[4, 4])); mw.insert("b", WeightTensor::zeros(&[3, 3])); assert_eq!(mw.n_params(), 25);
313 }
314
315 #[test]
316 fn model_weights_len_and_empty() {
317 let mut mw = ModelWeights::new();
318 assert!(mw.is_empty());
319 mw.insert("x", WeightTensor::zeros(&[1]));
320 assert_eq!(mw.len(), 1);
321 assert!(!mw.is_empty());
322 }
323}