use crate::{error::MatrixError, matrix::Matrix, vector::VectorN};
use super::mtuple::Direction;
impl<const R: usize, const C: usize> Matrix<R, C> {
pub fn scalar_multiply(&mut self, a: f64) {
let matrix_prime: Vec<f64> = self.matrix.iter_mut().map(|f| *f * a).collect();
self.matrix = matrix_prime
}
pub fn multiply<const K: usize>(&self, other: &Matrix<C, K>) -> Matrix<R, K> {
let horizontal_tuples: Vec<VectorN<C>> = (0..R)
.map(|i| VectorN::new(self.get_tuple(i, Direction::Row).try_into().unwrap()))
.collect();
let vertical_tuples: Vec<VectorN<C>> = (0..C)
.map(|i| VectorN::new(other.get_tuple(i, Direction::Column).try_into().unwrap()))
.collect();
let dotproducts: Vec<f64> = (0..(R * K))
.map(|i| {
let x = i % K;
let y = i / C;
horizontal_tuples[y]
.dotproduct(&vertical_tuples[x])
.unwrap()
})
.collect();
Matrix::new(dotproducts).unwrap()
}
}
impl<const R: usize, const C: usize> std::ops::Add for Matrix<R, C> {
type Output = Result<Matrix<R, C>, MatrixError>;
fn add(self, rhs: Matrix<R, C>) -> Self::Output {
let this_data = self.matrix;
let other_data = rhs.matrix;
let new_matrix_data: Vec<f64> = (0..this_data.len())
.map(|i| this_data[i] + other_data[i])
.collect();
match Matrix::new(new_matrix_data) {
Ok(matrix) => Ok(matrix),
Err(error) => Err(error),
}
}
}
impl<const R: usize, const C: usize> std::ops::Sub for Matrix<R, C> {
type Output = Result<Matrix<R, C>, MatrixError>;
fn sub(self, rhs: Matrix<R, C>) -> Self::Output {
let this_data = self.matrix;
let other_data = rhs.matrix;
let new_matrix_data: Vec<f64> = (0..this_data.len())
.map(|i| this_data[i] - other_data[i])
.collect();
match Matrix::new(new_matrix_data) {
Ok(matrix) => Ok(matrix),
Err(error) => Err(error),
}
}
}
impl<const R: usize, const C: usize> std::cmp::PartialEq for Matrix<R, C> {
fn eq(&self, other: &Self) -> bool {
self.matrix == other.matrix
}
}
#[cfg(test)]
mod matrixop_tests {
use super::*;
#[test]
fn test_scalar_multiply() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let mut matrix: Matrix<2, 2> = Matrix::new(data).unwrap();
matrix.scalar_multiply(3.0);
assert_eq!(matrix.data(), &[3.0, 6.0, 9.0, 12.0]);
matrix.scalar_multiply(std::f64::consts::PI);
assert_eq!(
matrix.data(),
&[
3.0 * std::f64::consts::PI,
6.0 * std::f64::consts::PI,
9.0 * std::f64::consts::PI,
12.0 * std::f64::consts::PI
]
);
}
#[test]
fn test_matrix_addition() {
let entries_a = vec![0.0, 1.0, 2.0, 3.0];
let a: Matrix<2, 2> = Matrix::new(entries_a).unwrap();
let entries_b = vec![6.9, 4.2, 2.3, 5.7];
let b: Matrix<2, 2> = Matrix::new(entries_b).unwrap();
let c = (a + b).unwrap();
assert_eq!(c.data(), &[6.9, 5.2, 4.3, 8.7]);
}
#[test]
fn test_matrix_subtraction() {
let entries_a = vec![0.0, 1.0, 2.0, 3.0];
let a: Matrix<2, 2> = Matrix::new(entries_a).unwrap();
let entries_b = vec![6.9, 4.2, 2.3, 5.7];
let b: Matrix<2, 2> = Matrix::new(entries_b).unwrap();
let c = (a - b).unwrap();
assert_eq!(c.data(), &[-6.9, -3.2, -0.2999999999999998, -2.7]);
}
#[test]
fn test_matrix_multiplication_2x2() {
let data_a = vec![1.0, 2.0, 3.0, 4.0];
let data_b = vec![5.0, 2.0, 3.0, 4.0];
let a: Matrix<2, 2> = Matrix::new(data_a).unwrap();
let b: Matrix<2, 2> = Matrix::new(data_b).unwrap();
let c = a.multiply(&b);
assert_eq!(c.data(), &[11.0, 10.0, 27.0, 22.0]);
}
#[test]
fn test_matrix_multiplication_mxn() {
let data_a = vec![1.0, 2.0, 3.0, 4.0, 7.0, 9.0];
let data_b = vec![5.0, 2.0, 6.0, 5.0];
let a: Matrix<3, 2> = Matrix::new(data_a).unwrap();
let b: Matrix<2, 2> = Matrix::new(data_b).unwrap();
let c: Matrix<3, 2> = a.multiply(&b);
assert_eq!(c.data(), &[17.0, 12.0, 39.0, 26.0, 89.0, 59.0]);
}
#[test]
fn test_true_matrix_equivalence() {
let data_a = vec![1.0, 2.0, 3.0, 4.0];
let data_b = vec![1.0, 2.0, 3.0, 4.0];
let a: Matrix<2, 2> = Matrix::new(data_a).unwrap();
let b: Matrix<2, 2> = Matrix::new(data_b).unwrap();
assert_eq!(a, b);
}
#[test]
fn test_false_matrix_equivalence() {
let data_a = vec![1.0, 2.0, 3.0, 4.0];
let data_b = vec![5.0, 6.0, 7.0, 8.0];
let a: Matrix<2, 2> = Matrix::new(data_a).unwrap();
let b: Matrix<2, 2> = Matrix::new(data_b).unwrap();
assert_ne!(a, b);
}
}