lmm 0.1.6

A language agnostic framework for emulating reality.
Documentation
use crate::error::LmmError::Simulation;
use crate::error::Result;
use rand::RngExt;
use std::ops::{Add, Mul, Sub};

#[derive(Debug, Clone, PartialEq)]
pub struct Tensor {
    pub shape: Vec<usize>,
    pub data: Vec<f64>,
}

impl Tensor {
    pub fn new(shape: Vec<usize>, data: Vec<f64>) -> Result<Self> {
        let expected_len: usize = shape.iter().product();
        if data.len() != expected_len {
            return Err(Simulation(format!(
                "Tensor shape mismatch: expected {} elements but got {}",
                expected_len,
                data.len()
            )));
        }
        Ok(Self { shape, data })
    }

    pub fn zeros(shape: Vec<usize>) -> Self {
        let len: usize = shape.iter().product();
        Self {
            shape,
            data: vec![0.0; len],
        }
    }

    pub fn ones(shape: Vec<usize>) -> Self {
        let len: usize = shape.iter().product();
        Self {
            shape,
            data: vec![1.0; len],
        }
    }

    pub fn fill(shape: Vec<usize>, value: f64) -> Self {
        let len: usize = shape.iter().product();
        Self {
            shape,
            data: vec![value; len],
        }
    }

    pub fn randn(shape: Vec<usize>, mean: f64, std: f64) -> Self {
        let len: usize = shape.iter().product();
        let mut rng = rand::rng();
        let data: Vec<f64> = (0..len)
            .map(|_| {
                let u1: f64 = rng.random::<f64>().max(1e-10);
                let u2: f64 = rng.random();
                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
                mean + std * z
            })
            .collect();
        Self { shape, data }
    }

    pub fn from_vec(data: Vec<f64>) -> Self {
        let len = data.len();
        Self {
            shape: vec![len],
            data,
        }
    }

    pub fn scale(&self, factor: f64) -> Self {
        Self {
            shape: self.shape.clone(),
            data: self.data.iter().map(|&x| x * factor).collect(),
        }
    }

    pub fn map<F: Fn(f64) -> f64>(&self, f: F) -> Self {
        Self {
            shape: self.shape.clone(),
            data: self.data.iter().map(|&x| f(x)).collect(),
        }
    }

    pub fn zip_map<F: Fn(f64, f64) -> f64>(&self, other: &Self, f: F) -> Result<Self> {
        if self.shape != other.shape {
            return Err(Simulation("zip_map shape mismatch".into()));
        }
        let data = self
            .data
            .iter()
            .zip(other.data.iter())
            .map(|(&a, &b)| f(a, b))
            .collect();
        Ok(Self {
            shape: self.shape.clone(),
            data,
        })
    }

    pub fn dot(&self, other: &Self) -> Result<f64> {
        if self.shape != other.shape {
            return Err(Simulation("dot product shape mismatch".into()));
        }
        Ok(self
            .data
            .iter()
            .zip(other.data.iter())
            .map(|(a, b)| a * b)
            .sum())
    }

    pub fn norm(&self) -> f64 {
        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
    }

    pub fn mean(&self) -> f64 {
        if self.data.is_empty() {
            return 0.0;
        }
        self.data.iter().sum::<f64>() / self.data.len() as f64
    }

    pub fn variance(&self) -> f64 {
        let m = self.mean();
        self.data.iter().map(|x| (x - m).powi(2)).sum::<f64>() / self.data.len() as f64
    }

    pub fn argmax(&self) -> usize {
        self.data.iter().enumerate().fold(
            0,
            |best, (i, &v)| {
                if v > self.data[best] { i } else { best }
            },
        )
    }

    pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self> {
        let expected: usize = new_shape.iter().product();
        if expected != self.data.len() {
            return Err(Simulation(format!(
                "Cannot reshape {} elements into shape {:?}",
                self.data.len(),
                new_shape
            )));
        }
        Ok(Self {
            shape: new_shape,
            data: self.data.clone(),
        })
    }

    pub fn transpose(&self) -> Result<Self> {
        if self.shape.len() != 2 {
            return Err(Simulation("transpose requires a 2-D tensor".into()));
        }
        let rows = self.shape[0];
        let cols = self.shape[1];
        let mut data = vec![0.0; rows * cols];
        for r in 0..rows {
            for c in 0..cols {
                data[c * rows + r] = self.data[r * cols + c];
            }
        }
        Ok(Self {
            shape: vec![cols, rows],
            data,
        })
    }

    pub fn matmul(&self, other: &Self) -> Result<Self> {
        if self.shape.len() != 2 || other.shape.len() != 2 {
            return Err(Simulation("matmul requires 2-D tensors".into()));
        }
        let (m, k) = (self.shape[0], self.shape[1]);
        let (k2, n) = (other.shape[0], other.shape[1]);
        if k != k2 {
            return Err(Simulation(format!(
                "matmul dimension mismatch: [{m}x{k}] vs [{k2}x{n}]"
            )));
        }
        let mut data = vec![0.0; m * n];
        for i in 0..m {
            for j in 0..n {
                let mut sum = 0.0;
                for l in 0..k {
                    sum += self.data[i * k + l] * other.data[l * n + j];
                }
                data[i * n + j] = sum;
            }
        }
        Ok(Self {
            shape: vec![m, n],
            data,
        })
    }

    pub fn len(&self) -> usize {
        self.data.len()
    }

    pub fn is_empty(&self) -> bool {
        self.data.is_empty()
    }
}

impl Add for &Tensor {
    type Output = Result<Tensor>;

    fn add(self, rhs: Self) -> Self::Output {
        if self.shape != rhs.shape {
            return Err(Simulation(format!(
                "Tensor add shape mismatch: {:?} vs {:?}",
                self.shape, rhs.shape
            )));
        }
        let data = self
            .data
            .iter()
            .zip(rhs.data.iter())
            .map(|(a, b)| a + b)
            .collect();
        Ok(Tensor {
            shape: self.shape.clone(),
            data,
        })
    }
}

impl Sub for &Tensor {
    type Output = Result<Tensor>;

    fn sub(self, rhs: Self) -> Self::Output {
        if self.shape != rhs.shape {
            return Err(Simulation(format!(
                "Tensor sub shape mismatch: {:?} vs {:?}",
                self.shape, rhs.shape
            )));
        }
        let data = self
            .data
            .iter()
            .zip(rhs.data.iter())
            .map(|(a, b)| a - b)
            .collect();
        Ok(Tensor {
            shape: self.shape.clone(),
            data,
        })
    }
}

impl Mul for &Tensor {
    type Output = Result<Tensor>;

    fn mul(self, rhs: Self) -> Self::Output {
        if self.shape != rhs.shape {
            return Err(Simulation(format!(
                "Tensor mul shape mismatch: {:?} vs {:?}",
                self.shape, rhs.shape
            )));
        }
        let data = self
            .data
            .iter()
            .zip(rhs.data.iter())
            .map(|(a, b)| a * b)
            .collect();
        Ok(Tensor {
            shape: self.shape.clone(),
            data,
        })
    }
}