use std::ops::{Add, Mul, Sub};
pub type Row = Vec<f32>;
#[derive(Debug)]
struct MatrixMeta {
row_count: usize,
column_count: usize,
}
#[derive(Debug)]
pub struct Matrix {
meta: MatrixMeta,
data: Vec<Row>,
}
impl Matrix {
pub fn new(data: Vec<Row>) -> Matrix {
let meta = MatrixMeta {
row_count: data.len(),
column_count: data.get(0).unwrap().len(),
};
for row_index in 0..data.len() {
if data.get(row_index).unwrap().len() != meta.column_count {
panic!(
"Column counts not match in row {}. (Found {}, Expected {})",
row_index,
data.get(row_index).unwrap().len(),
meta.column_count,
);
}
}
Matrix { meta, data }
}
}
impl PartialEq<Matrix> for Matrix {
fn eq(&self, other: &Matrix) -> bool {
self.meta.row_count == other.meta.row_count
&& self.meta.column_count == other.meta.column_count
&& self.data == other.data
}
}
impl Add<Matrix> for Matrix {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
if self.meta.row_count != rhs.meta.row_count
|| self.meta.column_count != rhs.meta.column_count
{
panic!("These two matrices cannot be added.");
} else {
let mut matrix_vec = vec![];
for row_index in 0..self.meta.row_count {
let mut matrix_column_vec = vec![];
for column_index in 0..self.meta.column_count {
matrix_column_vec.push(
self.data.get(row_index).unwrap().get(column_index).unwrap()
+ rhs.data.get(row_index).unwrap().get(column_index).unwrap(),
);
}
matrix_vec.push(matrix_column_vec);
}
Matrix::new(matrix_vec)
}
}
}
impl Sub<Matrix> for Matrix {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
if self.meta.row_count != rhs.meta.row_count
|| self.meta.column_count != rhs.meta.column_count
{
panic!("These two matrices cannot be substracted.");
} else {
let mut matrix_vec = vec![];
for row_index in 0..self.meta.row_count {
let mut matrix_column_vec = vec![];
for column_index in 0..self.meta.column_count {
matrix_column_vec.push(
self.data.get(row_index).unwrap().get(column_index).unwrap()
- rhs.data.get(row_index).unwrap().get(column_index).unwrap(),
);
}
matrix_vec.push(matrix_column_vec);
}
Matrix::new(matrix_vec)
}
}
}
impl Mul<Matrix> for Matrix {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
if self.meta.column_count != rhs.meta.row_count {
panic!("These two matrices cannot be multiplied.");
} else {
let mut matrix_vec = vec![];
for row_index in 0..self.meta.row_count {
let mut matrix_column_vec = vec![];
for column_index in 0..rhs.meta.column_count {
let mut result: f32 = 0.0;
for (index, value) in self.data.get(row_index).unwrap().iter().enumerate() {
result += value * rhs.data.get(index).unwrap().get(column_index).unwrap();
}
matrix_column_vec.push(result);
}
matrix_vec.push(matrix_column_vec);
}
Matrix::new(matrix_vec)
}
}
}
impl Mul<Matrix> for f32 {
type Output = Matrix;
fn mul(self, matrix: Matrix) -> Self::Output {
let mut matrix_vec = vec![];
for row_index in 0..matrix.meta.row_count {
let mut matrix_column_vec = vec![];
for column_index in 0..matrix.meta.column_count {
matrix_column_vec.push(
self * matrix
.data
.get(row_index)
.unwrap()
.get(column_index)
.unwrap(),
);
}
matrix_vec.push(matrix_column_vec);
}
Matrix::new(matrix_vec)
}
}
#[macro_export]
macro_rules! create_matrix {
($($x:expr), *) => (
{
let mut matrix = Vec::new();
$(
matrix.push($x);
)*
matrix::matrix::Matrix::new(matrix)
}
);
}
#[macro_export]
macro_rules! create_matrix_row {
($($x:expr), *) => (
{
let mut row = matrix::matrix::Row::new();
$(
row.push($x);
)*
row
}
);
}