use std::ops::{Add, Sub, Mul, Div};
use std::cmp::{PartialEq, Eq};
#[derive(Debug, PartialEq,Clone)]
pub struct Matrix {
rows: usize,
cols: usize,
data: Vec<Vec<f64>>,
}
impl Matrix {
pub fn new(rows: usize, cols: usize, data: Vec<Vec<f64>>) -> Self {
Matrix { rows, cols, data }
}
pub fn check(&self) -> Option<usize> {
if self.data.len() != self.rows {
return None; }
if self.data.iter().any(|row| row.len() != self.cols) {
return None; }
Some(1)
}
}
impl Add for Matrix {
type Output = Matrix;
fn add(self, other: Matrix) -> Matrix {
assert_eq!(self.rows, other.rows);
assert_eq!(self.cols, other.cols);
let data: Vec<Vec<f64>> = self.data
.iter()
.zip(other.data.iter())
.map(|(row1, row2)| {
row1.iter()
.zip(row2.iter())
.map(|(&a, &b)| a + b)
.collect()
})
.collect();
Matrix::new(self.rows, self.cols, data)
}
}
impl Sub for Matrix {
type Output = Matrix;
fn sub(self, other: Matrix) -> Matrix {
assert_eq!(self.rows, other.rows);
assert_eq!(self.cols, other.cols);
let data: Vec<Vec<f64>> = self.data
.iter()
.zip(other.data.iter())
.map(|(row1, row2)| {
row1.iter()
.zip(row2.iter())
.map(|(&a, &b)| a - b)
.collect()
})
.collect();
Matrix::new(self.rows, self.cols, data)
}
}
impl Mul for Matrix {
type Output = Matrix;
fn mul(self, other: Matrix) -> Matrix {
assert_eq!(self.cols, other.rows);
let mut result = vec![vec![0.0; other.cols]; self.rows];
for i in 0..self.rows {
for j in 0..other.cols {
for k in 0..self.cols {
result[i][j] += self.data[i][k] * other.data[k][j];
}
}
}
Matrix::new(self.rows, other.cols, result)
}
}