#[cfg(feature = "serde")]
use serde::Serialize;
mod errors;
pub mod iterators;
pub mod operations;
pub mod slices;
pub mod views;
pub use errors::ScalarConversionError;
use crate::linear_algebra;
use crate::matrices::iterators::*;
use crate::matrices::slices::Slice2D;
use crate::matrices::views::{
IndexRange, MatrixMask, MatrixPart, MatrixQuadrants, MatrixRange, MatrixReverse, MatrixView,
Reverse,
};
use crate::numeric::extra::{Real, RealRef};
use crate::numeric::{Numeric, NumericRef};
use std::num::NonZeroUsize;
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct Matrix<T> {
data: Vec<T>,
rows: Row,
columns: Column,
}
pub type Row = usize;
pub type Column = usize;
impl<T> Matrix<T> {
pub fn from_scalar(value: T) -> Matrix<T> {
Matrix {
data: vec![value],
rows: 1,
columns: 1,
}
}
#[track_caller]
pub fn row(values: Vec<T>) -> Matrix<T> {
assert!(!values.is_empty(), "No values provided");
Matrix {
columns: values.len(),
data: values,
rows: 1,
}
}
#[track_caller]
pub fn column(values: Vec<T>) -> Matrix<T> {
assert!(!values.is_empty(), "No values provided");
Matrix {
rows: values.len(),
data: values,
columns: 1,
}
}
#[track_caller]
pub fn from(mut values: Vec<Vec<T>>) -> Matrix<T> {
assert!(!values.is_empty(), "No rows defined");
assert!(!values[0].is_empty(), "No column defined");
assert!(
values.iter().map(|x| x.len()).all(|x| x == values[0].len()),
"Inconsistent size"
);
let rows = values.len();
let columns = values[0].len();
let mut data = Vec::with_capacity(rows * columns);
let mut value_stream = values.drain(..);
for _ in 0..rows {
let mut value_row_stream = value_stream.next().unwrap();
let mut row_of_values = value_row_stream.drain(..);
for _ in 0..columns {
data.push(row_of_values.next().unwrap());
}
}
Matrix {
data,
rows,
columns,
}
}
#[track_caller]
pub fn from_flat_row_major(size: (Row, Column), values: Vec<T>) -> Matrix<T> {
assert!(
size.0 * size.1 == values.len(),
"Inconsistent size, attempted to construct a {}x{} matrix but provided with {} elements.",
size.0,
size.1,
values.len()
);
assert!(!values.is_empty(), "No values provided");
Matrix {
data: values,
rows: size.0,
columns: size.1,
}
}
#[track_caller]
pub fn from_fn<F>(size: (Row, Column), mut producer: F) -> Matrix<T>
where
F: FnMut((Row, Column)) -> T,
{
use crate::tensors::indexing::ShapeIterator;
let length = size.0 * size.1;
let mut data = Vec::with_capacity(length);
let iterator = ShapeIterator::from([("row", size.0), ("column", size.1)]);
for [r, c] in iterator {
data.push(producer((r, c)));
}
Matrix::from_flat_row_major(size, data)
}
#[deprecated(
since = "1.1.0",
note = "Incorrect use of terminology, a unit matrix is another term for an identity matrix, please use `from_scalar` instead"
)]
pub fn unit(value: T) -> Matrix<T> {
Matrix::from_scalar(value)
}
pub fn size(&self) -> (Row, Column) {
(self.rows, self.columns)
}
pub fn rows(&self) -> Row {
self.rows
}
pub fn columns(&self) -> Column {
self.columns
}
fn get_index(&self, row: Row, column: Column) -> usize {
column + (row * self.columns())
}
#[allow(dead_code)]
fn get_row_column(&self, index: usize) -> (Row, Column) {
(index / self.columns(), index % self.columns())
}
#[track_caller]
pub fn get_reference(&self, row: Row, column: Column) -> &T {
assert!(row < self.rows(), "Row out of index");
assert!(column < self.columns(), "Column out of index");
&self.data[self.get_index(row, column)]
}
#[track_caller]
pub fn get_reference_mut(&mut self, row: Row, column: Column) -> &mut T {
assert!(row < self.rows(), "Row out of index");
assert!(column < self.columns(), "Column out of index");
let index = self.get_index(row, column);
&mut self.data[index]
}
pub(crate) fn _try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
if row < self.rows() && column < self.columns() {
Some(&self.data[self.get_index(row, column)])
} else {
None
}
}
pub(crate) unsafe fn _get_reference_unchecked(&self, row: Row, column: Column) -> &T {
unsafe { self.data.get_unchecked(self.get_index(row, column)) }
}
#[track_caller]
pub fn set(&mut self, row: Row, column: Column, value: T) {
assert!(row < self.rows(), "Row out of index");
assert!(column < self.columns(), "Column out of index");
let index = self.get_index(row, column);
self.data[index] = value;
}
pub(crate) fn _try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
if row < self.rows() && column < self.columns() {
let index = self.get_index(row, column);
Some(&mut self.data[index])
} else {
None
}
}
pub(crate) unsafe fn _get_reference_unchecked_mut(
&mut self,
row: Row,
column: Column,
) -> &mut T {
unsafe {
let index = self.get_index(row, column);
self.data.get_unchecked_mut(index)
}
}
#[track_caller]
pub fn remove_row(&mut self, row: Row) {
assert!(self.rows() > 1);
let mut r = 0;
let mut c = 0;
let columns = self.columns();
self.data.retain(|_| {
let keep = r != row;
if c < (columns - 1) {
c += 1;
} else {
r += 1;
c = 0;
}
keep
});
self.rows -= 1;
}
#[track_caller]
pub fn remove_column(&mut self, column: Column) {
assert!(self.columns() > 1);
let mut r = 0;
let mut c = 0;
let columns = self.columns();
self.data.retain(|_| {
let keep = c != column;
if c < (columns - 1) {
c += 1;
} else {
r += 1;
c = 0;
}
keep
});
self.columns -= 1;
}
#[track_caller]
pub fn column_reference_iter(&self, column: Column) -> ColumnReferenceIterator<'_, T> {
ColumnReferenceIterator::new(self, column)
}
#[track_caller]
pub fn row_reference_iter(&self, row: Row) -> RowReferenceIterator<'_, T> {
RowReferenceIterator::new(self, row)
}
#[track_caller]
pub fn column_reference_mut_iter(
&mut self,
column: Column,
) -> ColumnReferenceMutIterator<'_, T> {
ColumnReferenceMutIterator::new(self, column)
}
#[track_caller]
pub fn row_reference_mut_iter(&mut self, row: Row) -> RowReferenceMutIterator<'_, T> {
RowReferenceMutIterator::new(self, row)
}
pub fn column_major_reference_iter(&self) -> ColumnMajorReferenceIterator<'_, T> {
ColumnMajorReferenceIterator::new(self)
}
pub fn row_major_reference_iter(&self) -> RowMajorReferenceIterator<'_, T> {
RowMajorReferenceIterator::new(self)
}
pub(crate) fn direct_row_major_reference_iter(&self) -> std::slice::Iter<'_, T> {
self.data.iter()
}
pub(crate) fn direct_row_major_reference_iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
self.data.iter_mut()
}
pub fn column_major_reference_mut_iter(&mut self) -> ColumnMajorReferenceMutIterator<'_, T> {
ColumnMajorReferenceMutIterator::new(self)
}
pub fn row_major_reference_mut_iter(&mut self) -> RowMajorReferenceMutIterator<'_, T> {
RowMajorReferenceMutIterator::new(self)
}
pub fn column_major_owned_iter(self) -> ColumnMajorOwnedIterator<T>
where
T: Default,
{
ColumnMajorOwnedIterator::new(self)
}
pub fn row_major_owned_iter(self) -> RowMajorOwnedIterator<T>
where
T: Default,
{
RowMajorOwnedIterator::new(self)
}
pub fn diagonal_reference_iter(&self) -> DiagonalReferenceIterator<'_, T> {
DiagonalReferenceIterator::new(self)
}
pub fn diagonal_reference_mut_iter(&mut self) -> DiagonalReferenceMutIterator<'_, T> {
DiagonalReferenceMutIterator::new(self)
}
#[track_caller]
pub fn retain_mut(&mut self, slice: Slice2D) {
let mut r = 0;
let mut c = 0;
let columns = self.columns();
self.data.retain(|_| {
let keep = slice.accepts(r, c);
if c < (columns - 1) {
c += 1;
} else {
r += 1;
c = 0;
}
keep
});
let remaining_rows = {
let mut accepted = 0;
for i in 0..self.rows() {
if slice.rows.accepts(i) {
accepted += 1;
}
}
accepted
};
let remaining_columns = {
let mut accepted = 0;
for i in 0..self.columns() {
if slice.columns.accepts(i) {
accepted += 1;
}
}
accepted
};
assert!(
remaining_rows > 0,
"Provided slice must leave at least 1 row in the retained matrix"
);
assert!(
remaining_columns > 0,
"Provided slice must leave at least 1 column in the retained matrix"
);
assert!(
!self.data.is_empty(),
"Provided slice must leave at least 1 row and 1 column in the retained matrix"
);
self.rows = remaining_rows;
self.columns = remaining_columns
}
pub fn try_into_scalar(self) -> Result<T, ScalarConversionError> {
if self.size() == (1, 1) {
Ok(self.data.into_iter().next().unwrap())
} else {
Err(ScalarConversionError {})
}
}
#[track_caller]
pub fn partition(
&mut self,
row_partitions: &[Row],
column_partitions: &[Column],
) -> Vec<MatrixView<T, MatrixPart<'_, T>>> {
let rows = self.rows();
let columns = self.columns();
fn check_axis(partitions: &[usize], length: usize) {
let mut previous: Option<usize> = None;
for &index in partitions {
assert!(index <= length);
previous = match previous {
None => Some(index),
Some(i) => {
assert!(index > i, "{:?} must be ascending", partitions);
Some(i)
}
}
}
}
check_axis(row_partitions, rows);
check_axis(column_partitions, columns);
let row_slices = row_partitions.len() + 1;
let column_slices = column_partitions.len() + 1;
let total_slices = row_slices * column_slices;
let mut slices: Vec<Vec<&mut [T]>> = Vec::with_capacity(total_slices);
let (_, mut data) = self.data.split_at_mut(0);
let mut index = 0;
for r in 0..row_slices {
let row_index = row_partitions.get(r).cloned().unwrap_or(rows);
let rows_included = row_index - index;
for _ in 0..column_slices {
slices.push(Vec::with_capacity(rows_included));
}
index = row_index;
for _ in 0..rows_included {
let mut index = 0;
for c in 0..column_slices {
let column_index = column_partitions.get(c).cloned().unwrap_or(columns);
let columns_included = column_index - index;
index = column_index;
let (slice, rest) = data.split_at_mut(columns_included);
slices[(r * column_slices) + c].push(slice);
data = rest;
}
}
}
slices
.into_iter()
.map(|slices| {
let rows = slices.len();
let columns = slices.first().map(|columns| columns.len()).unwrap_or(0);
if columns == 0 {
MatrixView::from(MatrixPart::new(slices, 0, 0))
} else {
MatrixView::from(MatrixPart::new(slices, rows, columns))
}
})
.collect()
}
#[track_caller]
#[allow(clippy::needless_lifetimes)] pub fn partition_quadrants<'a>(
&'a mut self,
row: Row,
column: Column,
) -> MatrixQuadrants<'a, T> {
let mut parts = self.partition(&[row], &[column]).into_iter();
MatrixQuadrants {
top_left: parts.next().unwrap(),
top_right: parts.next().unwrap(),
bottom_left: parts.next().unwrap(),
bottom_right: parts.next().unwrap(),
}
}
pub fn range<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, &Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(self, rows, columns))
}
pub fn range_mut<R>(
&mut self,
rows: R,
columns: R,
) -> MatrixView<T, MatrixRange<T, &mut Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(self, rows, columns))
}
pub fn range_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(self, rows, columns))
}
pub fn mask<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, &Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(self, rows, columns))
}
pub fn mask_mut<R>(
&mut self,
rows: R,
columns: R,
) -> MatrixView<T, MatrixMask<T, &mut Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(self, rows, columns))
}
pub fn mask_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, Matrix<T>>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(self, rows, columns))
}
#[track_caller]
pub fn start_and_end_of_rows(&self, retain: usize) -> MatrixView<T, MatrixMask<T, &Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_rows(self, Some(retain))),
}
}
#[track_caller]
pub fn start_and_end_of_rows_mut(
&mut self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, &mut Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_rows(self, Some(retain))),
}
}
#[track_caller]
pub fn start_and_end_of_rows_owned(
self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_rows(self, Some(retain))),
}
}
#[track_caller]
pub fn start_and_end_of_columns(
&self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, &Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => {
MatrixView::from(MatrixMask::start_and_end_of_columns(self, Some(retain)))
}
}
}
#[track_caller]
pub fn start_and_end_of_columns_mut(
&mut self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, &mut Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => {
MatrixView::from(MatrixMask::start_and_end_of_columns(self, Some(retain)))
}
}
}
#[track_caller]
pub fn start_and_end_of_columns_owned(
self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, Matrix<T>>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => {
MatrixView::from(MatrixMask::start_and_end_of_columns(self, Some(retain)))
}
}
}
pub fn reverse(&self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, &Matrix<T>>> {
MatrixView::from(MatrixReverse::from(self, reverse))
}
pub fn reverse_mut(
&mut self,
reverse: Reverse,
) -> MatrixView<T, MatrixReverse<T, &mut Matrix<T>>> {
MatrixView::from(MatrixReverse::from(self, reverse))
}
pub fn reverse_owned(self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, Matrix<T>>> {
MatrixView::from(MatrixReverse::from(self, reverse))
}
pub fn into_tensor(
self,
rows: crate::tensors::Dimension,
columns: crate::tensors::Dimension,
) -> Result<crate::tensors::Tensor<T, 2>, crate::tensors::InvalidShapeError<2>> {
(self, [rows, columns]).try_into()
}
}
impl<T: Clone> Matrix<T> {
pub fn transpose(&self) -> Matrix<T> {
Matrix::from_fn((self.columns(), self.rows()), |(column, row)| {
self.get(row, column)
})
}
pub fn transpose_mut(&mut self) {
if self.rows() != self.columns() {
let transposed = self.transpose();
self.data = transposed.data;
self.rows = transposed.rows;
self.columns = transposed.columns;
} else {
for i in 0..self.rows() {
for j in 0..self.columns() {
if i > j {
continue;
}
let temp = self.get(i, j);
self.set(i, j, self.get(j, i));
self.set(j, i, temp);
}
}
}
}
#[track_caller]
pub fn column_iter(&self, column: Column) -> ColumnIterator<'_, T> {
ColumnIterator::new(self, column)
}
#[track_caller]
pub fn row_iter(&self, row: Row) -> RowIterator<'_, T> {
RowIterator::new(self, row)
}
pub fn column_major_iter(&self) -> ColumnMajorIterator<'_, T> {
ColumnMajorIterator::new(self)
}
pub fn row_major_iter(&self) -> RowMajorIterator<'_, T> {
RowMajorIterator::new(self)
}
pub fn diagonal_iter(&self) -> DiagonalIterator<'_, T> {
DiagonalIterator::new(self)
}
#[track_caller]
pub fn empty(value: T, size: (Row, Column)) -> Matrix<T> {
assert!(size.0 > 0 && size.1 > 0, "Size must be at least 1x1");
Matrix {
data: vec![value; size.0 * size.1],
rows: size.0,
columns: size.1,
}
}
#[track_caller]
pub fn get(&self, row: Row, column: Column) -> T {
assert!(
row < self.rows(),
"Row out of index, only have {} rows",
self.rows()
);
assert!(
column < self.columns(),
"Column out of index, only have {} columns",
self.columns()
);
self.data[self.get_index(row, column)].clone()
}
#[track_caller]
pub fn scalar(&self) -> T {
assert!(
self.rows() == 1,
"Cannot treat matrix as scalar as it has more than one row"
);
assert!(
self.columns() == 1,
"Cannot treat matrix as scalar as it has more than one column"
);
self.get(0, 0)
}
pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
for value in self.data.iter_mut() {
*value = mapping_function(value.clone());
}
}
pub fn map_mut_with_index(&mut self, mapping_function: impl Fn(T, Row, Column) -> T) {
self.row_major_reference_mut_iter()
.with_index()
.for_each(|((i, j), x)| {
*x = mapping_function(x.clone(), i, j);
});
}
pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Matrix<U>
where
U: Clone,
{
let mapped = self
.data
.iter()
.map(|x| mapping_function(x.clone()))
.collect();
Matrix::from_flat_row_major(self.size(), mapped)
}
pub fn map_with_index<U>(&self, mapping_function: impl Fn(T, Row, Column) -> U) -> Matrix<U>
where
U: Clone,
{
let mapped = self
.row_major_iter()
.with_index()
.map(|((i, j), x)| mapping_function(x, i, j))
.collect();
Matrix::from_flat_row_major(self.size(), mapped)
}
#[track_caller]
pub fn insert_row(&mut self, row: Row, value: T) {
assert!(
row <= self.rows(),
"Row to insert must be <= to {}",
self.rows()
);
for column in 0..self.columns() {
self.data.insert(self.get_index(row, column), value.clone());
}
self.rows += 1;
}
#[track_caller]
pub fn insert_row_with<I>(&mut self, row: Row, mut values: I)
where
I: Iterator<Item = T>,
{
assert!(
row <= self.rows(),
"Row to insert must be <= to {}",
self.rows()
);
for column in 0..self.columns() {
self.data.insert(
self.get_index(row, column),
values.next().unwrap_or_else(|| {
panic!("At least {} values must be provided", self.columns())
}),
);
}
self.rows += 1;
}
#[track_caller]
pub fn insert_column(&mut self, column: Column, value: T) {
assert!(
column <= self.columns(),
"Column to insert must be <= to {}",
self.columns()
);
for row in (0..self.rows()).rev() {
self.data.insert(self.get_index(row, column), value.clone());
}
self.columns += 1;
}
#[track_caller]
pub fn insert_column_with<I>(&mut self, column: Column, values: I)
where
I: Iterator<Item = T>,
{
assert!(
column <= self.columns(),
"Column to insert must be <= to {}",
self.columns()
);
let mut array_values = values.collect::<Vec<T>>();
assert!(
array_values.len() >= self.rows(),
"At least {} values must be provided",
self.rows()
);
for row in (0..self.rows()).rev() {
self.data
.insert(self.get_index(row, column), array_values.pop().unwrap());
}
self.columns += 1;
}
pub fn retain(&self, slice: Slice2D) -> Matrix<T> {
let mut retained = self.clone();
retained.retain_mut(slice);
retained
}
}
impl<T: Clone> Clone for Matrix<T> {
fn clone(&self) -> Self {
self.map(|element| element)
}
}
impl<T: std::fmt::Display> std::fmt::Display for Matrix<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
crate::matrices::views::format_view(self, f)
}
}
impl<T> TryFrom<(Matrix<T>, [crate::tensors::Dimension; 2])> for crate::tensors::Tensor<T, 2> {
type Error = crate::tensors::InvalidShapeError<2>;
fn try_from(value: (Matrix<T>, [crate::tensors::Dimension; 2])) -> Result<Self, Self::Error> {
let (matrix, [row_name, column_name]) = value;
let shape = [(row_name, matrix.rows), (column_name, matrix.columns)];
let check = crate::tensors::InvalidShapeError::new(shape);
if !check.is_valid() {
return Err(check);
}
Ok(crate::tensors::Tensor::from(shape, matrix.data))
}
}
impl<T: Numeric> Matrix<T>
where
for<'a> &'a T: NumericRef<T>,
{
pub fn determinant(&self) -> Option<T> {
linear_algebra::determinant::<T>(self)
}
pub fn inverse(&self) -> Option<Matrix<T>> {
linear_algebra::inverse::<T>(self)
}
pub fn covariance_column_features(&self) -> Matrix<T> {
linear_algebra::covariance_column_features::<T>(self)
}
pub fn covariance_row_features(&self) -> Matrix<T> {
linear_algebra::covariance_row_features::<T>(self)
}
}
impl<T: Real> Matrix<T>
where
for<'a> &'a T: RealRef<T>,
{
#[track_caller]
pub fn euclidean_length(&self) -> T {
if self.columns() == 1 {
(self.transpose() * self).scalar().sqrt()
} else if self.rows() == 1 {
(self * self.transpose()).scalar().sqrt()
} else {
panic!(
"Cannot compute unit vector of a non vector, rows: {}, columns: {}",
self.rows(),
self.columns()
);
}
}
}
impl<T: Numeric> Matrix<T> {
#[track_caller]
pub fn diagonal(value: T, size: (Row, Column)) -> Matrix<T> {
assert!(size.0 == size.1);
let mut matrix = Matrix::empty(T::zero(), size);
for i in 0..size.0 {
matrix.set(i, i, value.clone());
}
matrix
}
pub fn from_diagonal(values: Vec<T>) -> Matrix<T> {
let mut matrix = Matrix::empty(T::zero(), (values.len(), values.len()));
for (i, element) in values.into_iter().enumerate() {
matrix.set(i, i, element);
}
matrix
}
}
impl<T: PartialEq> PartialEq for Matrix<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
if self.rows() != other.rows() {
return false;
}
if self.columns() != other.columns() {
return false;
}
self.data.iter().zip(other.data.iter()).all(|(x, y)| x == y)
}
}
#[test]
fn test_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<Matrix<f64>>();
}
#[test]
fn test_send() {
fn assert_send<T: Send>() {}
assert_send::<Matrix<f64>>();
}
#[cfg(feature = "serde")]
mod serde_impls {
use crate::matrices::{Column, Matrix, Row};
use serde::{Deserialize, Deserializer};
#[derive(Deserialize)]
#[serde(rename = "Matrix")]
struct MatrixDeserialize<T> {
data: Vec<T>,
rows: Row,
columns: Column,
}
impl<'de, T> Deserialize<'de> for Matrix<T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
MatrixDeserialize::<T>::deserialize(deserializer).map(|d| {
Matrix::from_flat_row_major((d.rows, d.columns), d.data)
})
}
}
}
#[cfg(feature = "serde")]
#[test]
fn test_serialize() {
fn assert_serialize<T: Serialize>() {}
assert_serialize::<Matrix<f64>>();
}
#[cfg(feature = "serde")]
#[test]
fn test_deserialize() {
use serde::Deserialize;
fn assert_deserialize<'de, T: Deserialize<'de>>() {}
assert_deserialize::<Matrix<f64>>();
}
#[cfg(feature = "serde")]
#[test]
fn test_serialization_deserialization_loop() {
#[rustfmt::skip]
let matrix = Matrix::from(vec![
vec![1, 2, 3, 4],
vec![5, 6, 7, 8],
vec![9, 10, 11, 12],
]);
let encoded = toml::to_string(&matrix).unwrap();
assert_eq!(
encoded,
r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
rows = 3
columns = 4
"#,
);
let parsed: Result<Matrix<i32>, _> = toml::from_str(&encoded);
assert!(parsed.is_ok());
assert_eq!(matrix, parsed.unwrap())
}
#[cfg(feature = "serde")]
#[test]
#[should_panic]
fn test_deserialization_validation() {
let _result: Result<Matrix<i32>, _> = toml::from_str(
r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
rows = 3
columns = 3
"#,
);
}
#[test]
fn test_indexing() {
let a = Matrix::from(vec![vec![1, 2], vec![3, 4]]);
assert_eq!(a.get_index(0, 1), 1);
assert_eq!(a.get_row_column(1), (0, 1));
assert_eq!(a.get(0, 1), 2);
let b = Matrix::from(vec![vec![1, 2, 3], vec![5, 6, 7]]);
assert_eq!(b.get_index(1, 2), 5);
assert_eq!(b.get_row_column(5), (1, 2));
assert_eq!(b.get(1, 2), 7);
assert_eq!(
Matrix::from(vec![vec![0, 0], vec![0, 0], vec![0, 0]])
.map_with_index(|_, r, c| format!("{:?}x{:?}", r, c)),
Matrix::from(vec![
vec!["0x0", "0x1"],
vec!["1x0", "1x1"],
vec!["2x0", "2x1"]
])
.map(|x| x.to_owned())
);
}