mathhook-core 0.2.0

Core mathematical engine for MathHook - expressions, algebra, and solving
Documentation
use crate::core::matrix::numeric_matrix::NumericMatrix;
use crate::error::MathError;
use std::ops::{Add, Mul, Neg, Sub};

impl NumericMatrix {
    pub fn add(&self, other: &Self) -> Result<Self, MathError> {
        if self.rows != other.rows || self.cols != other.cols {
            return Err(MathError::DomainError {
                operation: "NumericMatrix::add".to_string(),
                value: crate::Expression::integer(0),
                reason: format!(
                    "Cannot add matrices with different dimensions: {}x{} and {}x{}",
                    self.rows, self.cols, other.rows, other.cols
                ),
            });
        }

        let data: Vec<f64> = self
            .data
            .iter()
            .zip(other.data.iter())
            .map(|(a, b)| a + b)
            .collect();

        Ok(Self {
            rows: self.rows,
            cols: self.cols,
            data,
        })
    }

    pub fn sub(&self, other: &Self) -> Result<Self, MathError> {
        if self.rows != other.rows || self.cols != other.cols {
            return Err(MathError::DomainError {
                operation: "NumericMatrix::sub".to_string(),
                value: crate::Expression::integer(0),
                reason: format!(
                    "Cannot subtract matrices with different dimensions: {}x{} and {}x{}",
                    self.rows, self.cols, other.rows, other.cols
                ),
            });
        }

        let data: Vec<f64> = self
            .data
            .iter()
            .zip(other.data.iter())
            .map(|(a, b)| a - b)
            .collect();

        Ok(Self {
            rows: self.rows,
            cols: self.cols,
            data,
        })
    }

    pub fn scalar_mul(&self, scalar: f64) -> Self {
        let data: Vec<f64> = self.data.iter().map(|x| x * scalar).collect();
        Self {
            rows: self.rows,
            cols: self.cols,
            data,
        }
    }

    pub fn negate(&self) -> Self {
        let data: Vec<f64> = self.data.iter().map(|x| -x).collect();
        Self {
            rows: self.rows,
            cols: self.cols,
            data,
        }
    }
}

impl Add for NumericMatrix {
    type Output = Result<Self, MathError>;

    fn add(self, other: Self) -> Self::Output {
        NumericMatrix::add(&self, &other)
    }
}

impl Add for &NumericMatrix {
    type Output = Result<NumericMatrix, MathError>;

    fn add(self, other: Self) -> Self::Output {
        NumericMatrix::add(self, other)
    }
}

impl Sub for NumericMatrix {
    type Output = Result<Self, MathError>;

    fn sub(self, other: Self) -> Self::Output {
        NumericMatrix::sub(&self, &other)
    }
}

impl Sub for &NumericMatrix {
    type Output = Result<NumericMatrix, MathError>;

    fn sub(self, other: Self) -> Self::Output {
        NumericMatrix::sub(self, other)
    }
}

impl Mul<f64> for NumericMatrix {
    type Output = Self;

    fn mul(self, scalar: f64) -> Self::Output {
        self.scalar_mul(scalar)
    }
}

impl Mul<f64> for &NumericMatrix {
    type Output = NumericMatrix;

    fn mul(self, scalar: f64) -> Self::Output {
        self.scalar_mul(scalar)
    }
}

impl Mul<NumericMatrix> for f64 {
    type Output = NumericMatrix;

    fn mul(self, matrix: NumericMatrix) -> Self::Output {
        matrix.scalar_mul(self)
    }
}

impl Mul<&NumericMatrix> for f64 {
    type Output = NumericMatrix;

    fn mul(self, matrix: &NumericMatrix) -> Self::Output {
        matrix.scalar_mul(self)
    }
}

impl Neg for NumericMatrix {
    type Output = Self;

    fn neg(self) -> Self::Output {
        self.negate()
    }
}

impl Neg for &NumericMatrix {
    type Output = NumericMatrix;

    fn neg(self) -> Self::Output {
        self.negate()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_add_method() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = NumericMatrix::from_flat(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
        let c = NumericMatrix::add(&a, &b).unwrap();
        assert_eq!(c.get(0, 0).unwrap(), 6.0);
        assert_eq!(c.get(0, 1).unwrap(), 8.0);
        assert_eq!(c.get(1, 0).unwrap(), 10.0);
        assert_eq!(c.get(1, 1).unwrap(), 12.0);
    }

    #[test]
    fn test_sub_method() {
        let a = NumericMatrix::from_flat(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
        let b = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let c = NumericMatrix::sub(&a, &b).unwrap();
        assert_eq!(c.get(0, 0).unwrap(), 4.0);
        assert_eq!(c.get(0, 1).unwrap(), 4.0);
        assert_eq!(c.get(1, 0).unwrap(), 4.0);
        assert_eq!(c.get(1, 1).unwrap(), 4.0);
    }

    #[test]
    fn test_scalar_mul() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = a.scalar_mul(2.0);
        assert_eq!(b.get(0, 0).unwrap(), 2.0);
        assert_eq!(b.get(0, 1).unwrap(), 4.0);
        assert_eq!(b.get(1, 0).unwrap(), 6.0);
        assert_eq!(b.get(1, 1).unwrap(), 8.0);
    }

    #[test]
    fn test_negate() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = a.negate();
        assert_eq!(b.get(0, 0).unwrap(), -1.0);
        assert_eq!(b.get(0, 1).unwrap(), -2.0);
        assert_eq!(b.get(1, 0).unwrap(), -3.0);
        assert_eq!(b.get(1, 1).unwrap(), -4.0);
    }

    #[test]
    fn test_add_trait() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = NumericMatrix::from_flat(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
        let c = (&a + &b).unwrap();
        assert_eq!(c.get(0, 0).unwrap(), 6.0);
    }

    #[test]
    fn test_sub_trait() {
        let a = NumericMatrix::from_flat(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
        let b = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let c = (&a - &b).unwrap();
        assert_eq!(c.get(0, 0).unwrap(), 4.0);
    }

    #[test]
    fn test_mul_trait() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = &a * 2.0;
        assert_eq!(b.get(0, 0).unwrap(), 2.0);

        let c = 3.0 * &a;
        assert_eq!(c.get(0, 0).unwrap(), 3.0);
    }

    #[test]
    fn test_neg_trait() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = -&a;
        assert_eq!(b.get(0, 0).unwrap(), -1.0);
    }

    #[test]
    fn test_dimension_mismatch() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = NumericMatrix::from_flat(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
        assert!(NumericMatrix::add(&a, &b).is_err());
        assert!(NumericMatrix::sub(&a, &b).is_err());
    }
}