use crate::matrices::views::DataLayout as MDataLayout;
use crate::matrices::views::{MatrixMut, MatrixRef, NoInteriorMutability};
use crate::matrices::{Column, Row};
use crate::tensors::views::DataLayout as TDataLayout;
use crate::tensors::views::{TensorMut, TensorRef};
use crate::tensors::{Dimension, InvalidShapeError};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct TensorRefMatrix<T, S, N> {
source: S,
names: N,
_type: PhantomData<T>,
}
pub trait DimensionNames {
fn names(&self) -> [Dimension; 2];
}
#[derive(Clone, Debug)]
pub struct RowAndColumn;
impl DimensionNames for RowAndColumn {
fn names(&self) -> [Dimension; 2] {
["row", "column"]
}
}
impl DimensionNames for [Dimension; 2] {
fn names(&self) -> [Dimension; 2] {
*self
}
}
impl<T, S> TensorRefMatrix<T, S, RowAndColumn>
where
S: MatrixRef<T> + NoInteriorMutability,
{
pub fn from(source: S) -> Result<TensorRefMatrix<T, S, RowAndColumn>, InvalidShapeError<2>> {
TensorRefMatrix::with_names(source, RowAndColumn)
}
}
impl<T, S, N> TensorRefMatrix<T, S, N>
where
S: MatrixRef<T> + NoInteriorMutability,
N: DimensionNames,
{
pub fn with_names(
source: S,
names: N,
) -> Result<TensorRefMatrix<T, S, N>, InvalidShapeError<2>> {
let dimensions = names.names();
let shape = InvalidShapeError::new([
(dimensions[0], source.view_rows()),
(dimensions[1], source.view_columns()),
]);
if shape.is_valid() {
Ok(TensorRefMatrix {
source,
names,
_type: PhantomData,
})
} else {
Err(shape)
}
}
}
unsafe impl<T, S, N> TensorRef<T, 2> for TensorRefMatrix<T, S, N>
where
S: MatrixRef<T> + NoInteriorMutability,
N: DimensionNames,
{
fn get_reference(&self, indexes: [usize; 2]) -> Option<&T> {
self.source.try_get_reference(indexes[0], indexes[1])
}
fn view_shape(&self) -> [(Dimension, usize); 2] {
let (rows, columns) = (self.source.view_rows(), self.source.view_columns());
let [row, column] = self.names.names();
[(row, rows), (column, columns)]
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; 2]) -> &T {
unsafe { self.source.get_reference_unchecked(indexes[0], indexes[1]) }
}
fn data_layout(&self) -> TDataLayout<2> {
let [rows_dimension, columns_dimension] = self.names.names();
match self.source.data_layout() {
MDataLayout::RowMajor => TDataLayout::Linear([rows_dimension, columns_dimension]),
MDataLayout::ColumnMajor => TDataLayout::Linear([columns_dimension, rows_dimension]),
MDataLayout::Other => TDataLayout::Other,
}
}
}
unsafe impl<T, S, N> TensorMut<T, 2> for TensorRefMatrix<T, S, N>
where
S: MatrixMut<T> + NoInteriorMutability,
N: DimensionNames,
{
fn get_reference_mut(&mut self, indexes: [usize; 2]) -> Option<&mut T> {
self.source.try_get_reference_mut(indexes[0], indexes[1])
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; 2]) -> &mut T {
unsafe {
self.source
.get_reference_unchecked_mut(indexes[0], indexes[1])
}
}
}
pub struct MatrixRefTensor<T, S> {
source: S,
_type: PhantomData<T>,
}
impl<T, S> MatrixRefTensor<T, S>
where
S: TensorRef<T, 2>,
{
pub fn from(source: S) -> MatrixRefTensor<T, S> {
MatrixRefTensor {
source,
_type: PhantomData,
}
}
}
unsafe impl<T, S> MatrixRef<T> for MatrixRefTensor<T, S>
where
S: TensorRef<T, 2>,
{
fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
self.source.get_reference([row, column])
}
fn view_rows(&self) -> Row {
self.source.view_shape()[0].1
}
fn view_columns(&self) -> Column {
self.source.view_shape()[1].1
}
unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
unsafe { self.source.get_reference_unchecked([row, column]) }
}
fn data_layout(&self) -> MDataLayout {
let rows_dimension = self.source.view_shape()[0].0;
let columns_dimension = self.source.view_shape()[1].0;
let data_layout = self.source.data_layout();
if data_layout == TDataLayout::Linear([rows_dimension, columns_dimension]) {
MDataLayout::RowMajor
} else if data_layout == TDataLayout::Linear([columns_dimension, rows_dimension]) {
MDataLayout::ColumnMajor
} else {
match self.source.data_layout() {
TDataLayout::NonLinear => MDataLayout::Other,
TDataLayout::Other => MDataLayout::Other,
TDataLayout::Linear([_, _]) => MDataLayout::Other,
}
}
}
}
unsafe impl<T, S> MatrixMut<T> for MatrixRefTensor<T, S>
where
S: TensorMut<T, 2>,
{
fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
self.source.get_reference_mut([row, column])
}
unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
unsafe { self.source.get_reference_unchecked_mut([row, column]) }
}
}
unsafe impl<T, S> NoInteriorMutability for MatrixRefTensor<T, S> where S: TensorRef<T, 2> {}