mathrc 0.2.5

Rust Mathematics Library
Documentation
use std::ops::{Add, AddAssign, Index, IndexMut, Mul};

use crate::err::MatrixError;
use num_traits::Float;

#[derive(Debug, Clone, PartialEq)]
pub struct Matrix<T: Float> {
    data: Vec<T>,
    rows: usize,
    cols: usize,
}

impl<T: Float + AddAssign> Matrix<T> {
    pub fn new(data: Vec<T>, rows: usize, cols: usize) -> Result<Self, MatrixError> {
        let expected = rows * cols;
        if data.len() != expected {
            return Err(MatrixError::InvalidSize { expected, got: data.len() });
        }
        Ok(Self { data, rows, cols })
    }

    pub fn identity(n: usize) -> Self {
        let mut data = vec![T::zero(); n * n];
        for i in 0..n {
            data[i * n + i] = T::one();
        }
        Self { data, rows: n, cols: n }
    }

    pub fn rows(&self) -> usize { self.rows }
    pub fn cols(&self) -> usize { self.cols }

    pub fn get(&self, r: usize, c: usize) -> Option<T> {
        if r >= self.rows || c >= self.cols {
            return None;
        }
        Some(self.data[r * self.cols + c])
    }

    pub fn set(&mut self, r: usize, c: usize, value: T) {
        self.data[r * self.cols + c] = value;
    }

    pub fn transpose(&self) -> Self {
        let mut data = vec![T::zero(); self.rows * self.cols];
        for r in 0..self.rows {
            for c in 0..self.cols {
                data[c * self.rows + r] = self[(r, c)];
            }
        }
        Self { data, rows: self.cols, cols: self.rows }
    }
}

impl<T: Float> Index<(usize, usize)> for Matrix<T> {
    type Output = T;

    fn index(&self, (r, c): (usize, usize)) -> &Self::Output {
        &self.data[r * self.cols + c]
    }
}

impl<T: Float> IndexMut<(usize, usize)> for Matrix<T> {
    fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut Self::Output {
        &mut self.data[r * self.cols + c]
    }
}

impl<T: Float + AddAssign> Add for Matrix<T> {
    type Output = Result<Self, MatrixError>;

    fn add(self, rhs: Self) -> Self::Output {
        if self.rows != rhs.rows || self.cols != rhs.cols {
            return Err(MatrixError::DimensionMismatch {
                lhs: (self.rows, self.cols),
                rhs: (rhs.rows, rhs.cols),
            });
        }
        let data = self.data.iter().zip(rhs.data.iter()).map(|(a, b)| *a + *b).collect();
        Ok(Self { data, rows: self.rows, cols: self.cols })
    }
}

impl<T: Float + AddAssign> Mul for Matrix<T> {
    type Output = Result<Self, MatrixError>;

    fn mul(self, rhs: Self) -> Self::Output {
        if self.cols != rhs.rows {
            return Err(MatrixError::DimensionMismatch {
                lhs: (self.rows, self.cols),
                rhs: (rhs.rows, rhs.cols),
            });
        }
        let mut data = vec![T::zero(); self.rows * rhs.cols];
        for r in 0..self.rows {
            for c in 0..rhs.cols {
                let mut sum = T::zero();
                for k in 0..self.cols {
                    sum += self[(r, k)] * rhs[(k, c)];
                }
                data[r * rhs.cols + c] = sum;
            }
        }
        Ok(Self { data, rows: self.rows, cols: rhs.cols })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn get_element() {
        let m = Matrix::new(
            vec![0f64, 1f64, 0f64, 1f64, 0f64, 1f64, 1f64, 0f64, 1f64, 0f64, 1f64, 0f64],
            3, 4,
        ).unwrap();
        assert_eq!(m.get(1, 2), Some(1f64));
    }

    #[test]
    fn identity_matrix() {
        let m = Matrix::<f64>::identity(2);
        assert_eq!(m.get(0, 0), Some(1f64));
        assert_eq!(m.get(0, 1), Some(0f64));
        assert_eq!(m.get(1, 1), Some(1f64));
    }

    #[test]
    fn add_matrices() {
        let a = Matrix::new(vec![1f64, 2f64, 3f64, 4f64], 2, 2).unwrap();
        let b = Matrix::new(vec![5f64, 6f64, 7f64, 8f64], 2, 2).unwrap();
        let c = (a + b).unwrap();
        assert_eq!(c.get(0, 0), Some(6f64));
        assert_eq!(c.get(1, 1), Some(12f64));
    }

    #[test]
    fn mul_matrices() {
        let a = Matrix::new(vec![1f64, 2f64, 3f64, 4f64], 2, 2).unwrap();
        let b = Matrix::new(vec![5f64, 6f64, 7f64, 8f64], 2, 2).unwrap();
        let c = (a * b).unwrap();
        assert_eq!(c.get(0, 0), Some(19f64));
        assert_eq!(c.get(1, 1), Some(50f64));
    }
}