use std::marker::PhantomData;
use std::num::NonZeroUsize;
use crate::matrices::iterators::*;
use crate::matrices::{Column, Matrix, Row};
pub mod erased;
mod map;
mod partitions;
mod ranges;
mod reverse;
pub mod traits;
pub(crate) use map::*;
pub use partitions::*;
pub use ranges::*;
pub use reverse::*;
pub unsafe trait MatrixRef<T>: NoInteriorMutability {
fn try_get_reference(&self, row: Row, column: Column) -> Option<&T>;
fn view_rows(&self) -> Row;
fn view_columns(&self) -> Column;
#[allow(clippy::missing_safety_doc)] unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T;
fn data_layout(&self) -> DataLayout;
}
pub unsafe trait MatrixMut<T>: MatrixRef<T> {
fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T>;
#[allow(clippy::missing_safety_doc)] unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T;
}
pub unsafe trait NoInteriorMutability {}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum DataLayout {
RowMajor,
ColumnMajor,
Other,
}
#[derive(Clone, Debug)]
pub struct MatrixView<T, S> {
source: S,
_type: PhantomData<T>,
}
impl<T, S> MatrixView<T, S>
where
S: MatrixRef<T>,
{
pub fn from(source: S) -> MatrixView<T, S> {
MatrixView {
source,
_type: PhantomData,
}
}
pub fn source(self) -> S {
self.source
}
pub fn source_ref(&self) -> &S {
&self.source
}
pub fn source_ref_mut(&mut self) -> &mut S {
&mut self.source
}
pub fn size(&self) -> (Row, Column) {
(self.rows(), self.columns())
}
pub fn rows(&self) -> Row {
self.source.view_rows()
}
pub fn columns(&self) -> Column {
self.source.view_columns()
}
pub fn data_layout(&self) -> DataLayout {
self.source.data_layout()
}
#[track_caller]
pub fn get_reference(&self, row: Row, column: Column) -> &T {
match self.source.try_get_reference(row, column) {
Some(reference) => reference,
None => panic!(
"Index ({},{}) not in range, MatrixView range is (0,0) to ({},{}).",
row,
column,
self.rows(),
self.columns()
),
}
}
pub fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
self.source.try_get_reference(row, column)
}
#[allow(clippy::missing_safety_doc)] pub unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
unsafe { self.source.get_reference_unchecked(row, column) }
}
#[track_caller]
pub fn column_reference_iter(&self, column: Column) -> ColumnReferenceIterator<'_, T, S> {
ColumnReferenceIterator::from(&self.source, column)
}
#[track_caller]
pub fn row_reference_iter(&self, row: Row) -> RowReferenceIterator<'_, T, S> {
RowReferenceIterator::from(&self.source, row)
}
pub fn column_major_reference_iter(&self) -> ColumnMajorReferenceIterator<'_, T, S> {
ColumnMajorReferenceIterator::from(&self.source)
}
pub fn row_major_reference_iter(&self) -> RowMajorReferenceIterator<'_, T, S> {
RowMajorReferenceIterator::from(&self.source)
}
pub fn diagonal_reference_iter(&self) -> DiagonalReferenceIterator<'_, T, S> {
DiagonalReferenceIterator::from(&self.source)
}
pub fn range<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, &S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(&self.source, rows, columns))
}
pub fn range_mut<R>(&mut self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, &mut S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(&mut self.source, rows, columns))
}
pub fn range_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixRange::from(self.source, rows, columns))
}
pub fn mask<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, &S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(&self.source, rows, columns))
}
pub fn mask_mut<R>(&mut self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, &mut S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(&mut self.source, rows, columns))
}
pub fn mask_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, S>>
where
R: Into<IndexRange>,
{
MatrixView::from(MatrixMask::from(self.source, rows, columns))
}
#[track_caller]
pub fn start_and_end_of_rows(&self, retain: usize) -> MatrixView<T, MatrixMask<T, &S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix view must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_rows(
&self.source,
Some(retain),
)),
}
}
#[track_caller]
pub fn start_and_end_of_rows_mut(
&mut self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, &mut S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix view must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_rows(
&mut self.source,
Some(retain),
)),
}
}
#[track_caller]
pub fn start_and_end_of_rows_owned(self, retain: usize) -> MatrixView<T, MatrixMask<T, S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of rows to retain at start and end of matrix view must be at least 1, 0 rows retained would remove all elements"
),
Some(retain) => {
MatrixView::from(MatrixMask::start_and_end_of_rows(self.source, Some(retain)))
}
}
}
#[track_caller]
pub fn start_and_end_of_columns(&self, retain: usize) -> MatrixView<T, MatrixMask<T, &S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix view must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_columns(
&self.source,
Some(retain),
)),
}
}
#[track_caller]
pub fn start_and_end_of_columns_mut(
&mut self,
retain: usize,
) -> MatrixView<T, MatrixMask<T, &mut S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix view must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_columns(
&mut self.source,
Some(retain),
)),
}
}
#[track_caller]
pub fn start_and_end_of_columns_owned(self, retain: usize) -> MatrixView<T, MatrixMask<T, S>> {
match NonZeroUsize::new(retain) {
None => panic!(
"Number of columns to retain at start and end of matrix view must be at least 1, 0 columns retained would remove all elements"
),
Some(retain) => MatrixView::from(MatrixMask::start_and_end_of_columns(
self.source,
Some(retain),
)),
}
}
pub fn reverse(&self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, &S>> {
MatrixView::from(MatrixReverse::from(&self.source, reverse))
}
pub fn reverse_mut(&mut self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, &mut S>> {
MatrixView::from(MatrixReverse::from(&mut self.source, reverse))
}
pub fn reverse_owned(self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, S>> {
MatrixView::from(MatrixReverse::from(self.source, reverse))
}
}
impl<T, S> MatrixView<T, S>
where
T: Clone,
S: MatrixRef<T>,
{
#[track_caller]
pub fn get(&self, row: Row, column: Column) -> T {
match self.source.try_get_reference(row, column) {
Some(reference) => reference.clone(),
None => panic!(
"Index ({},{}) not in range, MatrixView range is (0,0) to ({},{}).",
row,
column,
self.rows(),
self.columns()
),
}
}
pub fn transpose(&self) -> Matrix<T> {
Matrix::from_fn((self.columns(), self.rows()), |(column, row)| {
self.get(row, column)
})
}
#[track_caller]
pub fn column_iter(&self, column: Column) -> ColumnIterator<'_, T, S> {
ColumnIterator::from(&self.source, column)
}
#[track_caller]
pub fn row_iter(&self, row: Row) -> RowIterator<'_, T, S> {
RowIterator::from(&self.source, row)
}
pub fn column_major_iter(&self) -> ColumnMajorIterator<'_, T, S> {
ColumnMajorIterator::from(&self.source)
}
pub fn row_major_iter(&self) -> RowMajorIterator<'_, T, S> {
RowMajorIterator::from(&self.source)
}
pub fn diagonal_iter(&self) -> DiagonalIterator<'_, T, S> {
DiagonalIterator::from(&self.source)
}
pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Matrix<U>
where
U: Clone,
{
let mapped = self.row_major_iter().map(mapping_function).collect();
Matrix::from_flat_row_major(self.size(), mapped)
}
pub fn map_with_index<U>(&self, mapping_function: impl Fn(T, Row, Column) -> U) -> Matrix<U>
where
U: Clone,
{
let mapped = self
.row_major_iter()
.with_index()
.map(|((i, j), x)| mapping_function(x, i, j))
.collect();
Matrix::from_flat_row_major(self.size(), mapped)
}
}
impl<T, S> MatrixView<T, S>
where
S: MatrixMut<T>,
{
#[track_caller]
pub fn get_reference_mut(&mut self, row: Row, column: Column) -> &mut T {
let size = self.size();
match self.source.try_get_reference_mut(row, column) {
Some(reference) => reference,
None => panic!(
"Index ({},{}) not in range, MatrixView range is (0,0) to ({},{}).",
row, column, size.0, size.1
),
}
}
#[track_caller]
pub fn set(&mut self, row: Row, column: Column, value: T) {
match self.source.try_get_reference_mut(row, column) {
Some(reference) => *reference = value,
None => panic!(
"Index ({},{}) not in range, MatrixView range is (0,0) to ({},{}).",
row,
column,
self.rows(),
self.columns()
),
}
}
pub fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
self.source.try_get_reference_mut(row, column)
}
#[allow(clippy::missing_safety_doc)] pub unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
unsafe { self.source.get_reference_unchecked_mut(row, column) }
}
}
impl<T, S> MatrixView<T, S>
where
S: MatrixMut<T> + NoInteriorMutability,
{
#[track_caller]
pub fn column_reference_mut_iter(
&mut self,
column: Column,
) -> ColumnReferenceMutIterator<'_, T, S> {
ColumnReferenceMutIterator::from(&mut self.source, column)
}
#[track_caller]
pub fn row_reference_mut_iter(&mut self, row: Row) -> RowReferenceMutIterator<'_, T, S> {
RowReferenceMutIterator::from(&mut self.source, row)
}
pub fn column_major_reference_mut_iter(&mut self) -> ColumnMajorReferenceMutIterator<'_, T, S> {
ColumnMajorReferenceMutIterator::from(&mut self.source)
}
pub fn row_major_reference_mut_iter(&mut self) -> RowMajorReferenceMutIterator<'_, T, S> {
RowMajorReferenceMutIterator::from(&mut self.source)
}
pub fn diagonal_reference_mut_iter(&mut self) -> DiagonalReferenceMutIterator<'_, T, S> {
DiagonalReferenceMutIterator::from(&mut self.source)
}
}
impl<T, S> MatrixView<T, S>
where
T: Clone,
S: MatrixMut<T>,
{
pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
self.map_mut_with_index(|x, _, _| mapping_function(x))
}
pub fn map_mut_with_index(&mut self, mapping_function: impl Fn(T, Row, Column) -> T) {
match self.data_layout() {
DataLayout::ColumnMajor => {
self.column_major_reference_mut_iter()
.with_index()
.for_each(|((i, j), x)| {
*x = mapping_function(x.clone(), i, j);
});
}
_ => {
self.row_major_reference_mut_iter()
.with_index()
.for_each(|((i, j), x)| {
*x = mapping_function(x.clone(), i, j);
});
}
}
}
}
pub(crate) fn format_view<T, S>(view: &S, f: &mut std::fmt::Formatter) -> std::fmt::Result
where
T: std::fmt::Display,
S: MatrixRef<T>,
{
let rows = view.view_rows();
let columns = view.view_columns();
write!(f, "[ ")?;
for row in 0..rows {
if row > 0 {
write!(f, " ")?;
}
for column in 0..columns {
let value = match view.try_get_reference(row, column) {
Some(x) => x,
None => panic!(
"Expected ({},{}) to be in range of (0,0) to ({},{})",
row, column, rows, columns
),
};
match f.precision() {
Some(precision) => write!(f, "{:.*}", precision, value)?,
None => write!(f, "{}", value)?,
};
if column < columns - 1 {
write!(f, ", ")?;
}
}
if row < rows - 1 {
writeln!(f)?;
}
}
write!(f, " ]")
}
impl<T, S> std::fmt::Display for MatrixView<T, S>
where
T: std::fmt::Display,
S: MatrixRef<T>,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
format_view(&self.source, f)
}
}
#[test]
fn printing_matrices() {
use crate::matrices::Matrix;
let view = MatrixView::from(Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]]));
let formatted = format!("{:.3}", view);
assert_eq!("[ 1.000, 2.000\n 3.000, 4.000 ]", formatted);
assert_eq!("[ 1, 2\n 3, 4 ]", view.to_string());
}
#[inline]
pub(crate) fn matrix_equality<T, S1, S2>(left: &S1, right: &S2) -> bool
where
T: PartialEq,
S1: MatrixRef<T>,
S2: MatrixRef<T>,
{
if left.view_rows() != right.view_rows() {
return false;
}
if left.view_columns() != right.view_columns() {
return false;
}
match (left.data_layout(), right.data_layout()) {
(DataLayout::ColumnMajor, DataLayout::ColumnMajor) => {
ColumnMajorReferenceIterator::from(left)
.zip(ColumnMajorReferenceIterator::from(right))
.all(|(x, y)| x == y)
}
_ => RowMajorReferenceIterator::from(left)
.zip(RowMajorReferenceIterator::from(right))
.all(|(x, y)| x == y),
}
}
impl<T, S1, S2> PartialEq<MatrixView<T, S2>> for MatrixView<T, S1>
where
T: PartialEq,
S1: MatrixRef<T>,
S2: MatrixRef<T>,
{
#[inline]
fn eq(&self, other: &MatrixView<T, S2>) -> bool {
matrix_equality(&self.source, &other.source)
}
}
impl<T, S> PartialEq<Matrix<T>> for MatrixView<T, S>
where
T: PartialEq,
S: MatrixRef<T>,
{
#[inline]
fn eq(&self, other: &Matrix<T>) -> bool {
matrix_equality(&self.source, &other)
}
}
impl<T, S> PartialEq<MatrixView<T, S>> for Matrix<T>
where
T: PartialEq,
S: MatrixRef<T>,
{
#[inline]
fn eq(&self, other: &MatrixView<T, S>) -> bool {
matrix_equality(&self, &other.source)
}
}
#[test]
fn creating_matrix_views_erased() {
let matrix = Matrix::from(vec![vec![1.0]]);
let boxed: Box<dyn MatrixMut<f32>> = Box::new(matrix);
let mut view = MatrixView::from(boxed);
view.set(0, 0, view.get(0, 0) + 1.0);
assert_eq!(2.0, view.get(0, 0));
}