use std::collections::HashMap;
use crate::error::{LmError, LmResult};
#[derive(Debug, Clone)]
pub struct WeightTensor {
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn from_data(data: Vec<f32>, shape: Vec<usize>) -> LmResult<Self> {
let expected: usize = shape.iter().product();
if data.len() != expected {
return Err(LmError::WeightDataLengthMismatch {
data_len: data.len(),
shape: shape.clone(),
expected,
});
}
Ok(Self { data, shape })
}
pub fn zeros(shape: &[usize]) -> Self {
let n: usize = shape.iter().product();
Self {
data: vec![0.0_f32; n],
shape: shape.to_vec(),
}
}
pub fn ones(shape: &[usize]) -> Self {
let n: usize = shape.iter().product();
Self {
data: vec![1.0_f32; n],
shape: shape.to_vec(),
}
}
pub fn eye(rows: usize, cols: usize) -> Self {
let mut data = vec![0.0_f32; rows * cols];
for i in 0..rows.min(cols) {
data[i * cols + i] = 1.0;
}
Self {
data,
shape: vec![rows, cols],
}
}
pub fn n_elements(&self) -> usize {
self.data.len()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn as_slice(&self) -> &[f32] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [f32] {
&mut self.data
}
pub fn row_slice(&self, row: usize) -> LmResult<&[f32]> {
if self.shape.len() != 2 {
return Err(LmError::DimensionMismatch {
expected: 2,
got: self.shape.len(),
});
}
let cols = self.shape[1];
let start = row * cols;
if start + cols > self.data.len() {
return Err(LmError::DimensionMismatch {
expected: row,
got: self.shape[0],
});
}
Ok(&self.data[start..start + cols])
}
pub fn validate_shape(&self, expected: &[usize]) -> LmResult<()> {
if self.shape != expected {
return Err(LmError::WeightShapeMismatch {
name: String::new(),
expected: expected.to_vec(),
got: self.shape.clone(),
});
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelWeights {
weights: HashMap<String, WeightTensor>,
}
impl ModelWeights {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: impl Into<String>, tensor: WeightTensor) {
self.weights.insert(name.into(), tensor);
}
pub fn get(&self, name: &str) -> LmResult<&WeightTensor> {
self.weights
.get(name)
.ok_or_else(|| LmError::WeightNotFound { name: name.into() })
}
pub fn get_checked(&self, name: &str, expected_shape: &[usize]) -> LmResult<&WeightTensor> {
let t = self.get(name)?;
if t.shape != expected_shape {
return Err(LmError::WeightShapeMismatch {
name: name.into(),
expected: expected_shape.to_vec(),
got: t.shape.clone(),
});
}
Ok(t)
}
pub fn contains(&self, name: &str) -> bool {
self.weights.contains_key(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &WeightTensor)> {
self.weights.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn n_params(&self) -> usize {
self.weights.values().map(|t| t.n_elements()).sum()
}
pub fn len(&self) -> usize {
self.weights.len()
}
pub fn is_empty(&self) -> bool {
self.weights.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn weight_tensor_zeros() {
let w = WeightTensor::zeros(&[4, 8]);
assert_eq!(w.n_elements(), 32);
assert!(w.data.iter().all(|&x| x == 0.0));
assert_eq!(w.shape, vec![4, 8]);
}
#[test]
fn weight_tensor_ones() {
let w = WeightTensor::ones(&[3, 3]);
assert_eq!(w.n_elements(), 9);
assert!(w.data.iter().all(|&x| x == 1.0));
}
#[test]
fn weight_tensor_eye() {
let w = WeightTensor::eye(3, 3);
assert_eq!(w.data[0], 1.0);
assert_eq!(w.data[1], 0.0);
assert_eq!(w.data[4], 1.0); assert_eq!(w.data[8], 1.0); }
#[test]
fn weight_tensor_from_data_ok() {
let d = vec![1.0_f32, 2.0, 3.0, 4.0];
let w = WeightTensor::from_data(d.clone(), vec![2, 2])
.expect("4 elements with shape [2,2] should match");
assert_eq!(w.data, d);
}
#[test]
fn weight_tensor_from_data_shape_mismatch() {
let d = vec![1.0_f32; 5];
let err = WeightTensor::from_data(d, vec![2, 2]).unwrap_err();
assert!(matches!(err, LmError::WeightDataLengthMismatch { .. }));
}
#[test]
fn weight_tensor_row_slice() {
let w = WeightTensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])
.expect("4 elements with shape [2,2] should match");
assert_eq!(
w.row_slice(0).expect("row 0 of 2x2 tensor should exist"),
&[1.0_f32, 2.0]
);
assert_eq!(
w.row_slice(1).expect("row 1 of 2x2 tensor should exist"),
&[3.0_f32, 4.0]
);
}
#[test]
fn weight_tensor_row_slice_non_2d_errors() {
let w = WeightTensor::zeros(&[8]);
assert!(w.row_slice(0).is_err());
}
#[test]
fn weight_tensor_validate_shape_ok() {
let w = WeightTensor::zeros(&[4, 8]);
w.validate_shape(&[4, 8])
.expect("validate_shape should succeed when shape matches");
}
#[test]
fn weight_tensor_validate_shape_fail() {
let w = WeightTensor::zeros(&[4, 8]);
assert!(w.validate_shape(&[8, 4]).is_err());
}
#[test]
fn model_weights_insert_and_get() {
let mut mw = ModelWeights::new();
mw.insert("embed", WeightTensor::zeros(&[10, 4]));
let t = mw
.get("embed")
.expect("'embed' key should exist after insertion");
assert_eq!(t.shape, vec![10, 4]);
}
#[test]
fn model_weights_get_missing_errors() {
let mw = ModelWeights::new();
assert!(matches!(
mw.get("missing"),
Err(LmError::WeightNotFound { .. })
));
}
#[test]
fn model_weights_get_checked_shape_error() {
let mut mw = ModelWeights::new();
mw.insert("w", WeightTensor::zeros(&[4, 8]));
assert!(matches!(
mw.get_checked("w", &[8, 4]),
Err(LmError::WeightShapeMismatch { .. })
));
}
#[test]
fn model_weights_n_params() {
let mut mw = ModelWeights::new();
mw.insert("a", WeightTensor::zeros(&[4, 4])); mw.insert("b", WeightTensor::zeros(&[3, 3])); assert_eq!(mw.n_params(), 25);
}
#[test]
fn model_weights_len_and_empty() {
let mut mw = ModelWeights::new();
assert!(mw.is_empty());
mw.insert("x", WeightTensor::zeros(&[1]));
assert_eq!(mw.len(), 1);
assert!(!mw.is_empty());
}
}