#![allow(unused_assignments)]
use std::fmt::{Debug, Display};
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub};
use crate::tensor_library::errors::MatrixError;
use crate::tensor_library::layout::Layout;
use crate::tensor_library::utils::{calc_concat_shape, calc_strides_from_shape, check_concat_dims};
#[derive(Debug, Clone)]
pub struct Matrix<T> where T: Clone + Default + 'static{
pub shape: Vec<usize>,
pub strides: Vec<usize>,
pub data: Vec<T>,
pub layout: Layout,
pub size: usize
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn new(_shape: Vec<usize>, layout: Layout) -> Matrix<T>
where
T: Default,
{
Matrix::from_iter(_shape, (0..).map(|_| T::default()), layout)
}
pub fn from_iter(
_shape: Vec<usize>,
_data: impl IntoIterator<Item=T>,
_layout: Layout,
) -> Matrix<T> {
assert!(!_shape.is_empty());
let _temp_shape = _shape.clone();
Matrix {
shape: _shape,
strides: {
let mut data_size: usize = 1;
let mut strides: Vec<usize> = vec![0; _temp_shape.len()];
if _layout == Layout::RowMajor {
for i in (1..(_temp_shape.len() + 1)).rev() {
strides[i - 1] = data_size;
data_size = strides[i - 1] * _temp_shape[i - 1];
}
}
else {
for i in 0.._temp_shape.len() {
strides[i] = data_size;
data_size = strides[i] * _temp_shape[i];
}
}
strides
},
data: {
let data: Vec<_> = _data
.into_iter()
.take(_temp_shape.iter().copied().reduce(|a, b| a * b).unwrap())
.collect();
assert_eq!(
data.len(),
_temp_shape.iter().copied().reduce(|a, b| a * b).unwrap()
);
data
},
layout: _layout,
size: _temp_shape.iter().copied().reduce(|a, b| a * b).unwrap()
}
}
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn size(&self) -> usize {
self.shape.iter().copied().reduce(|a, b| a * b).unwrap()
}
pub fn shape(&self) -> &Vec<usize> {
&self.shape
}
pub fn reshape(&mut self, new_shape: &Vec<usize>) -> Result<(), MatrixError> {
let size: usize = new_shape.iter().copied().reduce(|a, b| a * b).unwrap();
if size == self.size {
self.shape = new_shape.clone();
self.strides = calc_strides_from_shape(new_shape, self.layout);
Ok(())
} else {
Err(MatrixError::ReshapeError)
}
}
pub fn strides(&self) -> &Vec<usize> {
&self.strides
}
pub fn set_strides(&mut self, _strides: &[usize]) {
self.strides = (*_strides.to_owned()).to_owned();
}
pub fn set_shape(&mut self, _shape: &[usize]) {
self.shape = (*_shape.to_owned()).to_owned();
}
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn check_bounds(&self, idx: &Vec<usize>) -> Result<bool, MatrixError> {
if idx.len() != self.shape.len() {
return Err(MatrixError::DimError);
}
match !idx.iter().zip(self.shape.iter()).any(|(x, y)| x >= y) {
true => Ok(true),
false => Err(MatrixError::OutOfBounds),
}
}
pub fn get_physical_idx(&self, idx: &Vec<usize>) -> Result<usize, MatrixError> {
let mut return_val: usize = 0;
match self.check_bounds(idx) {
Ok(_) => {
for (i, index) in idx.iter().enumerate() {
return_val += index * self.strides[i];
}
Ok(return_val)
}
Err(err) => Err(err),
}
}
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn get(&self, idx: &Vec<usize>) -> Result<&T, MatrixError> {
match self.get_physical_idx(idx) {
Ok(physical_idx) => Ok(&self.data[physical_idx]),
Err(m_err) => Err(m_err),
}
}
pub fn get_copy(&self, idx: &Vec<usize>) -> Result<T, MatrixError> {
match self.get_physical_idx(idx) {
Ok(physical_idx) => Ok(self.data[physical_idx].clone()),
Err(m_err) => Err(m_err),
}
}
pub fn get_mut(&mut self, idx: &Vec<usize>) -> Result<&mut T, MatrixError> {
match self.get_physical_idx(idx) {
Ok(physical_idx) => Ok(&mut self.data[physical_idx]),
Err(m_err) => Err(m_err),
}
}
pub fn set(&mut self, idx: &Vec<usize>, value: T) -> Result<(), MatrixError> {
match self.get_mut(idx) {
Ok(cell) => {
*cell = value;
Ok(())
}
Err(m_err) => Err(m_err),
}
}
pub fn get_copy_row(&self, idx: &mut Vec<usize>) -> Result<Vec<T>, MatrixError> {
if idx.len() != self.shape().len() - 1 {
return Err(MatrixError::DimError);
}
let mut result = vec![T::default(); *self.shape().last().unwrap()];
for (i, item) in result.iter_mut().enumerate() {
idx.push(i);
match self.get_physical_idx(idx) {
Ok(physical_idx) => {
*item = self.data[physical_idx].clone();
},
Err(m_err) => {
return Err(m_err);
},
}
idx.pop();
}
Ok(result)
}
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn apply<F: FnMut(&T)>(&self, func: F) {
self.data.iter().for_each(func);
}
pub fn apply_mut<F: FnMut(&mut T)>(&mut self, func: F) {
self.data.iter_mut().for_each(func);
}
}
impl<T> Matrix<T> where T: Clone + Default {
pub fn transpose(&mut self) {
self.shape.reverse();
self.strides.reverse();
match self.layout {
Layout::RowMajor => self.layout = Layout::ColumnMajor,
Layout::ColumnMajor => self.layout = Layout::RowMajor,
}
}
pub fn flatten(&mut self){
match self.reshape(&vec![self.size()]) {
Ok(_) => {},
Err(err) => panic!("{}", err)
}
}
}
type BroadcastRetType = Result<(Vec<usize>, Vec<usize>, Vec<usize>), MatrixError>;
pub fn broadcast(
lhs_shape: &Vec<usize>,
lhs_layout: Layout,
rhs_shape: &Vec<usize>,
rhs_layout: Layout,
) -> BroadcastRetType {
let lhs_shape = if lhs_shape.len() < rhs_shape.len() {
let ones = vec![1; rhs_shape.len() - lhs_shape.len()];
[&ones[..], &lhs_shape[..]].concat()
} else {
lhs_shape.clone()
};
let rhs_shape = if rhs_shape.len() < lhs_shape.len() {
let ones = vec![1; lhs_shape.len() - rhs_shape.len()];
[&ones[..], &rhs_shape[..]].concat()
} else {
rhs_shape.clone()
};
let mut broadcasted_shape: Vec<usize> = Vec::with_capacity(lhs_shape.len());
let mut broadcasted_lhs_strides: Vec<usize> = calc_strides_from_shape(&lhs_shape, lhs_layout);
let mut broadcasted_rhs_strides: Vec<usize> = calc_strides_from_shape(&rhs_shape, rhs_layout);
for (i, (&lhs, &rhs)) in lhs_shape.iter().zip(rhs_shape.iter()).enumerate() {
if lhs == rhs {
broadcasted_shape.push(lhs);
} else if lhs == 1 {
broadcasted_shape.push(rhs);
broadcasted_lhs_strides[i] = 0;
} else if rhs == 1 {
broadcasted_shape.push(lhs);
broadcasted_rhs_strides[i] = 0;
} else {
return Err(MatrixError::BroadcastError);
}
}
Ok((
broadcasted_shape,
broadcasted_lhs_strides,
broadcasted_rhs_strides,
))
}
type ConcatRetType<T> = Result<(Matrix<T>, Matrix<T>, Matrix<T>), MatrixError>;
pub fn concat<T>(lhs: Matrix<T>, rhs: Matrix<T>, axis: usize) -> ConcatRetType<T> where T: Clone + Default + Debug {
if !check_concat_dims(lhs.shape(), rhs.shape(), axis) {
return Err(MatrixError::DimError);
}
let lhs_iter: MatrixIter<T> = MatrixIter {
mat: &lhs,
index: vec![0; lhs.shape().len()],
current_el: None,
empty: false,
};
let rhs_iter: MatrixIter<T> = MatrixIter {
mat: &rhs,
index: vec![0; rhs.shape().len()],
current_el: None,
empty: false,
};
let f_shape = calc_concat_shape(lhs.shape(), rhs.shape(), axis).unwrap();
let mut f_matrix: Matrix<T> = Matrix::new(f_shape, Layout::RowMajor);
for (item, idx) in lhs_iter {
match f_matrix.set(&idx, item) {
Ok(_) => {},
Err(err) => {
return Err(err);
}
}
}
for (item, mut idx) in rhs_iter {
idx[axis] += lhs.shape()[axis];
match f_matrix.set(&idx, item) {
Ok(_) => {},
Err(err) => {
return Err(err);
}
}
}
Ok((f_matrix, lhs, rhs))
}
type SubRetType<T> = Result<(Matrix<T>, Matrix<T>, Matrix<T>), MatrixError>;
pub fn subtract<T>(mut lhs: Matrix<T>,mut rhs: Matrix<T>) -> SubRetType<T> where T: Clone + Default + Sub + Sub<Output = T>, <T as Sub>::Output: Clone + Default{
let mut final_shape: Vec<usize> = vec![];
match broadcast(lhs.shape(), lhs.layout, rhs.shape(), rhs.layout) {
Ok((_shape, _lhs_strides, _rhs_strides)) => {
lhs.set_shape(&_shape);
rhs.set_shape(&_shape);
lhs.set_strides(&_lhs_strides);
rhs.set_strides(&_rhs_strides);
final_shape = _shape;
},
Err(err) => {
return Err(err);
}
}
let lhs_iter: MatrixIter<T> = MatrixIter {
mat: &lhs,
index: vec![0; lhs.shape().len()],
current_el: None,
empty: false,
};
let rhs_iter: MatrixIter<T> = MatrixIter {
mat: &rhs,
index: vec![0; rhs.shape().len()],
current_el: None,
empty: false,
};
let mut new_matrix = Matrix::new(final_shape, Layout::RowMajor);
for ((lhs_item, lhs_index), (rhs_item, rhs_index)) in lhs_iter.zip(rhs_iter){
assert_eq!(lhs_index, rhs_index);
match new_matrix.set(&lhs_index, lhs_item - rhs_item) {
Ok(_) => {},
Err(err) => {
return Err(err);
}
}
}
Ok((new_matrix, lhs, rhs))
}
type AddRetType<T> = Result<(Matrix<T>, Matrix<T>, Matrix<T>), MatrixError>;
pub fn add<T>(mut lhs: Matrix<T>,mut rhs: Matrix<T>) -> AddRetType<T> where T: Clone + Default + Add + Add<Output = T>, <T as Add>::Output: Clone + Default{
let mut final_shape: Vec<usize> = vec![];
match broadcast(lhs.shape(), lhs.layout, rhs.shape(), rhs.layout) {
Ok((_shape, _lhs_strides, _rhs_strides)) => {
lhs.set_shape(&_shape);
rhs.set_shape(&_shape);
lhs.set_strides(&_lhs_strides);
rhs.set_strides(&_rhs_strides);
final_shape = _shape;
},
Err(err) => {
return Err(err);
}
}
let lhs_iter: MatrixIter<T> = MatrixIter {
mat: &lhs,
index: vec![0; lhs.shape().len()],
current_el: None,
empty: false,
};
let rhs_iter: MatrixIter<T> = MatrixIter {
mat: &rhs,
index: vec![0; rhs.shape().len()],
current_el: None,
empty: false,
};
let mut new_matrix = Matrix::new(final_shape, Layout::RowMajor);
for ((lhs_item, lhs_index), (rhs_item, rhs_index)) in lhs_iter.zip(rhs_iter){
assert_eq!(lhs_index, rhs_index);
match new_matrix.set(&lhs_index, lhs_item + rhs_item) {
Ok(_) => {},
Err(err) => {
return Err(err);
}
}
}
Ok((new_matrix, lhs, rhs))
}
pub fn multiply_scalar<T>(mut lhs: Matrix<T>, rhs: T) -> Matrix<T> where T: Clone + Default + Mul + Mul<Output = T> + MulAssign, <T as Mul>::Output: Clone + Default{
for i in 0..lhs.data.len(){
lhs.data[i] = lhs.data[i].clone() * rhs.clone();
}
lhs
}
pub fn multiply_scalar_generic<T>(mut lhs: Matrix<i32>, rhs: T) -> Matrix<T> where T: Clone + Default + Mul<i32, Output = T>, i32: Mul<T>{
let mut new_matrix: Matrix<T> = Matrix::new(lhs.shape().clone(), lhs.layout.clone());
for i in 0..lhs.data.len(){
new_matrix.data[i] = rhs.clone() * lhs.data[i].clone();
}
new_matrix
}
pub fn multiply_scalar_diff_type<T>(mut lhs: Matrix<T>, rhs: i32) -> Matrix<T> where T: Clone + Default + Mul<i32, Output = T>{
for i in 0..lhs.data.len(){
lhs.data[i] = lhs.data[i].clone() * rhs.clone();
}
lhs
}
type MulRetType2D<T> = Result<(Matrix<T>, Matrix<T>, Matrix<T>), MatrixError>;
pub fn multiply_2d<T>(mut lhs: Matrix<T>, mut rhs: Matrix<T>) -> MulRetType2D<T> where T: Display + Clone + Default + Mul + Mul<Output = T> + MulAssign + AddAssign, <T as Mul>::Output: Clone + Default{
let mut final_shape: Vec<usize> = Vec::new();
if lhs.shape.len() != 2 || rhs.shape.len() != 2 {
return Err(MatrixError::MatmulShapeError);
}
match broadcast(lhs.shape(), lhs.layout, rhs.shape(), rhs.layout) {
Ok((_shape, _lhs_strides, _rhs_strides)) => {
lhs.set_shape(&_shape);
rhs.set_shape(&_shape);
lhs.set_strides(&_lhs_strides);
rhs.set_strides(&_rhs_strides);
final_shape = _shape;
},
Err(err) => {
return Err(err);
}
}
let mut new_matrix = Matrix::new(final_shape.clone(), Layout::RowMajor);
let mut curr_sum = T::default();
for i in 0..final_shape[0] {
for j in 0..final_shape[1]{
curr_sum = T::default();
for l in 0..final_shape[1]{
curr_sum += lhs.get_copy(&vec![i, l]).unwrap() * rhs.get_copy(&vec![l, j]).unwrap();
}
match new_matrix.set(&vec![i, j], curr_sum){
Ok(_) => {},
Err(err) => {
return Err(err);
}
}
}
}
Ok((new_matrix, lhs, rhs))
}
type MulRetType1D<T> = Result<(T, Matrix<T>, Matrix<T>), MatrixError>;
pub fn multiply_1d<T>(lhs: Matrix<T>, rhs: Matrix<T>) -> MulRetType1D<T> where T: Display + Clone + Default + Mul + Mul<Output = T> + MulAssign + AddAssign, <T as Mul>::Output: Clone + Default{
if lhs.shape.len() != 1 || rhs.shape.len() != 1 && lhs.shape[0] == rhs.shape[0]{
return Err(MatrixError::MatmulShapeError);
}
let mut curr_sum: T = T::default();
for i in 0..lhs.shape()[0] {
curr_sum += lhs.get_copy(&vec![i]).unwrap() * lhs.get_copy(&vec![i]).unwrap();
}
Ok((curr_sum, lhs, rhs))
}
#[derive(Debug, Clone)]
pub struct MatrixIter<'a, T> where T: Clone + Default + 'static{
pub mat: &'a Matrix<T>,
pub index: Vec<usize>,
pub current_el: Option<(T, Vec<usize>)>,
pub empty: bool,
}
impl<T> Iterator for MatrixIter<'_, T> where T: Clone + Default{
type Item = (T, Vec<usize>);
fn next(&mut self) -> Option<Self::Item> {
match self.mat.check_bounds(&self.index) {
Ok(_) => {
if !self.empty {
self.current_el = Some((self.mat.get_copy(&self.index).unwrap(), self.index.clone()));
}
else {
return None;
}
},
Err(_) => {
return None;
}
}
let dims = self.mat.shape();
let mut i: i32 = (self.mat.shape().len() - 1) as i32;
while i >= 0 {
if self.index[i as usize] + 1 < dims[i as usize] {
self.index[i as usize] += 1;
break;
} else {
self.index[i as usize] = 0;
if i == 0 {
self.empty = true;
break;
}
i -= 1;
}
}
self.current_el.clone()
}
}