use crate::decomposition::lu_pivot_decomposition;
use crate::utils::{Arr2DError, back_substitution, forward_substitution};
use std::{
any::type_name,
fmt::{self, Display},
ops::{Div, Index, IndexMut, Mul},
};
#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)]
pub struct Arr2D<T> {
inner: Vec<T>,
pub height: usize,
pub width: usize,
}
impl<T> Arr2D<T> {
pub fn new() -> Self {
Arr2D {
inner: Vec::new(),
height: 0,
width: 0,
}
}
pub fn shape(&self) -> (usize, usize) {
(self.height, self.width)
}
pub fn size(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.height == 0 || self.width == 0
}
pub fn max(&self) -> Option<T>
where
T: std::cmp::PartialOrd + Clone,
{
if self.is_empty() {
return None;
}
Some(
self.inner
.iter()
.reduce(|a, b| if a > b { a } else { b })
.unwrap()
.clone(),
)
}
pub fn min(&self) -> Option<T>
where
T: std::cmp::PartialOrd + Clone,
{
if self.is_empty() {
return None;
}
Some(
self.inner
.iter()
.reduce(|a, b| if a < b { a } else { b })
.unwrap()
.clone(),
)
}
pub fn inverse(&self) -> Result<Arr2D<f64>, Arr2DError>
where
Arr2D<f64>: for<'a> TryFrom<&'a Arr2D<T>>,
{
if self.height != self.width {
return Err(Arr2DError::NonSquareMatrix);
}
let coeff_matrix: &Arr2D<f64> =
&self.try_into().map_err(|_| Arr2DError::ConversionFailed {
from: type_name::<T>(),
to: type_name::<f64>(),
})?;
let size = self.height;
let (l, u, p) =
lu_pivot_decomposition(coeff_matrix).map_err(|_| Arr2DError::SingularMatrix)?;
let mut inverse_matrix = Arr2D::full(0.0, size, size);
for j in 0..size {
let mut b_prime = vec![0.0; size];
for i in 0..size {
b_prime[i] = p[i][j];
}
let mut y = vec![0.0; size];
forward_substitution(&l, size, &b_prime, &mut y);
let mut x_j = vec![0.0; size];
back_substitution(&u, size, &y, &mut x_j);
for i in 0..size {
inverse_matrix[i][j] = x_j[i];
}
}
Ok(inverse_matrix)
}
pub fn reshape(&mut self, height: usize) -> Result<(), Arr2DError> {
let size = self.height * self.width;
if !size.is_multiple_of(height) {
return Err(Arr2DError::InvalidReshape {
size,
new_height: height,
});
}
self.height = height;
self.width = size / height;
Ok(())
}
pub fn map<F, U>(&self, f: F) -> Arr2D<U>
where
F: Fn(&T) -> U,
{
let inner: Vec<_> = self.inner.iter().map(f).collect();
Arr2D {
inner,
height: self.height,
width: self.width,
}
}
pub fn rows(&self) -> Arr2DRows<'_, T> {
Arr2DRows {
data: &self.inner,
width: self.width,
remaining: self.height,
}
}
pub fn rows_mut(&mut self) -> Arr2DRowsMut<'_, T> {
Arr2DRowsMut {
data: self.inner.as_mut_slice(),
width: self.width,
remaining: self.height,
}
}
pub fn from_flat<D>(
inner: D,
default_val: T,
height: usize,
width: usize,
) -> Result<Self, Arr2DError>
where
D: AsRef<[T]>,
T: Clone,
{
let vec_len = inner.as_ref().len();
let Arr2D_size = height * width;
if vec_len > Arr2D_size || Arr2D_size == 0 {
return Err(Arr2DError::InvalidShape {
input_size: (vec_len),
output_size: (Arr2D_size),
});
}
let inner = inner.as_ref().to_vec();
if vec_len < Arr2D_size {
let mut new_inner = inner.clone();
new_inner.resize(Arr2D_size, default_val);
return Ok(Self {
inner: new_inner,
height,
width,
});
}
Ok(Self {
inner,
height,
width,
})
}
pub fn dot(&self, rhs: &Self) -> Result<Self, Arr2DError>
where
T: Copy + std::default::Default + std::ops::AddAssign + std::ops::Mul<Output = T>,
{
if self.height == 1 && self.width == 1 || rhs.height == 1 && rhs.width == 2 {
let mut matrix = Arr2D::new();
let mut scalar = T::default();
if rhs.height == 1 {
matrix = self.clone();
scalar = rhs[0][0];
} else if self.height == 1 {
matrix = rhs.clone();
scalar = self[0][0];
}
let mut result = Arr2D::full(T::default(), matrix.height, matrix.width);
for i in 0..matrix.height {
for j in 0..matrix.width {
result[i][j] = scalar * matrix[i][j];
}
}
return Ok(result);
}
if self.width != rhs.height {
return Err(Arr2DError::InvalidDotShape {
lhs: self.width,
rhs: rhs.height,
});
}
let mut result = Arr2D::full(T::default(), self.height, rhs.width);
for i in 0..self.height {
for j in 0..rhs.width {
let mut sum = T::default();
for k in 0..self.width {
sum += self[i][k] * rhs[k][j]
}
result[i][j] = sum;
}
}
Ok(result)
}
}
impl<'b, T> Mul<&'b Arr2D<T>> for &Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy + std::ops::AddAssign,
{
type Output = Arr2D<T>;
fn mul(self, rhs: &'b Arr2D<T>) -> Arr2D<T> {
self.dot(rhs).unwrap_or_default()
}
}
impl<T> Mul<Arr2D<T>> for Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy + std::ops::AddAssign,
{
type Output = Arr2D<T>;
fn mul(self, rhs: Arr2D<T>) -> Arr2D<T> {
self.dot(&rhs).unwrap_or_default()
}
}
impl<'b, T> Mul<&'b Arr2D<T>> for Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy + std::ops::AddAssign,
{
type Output = Arr2D<T>;
fn mul(self, rhs: &'b Arr2D<T>) -> Arr2D<T> {
self.dot(rhs).unwrap_or_default()
}
}
impl<T> Mul<Arr2D<T>> for &Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy + std::ops::AddAssign,
{
type Output = Arr2D<T>;
fn mul(self, rhs: Arr2D<T>) -> Arr2D<T> {
self.dot(&rhs).unwrap_or_default()
}
}
impl<T> Mul<T> for Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy,
{
type Output = Arr2D<T>;
fn mul(self, rhs: T) -> Arr2D<T> {
&self * rhs
}
}
impl<T> Mul<T> for &Arr2D<T>
where
T: Mul<Output = T> + Clone + std::default::Default + std::marker::Copy,
{
type Output = Arr2D<T>;
fn mul(self, rhs: T) -> Arr2D<T> {
let mut result = Arr2D::full(T::default(), self.height, self.width);
for i in 0..self.height {
for j in 0..self.width {
result[i][j] = self[i][j] * rhs;
}
}
result
}
}
impl<T> Div<T> for &Arr2D<T>
where
T: Div<Output = T> + Clone + std::default::Default + std::marker::Copy,
{
type Output = Arr2D<T>;
fn div(self, rhs: T) -> Arr2D<T> {
let mut result = Arr2D::full(T::default(), self.height, self.width);
for i in 0..self.height {
for j in 0..self.width {
result[i][j] = self[i][j] / rhs;
}
}
result
}
}
impl<T> Div<T> for Arr2D<T>
where
T: Div<Output = T> + Clone + std::default::Default + std::marker::Copy,
{
type Output = Arr2D<T>;
fn div(self, rhs: T) -> Self {
&self / rhs
}
}
impl<T: Copy> Arr2D<T> {
pub fn as_scalar(&self) -> Option<T> {
if self.height == 1 && self.width == 1 {
Some(self.inner[0])
} else {
None
}
}
pub fn as_scalar_unchecked(&self) -> T {
self.inner[0]
}
pub fn transpose(&self) -> Arr2D<T> {
let mut new_inner = Vec::with_capacity(self.inner.len());
for col in 0..self.width {
for row in 0..self.height {
new_inner.push(self[(row, col)]);
}
}
Arr2D {
inner: new_inner,
height: self.width,
width: self.height,
}
}
pub fn transpose_mut(&mut self) {
let mut new_inner = Vec::with_capacity(self.inner.len());
for col in 0..self.width {
for row in 0..self.height {
new_inner.push(self[(row, col)]);
}
}
self.inner = new_inner;
std::mem::swap(&mut self.width, &mut self.height);
}
pub fn swap_rows(&mut self, mut a: usize, mut b: usize) {
if a == b {
return;
}
if a > b {
std::mem::swap(&mut a, &mut b);
}
let w = self.width;
let (left, right) = self.inner.split_at_mut(b * w);
let row_b = &mut right[..w];
let row_a = &mut left[a * w..(a + 1) * w];
row_a.swap_with_slice(row_b);
}
pub fn full(val: T, height: usize, width: usize) -> Self {
Arr2D {
inner: vec![val; height * width],
height,
width,
}
}
pub fn identity(size: usize) -> Self
where
T: From<i32> + Default + Copy,
{
let mut ident_mat: Arr2D<T> = Arr2D::full(T::from(0), size, size);
for i in 0..size {
ident_mat[i][i] = T::from(1);
}
ident_mat
}
}
pub struct Arr2DRows<'a, T> {
data: &'a [T],
width: usize,
remaining: usize,
}
impl<'a, T> Iterator for Arr2DRows<'a, T> {
type Item = &'a [T];
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
self.remaining -= 1;
if self.width == 0 {
Some(&self.data[..0])
} else {
let (row, rest) = self.data.split_at(self.width);
self.data = rest;
Some(row)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
pub struct Arr2DRowsMut<'a, T> {
data: &'a mut [T],
width: usize,
remaining: usize,
}
impl<'a, T> Iterator for Arr2DRowsMut<'a, T> {
type Item = &'a mut [T];
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let (row, rest) = std::mem::take(&mut self.data).split_at_mut(self.width);
self.data = rest;
self.remaining -= 1;
Some(row)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<T> TryFrom<Vec<Vec<T>>> for Arr2D<T> {
type Error = Arr2DError;
fn try_from(values: Vec<Vec<T>>) -> Result<Self, Self::Error> {
if values.is_empty() {
return Ok(Self {
inner: vec![],
height: 0,
width: 0,
});
}
let width = values[0].len();
let height = values.len();
let mut inner: Vec<T> = Vec::with_capacity(height * width);
for row in values {
if row.len() != width {
return Err(Arr2DError::InconsistentRowLengths);
}
inner.extend(row);
}
Ok(Self {
inner,
height,
width,
})
}
}
impl<T: Clone, U: TryFrom<T>> TryFrom<&Vec<Vec<T>>> for Arr2D<U> {
type Error = Arr2DError;
fn try_from(values: &Vec<Vec<T>>) -> Result<Self, Self::Error> {
if values.is_empty() {
return Ok(Self {
inner: vec![],
height: 0,
width: 0,
});
}
let width = values[0].len();
let mut inner = Vec::with_capacity(values.len() * width);
for row in values {
if row.len() != width {
return Err(Arr2DError::InconsistentRowLengths);
}
for x in row {
inner.push(
U::try_from(x.clone()).map_err(|_| Arr2DError::ConversionFailed {
from: type_name::<T>(),
to: type_name::<U>(),
})?,
);
}
}
Ok(Self {
inner,
height: values.len(),
width,
})
}
}
impl<T: Clone, U: TryFrom<T>> TryFrom<&Arr2D<T>> for Arr2D<U> {
type Error = Arr2DError;
fn try_from(arr: &Arr2D<T>) -> Result<Self, Self::Error> {
if arr.is_empty() {
return Ok(Self {
inner: vec![],
height: 0,
width: 0,
});
}
let mut inner = Vec::with_capacity(arr.inner.len());
for x in &arr.inner {
inner.push(
U::try_from(x.clone()).map_err(|_| Arr2DError::ConversionFailed {
from: type_name::<T>(),
to: type_name::<U>(),
})?,
);
}
Ok(Self {
inner,
height: arr.height,
width: arr.width,
})
}
}
impl<'a, T> IntoIterator for &'a Arr2D<T> {
type Item = &'a [T];
type IntoIter = Arr2DRows<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.rows()
}
}
impl<'a, T> IntoIterator for &'a mut Arr2D<T> {
type Item = &'a mut [T];
type IntoIter = Arr2DRowsMut<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.rows_mut()
}
}
impl<T, const M: usize, const N: usize> From<&[[T; N]; M]> for Arr2D<T>
where
T: Clone,
{
fn from(values: &[[T; N]; M]) -> Self {
let mut inner = Vec::with_capacity(M * N);
for row in values.iter() {
inner.extend_from_slice(row);
}
Self {
inner,
height: M,
width: N,
}
}
}
impl<T> Index<(usize, usize)> for Arr2D<T> {
type Output = T;
fn index(&self, idx: (usize, usize)) -> &Self::Output {
let (row, col) = idx;
if row >= self.height || col >= self.width {
panic!(
"Out of bound index ({row},{col}) into Arr2D of shape ({},{})",
self.height, self.width
)
}
&self.inner[row * self.width + col]
}
}
impl<T> IndexMut<(usize, usize)> for Arr2D<T> {
fn index_mut(&mut self, idx: (usize, usize)) -> &mut Self::Output {
let (row, col) = idx;
if row >= self.height || col >= self.width {
panic!(
"Out of bound index ({row},{col}) into Arr2D of shape ({},{})",
self.height, self.width
)
}
&mut self.inner[row * self.width + col]
}
}
impl<T> Index<usize> for Arr2D<T> {
type Output = [T];
fn index(&self, row: usize) -> &Self::Output {
if row >= self.height {
panic!(
"Out of bound row index {row} into Arr2D of shape ({},{})",
self.height, self.width
)
}
&self.inner[row * self.width..(row + 1) * self.width]
}
}
impl<T> IndexMut<usize> for Arr2D<T> {
fn index_mut(&mut self, row: usize) -> &mut Self::Output {
if row >= self.height {
panic!(
"Out of bound row index {row} into Arr2D of shape ({},{})",
self.height, self.width
)
}
&mut self.inner[row * self.width..(row + 1) * self.width]
}
}
impl<T: Display> Display for Arr2D<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.height == 0 || self.width == 0 {
writeln!(f, "[]")?;
return Ok(());
}
let mut col_widths = vec![0; self.width];
for c in 0..self.width {
col_widths[c] = (0..self.height)
.map(|r| format!("{}", self[(r, c)]).len())
.max()
.unwrap_or(0);
}
for r in 0..self.height {
if r == 0 {
write!(f, "[[ ")?;
} else {
write!(f, " [ ")?;
}
for c in 0..self.width {
let item = &self[(r, c)];
write!(f, "{:>width$}", *item, width = col_widths[c])?;
if c + 1 != self.width {
write!(f, ", ")?;
}
}
if r + 1 == self.height {
write!(f, " ]]")?;
} else {
writeln!(f, " ]")?;
}
}
Ok(())
}
}
pub trait Rounding {
fn round_to_decimal(&self, decimals: u32) -> Self;
}
impl Rounding for Arr2D<f64> {
fn round_to_decimal(&self, decimals: u32) -> Arr2D<f64> {
let factor = (10.0_f64).powi(decimals as i32);
self.clone().map(|&val| (val * factor).round() / factor)
}
}
impl<T: PartialEq> PartialEq<Vec<Vec<T>>> for Arr2D<T> {
fn eq(&self, other: &Vec<Vec<T>>) -> bool {
if self.height != other.len() {
return false;
}
if self.height == 0 {
return true; }
let width = self.width;
if other.iter().any(|row| row.len() != width) {
return false;
}
for r in 0..self.height {
for c in 0..self.width {
if self[r][c] != other[r][c] {
return false;
}
}
}
true
}
}
impl<T: PartialEq> PartialEq<Arr2D<T>> for Vec<Vec<T>> {
fn eq(&self, other: &Arr2D<T>) -> bool {
other == self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_getting_shape() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data.shape(), (2, 3));
let data: Arr2D<i32> = Arr2D::from(&[[1, 2]; 0]);
assert_eq!(data.shape(), (0, 2));
let data: Arr2D<i32> = Arr2D::from(&[[]; 0]);
assert_eq!(data.shape(), (0, 0));
let data: Arr2D<i32> = Arr2D::from(&[[], []]);
assert_eq!(data.shape(), (2, 0));
}
#[test]
fn test_getting_size() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data.size(), 6);
let data: Arr2D<i32> = Arr2D::from(&[[1, 2]; 0]);
assert_eq!(data.size(), 0);
let data: Arr2D<i32> = Arr2D::from(&[[]; 0]);
assert_eq!(data.size(), 0);
let data: Arr2D<i32> = Arr2D::from(&[[], []]);
assert_eq!(data.size(), 0);
}
#[test]
fn test_indexing_item() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data[(0, 0)], 1);
assert_eq!(data[(0, 1)], 2);
assert_eq!(data[(0, 2)], 3);
assert_eq!(data[(1, 0)], 6);
assert_eq!(data[(1, 1)], 5);
assert_eq!(data[(1, 2)], 4);
}
#[test]
fn test_2D_indexing_item() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data[0][0], 1);
assert_eq!(data[0][1], 2);
assert_eq!(data[0][2], 3);
assert_eq!(data[1][0], 6);
assert_eq!(data[1][1], 5);
assert_eq!(data[1][2], 4);
}
#[test]
fn test_mut_indexing_item() {
let mut data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
data[(1, 2)] = 10;
data[(0, 0)] = 11;
assert_eq!(data[(0, 0)], 11);
assert_eq!(data[(0, 1)], 2);
assert_eq!(data[(0, 2)], 3);
assert_eq!(data[(1, 0)], 6);
assert_eq!(data[(1, 1)], 5);
assert_eq!(data[(1, 2)], 10);
}
#[test]
fn test_2D_mut_indexing_item() {
let mut data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
data[(1, 2)] = 10;
data[(0, 0)] = 11;
assert_eq!(data[0][0], 11);
assert_eq!(data[0][1], 2);
assert_eq!(data[0][2], 3);
assert_eq!(data[1][0], 6);
assert_eq!(data[1][1], 5);
assert_eq!(data[1][2], 10);
}
#[test]
fn test_indexing_row() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data[0], [1, 2, 3]);
assert_eq!(data[1], [6, 5, 4]);
}
#[test]
fn test_mut_indexing_row() {
let mut data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
data[0][2] = 9;
data[1][0] = 10;
assert_eq!(data[0], [1, 2, 9]);
assert_eq!(data[1], [10, 5, 4]);
}
#[test]
#[should_panic]
fn test_index_out_of_bounds_panics() {
let data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
let _ = data[(2, 0)];
}
#[test]
#[should_panic]
fn test_mut_index_out_of_bounds_panics() {
let mut data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
let _ = &mut data[(2, 0)];
}
#[test]
#[should_panic]
fn test_row_index_out_of_bounds_panics() {
let data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
let _ = &data[2];
}
#[test]
#[should_panic]
fn test_row_mut_index_out_of_bounds_panics() {
let mut data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
let _ = &mut data[2];
}
#[test]
fn test_rows_iterator_returns_slices() {
let data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
let rows: Vec<&[i32]> = data.rows().collect();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0], &[1, 2, 3]);
assert_eq!(rows[1], &[4, 5, 6]);
}
#[test]
fn test_iterating_through_items() {
let data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
let mut expected = [1, 2, 3, 6, 5, 4].iter();
for row in data.rows() {
for item in row {
assert_eq!(item, expected.next().unwrap());
}
}
}
#[test]
fn test_rows_mut_iterator_allows_mutation() {
let mut data = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
for row in data.rows_mut() {
row.reverse();
}
let expected = Arr2D::from(&[[3, 2, 1], [6, 5, 4]]);
assert_eq!(data, expected);
}
#[test]
fn test_reshape() {
let mut data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
data.reshape(3).unwrap();
let expected = Arr2D::from(&[[1, 2], [3, 6], [5, 4]]);
assert_eq!(data, expected);
}
#[test]
fn test_reshape_invalid_height_errors() {
let mut data = Arr2D::from(&[[1, 2], [3, 4]]);
let err = data
.reshape(3)
.expect_err("reshape should fail when new height mismatches size");
assert!(matches!(
err,
Arr2DError::InvalidReshape {
size: 4,
new_height: 3
}
));
}
#[test]
fn test_transpose() {
let mut data = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
data.transpose_mut();
let expected = Arr2D::from(&[[1, 6], [2, 5], [3, 4]]);
assert_eq!(data, expected);
}
#[test]
fn test_arr_from_vec() {
let data = Arr2D::try_from(vec![vec![1, 2, 3], vec![6, 5, 4]]).unwrap();
let expected = Arr2D::from(&[[1, 2, 3], [6, 5, 4]]);
assert_eq!(data, expected);
}
#[test]
fn test_try_from_inconsistent_rows_returns_error() {
let err = Arr2D::try_from(vec![vec![1, 2, 3], vec![4, 5]])
.expect_err("rows with different widths should error");
assert!(matches!(err, Arr2DError::InconsistentRowLengths));
}
#[test]
fn test_try_from_empty_vec_creates_empty_arr() {
let data = Arr2D::try_from(Vec::<Vec<i32>>::new()).unwrap();
assert_eq!(data.rows().count(), 0);
}
#[test]
fn test_map_transforms_elements() {
let data = Arr2D::from(&[[1, 2], [3, 4]]);
let mapped = data.map(|value| value * 2);
let expected = Arr2D::from(&[[2, 4], [6, 8]]);
assert_eq!(mapped, expected);
}
#[test]
fn test_full() {
let data = Arr2D::full(10, 3, 4);
for row in data.rows() {
for item in row {
assert_eq!(*item, 10);
}
}
}
#[test]
fn test_from_flat() {
let data = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_ref() {
let data = Arr2D::from_flat(&vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_slice() {
let data = Arr2D::from_flat(&[1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[1, 2, 3], [4, 5, 6]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_with_default() {
let data = Arr2D::from_flat(vec![1, 2, 3, 4], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[1, 2, 3], [4, 0, 0]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_with_default_ref() {
let data = Arr2D::from_flat(&vec![1, 2, 3, 4], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[1, 2, 3], [4, 0, 0]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_full_zeros() {
let flat_data = Arr2D::from_flat(&vec![], 0, 2, 3).unwrap();
let full_data = Arr2D::full(0, 2, 3);
assert_eq!(flat_data, full_data);
}
#[test]
fn test_from_flat_slice_full_zeros() {
let data = Arr2D::from_flat(&[], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[0, 0, 0], [0, 0, 0]]);
assert_eq!(data, out);
}
#[test]
fn test_from_flat_slice_full_zeros_no_ref() {
let data = Arr2D::from_flat([], 0, 2, 3).unwrap();
let out = Arr2D::from(&[[0, 0, 0], [0, 0, 0]]);
assert_eq!(data, out);
}
#[test]
fn test_mat_mul_mat() {
let arr1 = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![7, 8, 9, 10, 11, 12], 0, 3, 2).unwrap();
let expected = Arr2D::from_flat(vec![58, 64, 139, 154], 0, 2, 2).unwrap();
let res = arr1.dot(&arr2).unwrap();
assert_eq!(res, expected);
}
#[test]
fn test_vec_mul_mat() {
let arr1 = Arr2D::from_flat(vec![3, 4, 2], 0, 1, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![13, 9, 7, 15, 8, 7, 4, 6, 6, 4, 0, 3], 0, 3, 4).unwrap();
let expected = Arr2D::from_flat(vec![83, 63, 37, 75], 0, 1, 4).unwrap();
let res = arr1.dot(&arr2).unwrap();
assert_eq!(res, expected);
}
#[test]
fn test_scalar_mul_mat() {
let scal = Arr2D::from_flat(vec![2], 0, 1, 1).unwrap();
let mat = Arr2D::from_flat(vec![4, 0, 1, -9], 0, 2, 2).unwrap();
let expected = Arr2D::from_flat(vec![8, 0, 2, -18], 0, 2, 2).unwrap();
let res = scal.dot(&mat).unwrap();
assert_eq!(res, expected);
}
#[test]
fn test_Mul_trait_scalar() {
let scal = 2;
let mat = Arr2D::from_flat(vec![4, 0, 1, -9], 0, 2, 2).unwrap();
let expected = Arr2D::from_flat(vec![8, 0, 2, -18], 0, 2, 2).unwrap();
let res = mat * scal;
assert_eq!(res, expected)
}
#[test]
fn test_Mul_trait_scalar_2() {
let scal = 2;
let mat = Arr2D::from_flat(vec![4, 0, 1, -9], 0, 2, 2).unwrap();
let expected = Arr2D::from_flat(vec![8, 0, 2, -18], 0, 2, 2).unwrap();
let res = &mat * scal;
assert_eq!(res, expected)
}
#[test]
fn test_Mul_trait_mat_mul_mat() {
let arr1 = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![7, 8, 9, 10, 11, 12], 0, 3, 2).unwrap();
let expected = Arr2D::from_flat(vec![58, 64, 139, 154], 0, 2, 2).unwrap();
let res = arr1 * arr2;
assert_eq!(res, expected);
}
#[test]
fn test__borrowed_Mul_trait_mat_mul_mat() {
let arr1 = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![7, 8, 9, 10, 11, 12], 0, 3, 2).unwrap();
let expected = Arr2D::from_flat(vec![58, 64, 139, 154], 0, 2, 2).unwrap();
let res = arr1 * &arr2;
assert_eq!(res, expected);
}
#[test]
fn test__borrowed_Mul_trait_mat_mul_mat_2() {
let arr1 = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![7, 8, 9, 10, 11, 12], 0, 3, 2).unwrap();
let expected = Arr2D::from_flat(vec![58, 64, 139, 154], 0, 2, 2).unwrap();
let res = &arr1 * arr2;
assert_eq!(res, expected);
}
#[test]
fn test__double_borrowed_Mul_trait_mat_mul_mat() {
let arr1 = Arr2D::from_flat(vec![1, 2, 3, 4, 5, 6], 0, 2, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![7, 8, 9, 10, 11, 12], 0, 3, 2).unwrap();
let expected = Arr2D::from_flat(vec![58, 64, 139, 154], 0, 2, 2).unwrap();
let res = &arr1 * &arr2;
assert_eq!(res, expected);
}
#[test]
fn test_Mul_trait_vec_mul_mat() {
let arr1 = Arr2D::from_flat(vec![3, 4, 2], 0, 1, 3).unwrap();
let arr2 = Arr2D::from_flat(vec![13, 9, 7, 15, 8, 7, 4, 6, 6, 4, 0, 3], 0, 3, 4).unwrap();
let expected = Arr2D::from_flat(vec![83, 63, 37, 75], 0, 1, 4).unwrap();
let res = arr1 * arr2;
assert_eq!(res, expected);
}
#[test]
fn test_Div_trait_mat_div_scalar() {
let mat = Arr2D::from_flat(vec![1.778, 0.0, 1.778], 0.0, 3, 1).unwrap();
let result = mat / 1.778;
let expected = Arr2D::from(&[[1.0], [0.0], [1.0]]);
assert_eq!(result, expected);
}
#[test]
fn test_borrow_Div_trait_mat_div_scalar() {
let mat = Arr2D::from_flat(vec![1.778, 0.0, 1.778], 0.0, 3, 1).unwrap();
let result = &mat / 1.778;
let expected = Arr2D::from(&[[1.0], [0.0], [1.0]]);
assert_eq!(result, expected);
}
#[test]
fn test_display() {
let data = Arr2D::from(&[[1.2, 34.5678], [789.02, 0.123]]);
let out = format!("{data}");
let expected = r#"
[[ 1.2, 34.5678 ]
[ 789.02, 0.123 ]]"#;
assert_eq!(&out, &expected[1..]);
}
#[test]
fn test_int_max() {
let data = Arr2D::from(&[[12], [10], [11], [20], [9]]);
let result = data.max().unwrap();
let expected = 20;
assert_eq!(result, expected);
}
#[test]
fn test_float_max() {
let data = Arr2D::from(&[[23.3], [12.4], [23.4], [10.4]]);
let result = data.max().unwrap();
let expected = 23.4;
assert_eq!(result, expected);
}
#[test]
fn test_int_min() {
let data = Arr2D::from(&[[12], [10], [11], [20], [9]]);
let result = data.min().unwrap();
let expected = 9;
assert_eq!(result, expected);
}
#[test]
fn test_float_min() {
let data = Arr2D::from(&[[23.3], [12.4], [23.4], [10.4]]);
let result = data.min().unwrap();
let expected = 10.4;
assert_eq!(result, expected);
}
#[test]
fn test_known_inverse() {
let matrix = Arr2D::from(&[[3, 0, 2], [2, 0, -2], [0, 1, 1]]);
let result = matrix.inverse().unwrap();
let rounded_result = result.round_to_decimal(1);
let expected = Arr2D::from(&[[0.2, 0.2, 0.0], [-0.2, 0.3, 1.0], [0.2, -0.3, 0.0]]);
assert_eq!(rounded_result, expected);
}
#[test]
fn test_known_inverse_2() {
let matrix = Arr2D::from(&[
[3.556, -1.778, 0.0],
[-1.778, 3.556, -1.778],
[0.0, -1.778, 3.556],
]);
let result = matrix.inverse().unwrap();
let rounded_result = result.round_to_decimal(3);
let expected = Arr2D::from(&[
[0.422, 0.281, 0.141],
[0.281, 0.562, 0.281],
[0.141, 0.281, 0.422],
]);
assert_eq!(rounded_result, expected);
}
}