use crate::{
common::Numeric,
error::{Error, Result},
implementations::HeapMatrix,
matrix::Matrix,
};
use std::{
fmt::{Debug, Display, Formatter},
ops::*,
};
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct StackMatrix<T: Numeric, const X: usize, const Y: usize>
where
[T; X * Y]: Sized,
{
pub(crate) data: [T; X * Y],
pub(crate) x_len: usize,
pub(crate) y_len: usize,
}
impl<'a, T: Numeric, const X: usize, const Y: usize> Add for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output { self.mat_add(rhs).unwrap() }
}
impl<T: Numeric, const X: usize, const Y: usize> Add<HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<Self>;
fn add(self, rhs: HeapMatrix<T>) -> Self::Output { self.mat_add(rhs) }
}
impl<T: Numeric, const X: usize, const Y: usize> Add<&Self> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn add(self, rhs: &Self) -> Self::Output { self.mat_add(*rhs).unwrap() }
}
impl<T: Numeric, const X: usize, const Y: usize> Add<&HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<Self>;
fn add(self, rhs: &HeapMatrix<T>) -> Self::Output { self.mat_add(rhs.clone()) }
}
impl<T: Numeric, const X: usize, const Y: usize> Sub for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output { self.mat_sub(rhs).unwrap() }
}
impl<T: Numeric, const X: usize, const Y: usize> Sub<HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<Self>;
fn sub(self, rhs: HeapMatrix<T>) -> Self::Output { self.mat_sub(rhs) }
}
impl<T: Numeric, const X: usize, const Y: usize> Sub<&Self> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn sub(self, rhs: &Self) -> Self::Output { self.mat_sub(*rhs).unwrap() }
}
impl<T: Numeric, const X: usize, const Y: usize> Sub<&HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<Self>;
fn sub(self, rhs: &HeapMatrix<T>) -> Self::Output { self.mat_sub(rhs.clone()) }
}
impl<T: Numeric, const X: usize, const Y: usize, const Z: usize, const W: usize>
Mul<StackMatrix<T, Z, W>> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
[T; Z * W]: Sized,
[T; Z * Y]: Sized,
{
type Output = Result<StackMatrix<T, Z, Y>>;
fn mul(self, rhs: StackMatrix<T, Z, W>) -> Self::Output { self.mat_mul(rhs) }
}
impl<T: Numeric, const X: usize, const Y: usize> Mul<HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<HeapMatrix<T>>;
fn mul(self, rhs: HeapMatrix<T>) -> Self::Output { self.mat_mul(rhs) }
}
impl<T: Numeric, const X: usize, const Y: usize, const Z: usize, const W: usize>
Mul<&StackMatrix<T, Z, W>> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
[T; Z * W]: Sized,
[T; Z * Y]: Sized,
{
type Output = Result<StackMatrix<T, Z, Y>>;
fn mul(self, rhs: &StackMatrix<T, Z, W>) -> Self::Output { self.mat_mul(*rhs) }
}
impl<T: Numeric, const X: usize, const Y: usize> Mul<&HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Result<HeapMatrix<T>>;
fn mul(self, rhs: &HeapMatrix<T>) -> Self::Output { self.mat_mul(rhs.clone()) }
}
impl<T: Numeric, const X: usize, const Y: usize> Add<T> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn add(self, rhs: T) -> Self::Output {
let data: Vec<T> = self.data.iter().map(|x| *x + rhs).collect();
Self::new_from_slice(&data).unwrap()
}
}
impl<T: Numeric, const X: usize, const Y: usize> Sub<T> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn sub(self, rhs: T) -> Self::Output {
let data: Vec<T> = self.data.iter().map(|x| *x - rhs).collect();
Self::new_from_slice(&data).unwrap()
}
}
impl<T: Numeric, const X: usize, const Y: usize> Mul<T> for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
type Output = Self;
fn mul(self, rhs: T) -> Self::Output {
let data: Vec<T> = self.data.iter().map(|x| *x * rhs).collect();
Self::new_from_slice(&data).unwrap()
}
}
impl<T: Numeric, const X: usize, const Y: usize> StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
pub fn new(data: [[T; X]; Y]) -> Self {
let mut dat: [T; X * Y] = [data[0][0]; X * Y];
for i in 0..dat.len() {
dat[i] = data[i / X][i % X];
}
Self {
data: dat,
x_len: X,
y_len: Y,
}
}
pub const fn new_1d(data: [T; X * Y]) -> Self {
Self {
data,
x_len: X,
y_len: Y,
}
}
pub fn new_from_slice(data: &[T]) -> Result<Self> {
if data.len() != X * Y {
return Err(Error::IncorrectLength);
}
let mut array: [T; X * Y] = [T::default(); X * Y];
array
.iter_mut()
.enumerate()
.map(|(index, x)| *x = data[index])
.last();
Ok(Self {
data: array,
x_len: X,
y_len: Y,
})
}
}
impl<T: Numeric, const X: usize, const Y: usize> PartialEq<HeapMatrix<T>>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
fn eq(&self, other: &HeapMatrix<T>) -> bool {
if X != other.x_len || Y != other.y_len {
return false;
}
other.data == self.data
}
}
impl<T: Numeric, const X: usize, const Y: usize> Display for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_printable())
}
}
impl<'a, T: 'a + Numeric, const X: usize, const Y: usize> Matrix<'a, T>
for StackMatrix<T, X, Y>
where
[T; X * Y]: Sized,
{
fn get_data(&self) -> &[T] { &self.data }
fn get_data_mut(&mut self) -> &mut [T] { &mut self.data }
fn get_x_len(&self) -> usize { self.x_len }
fn get_y_len(&self) -> usize { self.y_len }
fn mat_new(data: &[&[T]]) -> Result<Self> {
if data.len() != Y || data[0].len() != X {
return Err(Error::IncorrectLength);
}
let mut array: [[T; X]; Y] = [[T::default(); X]; Y];
array
.iter_mut()
.enumerate()
.map(|(y_index, y)| {
y.iter_mut()
.enumerate()
.map(|(x_index, x)| {
*x = data[y_index][x_index];
})
.last()
})
.last();
Ok(Self::new(array))
}
fn mat_new_1d(data: &[T], _columns: usize, _rows: usize) -> Result<Self> {
Self::new_from_slice(data)
}
fn mat_new_vec(data: Vec<Vec<T>>) -> Result<Self> {
if data.len() != Y || data[0].len() != X {
return Err(Error::IncorrectLength);
}
let mut array: [[T; X]; Y] = [[T::default(); X]; Y];
array
.iter_mut()
.enumerate()
.map(|(y_index, y)| {
y.iter_mut()
.enumerate()
.map(|(x_index, x)| {
*x = data[y_index][x_index];
})
.last()
})
.last();
Ok(Self::new(array))
}
}