use crate::tensors::views::{TensorMut, TensorRef};
use crate::tensors::Dimension;
use std::error::Error;
use std::fmt;
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct TensorAccess<T, S, const D: usize> {
source: S,
dimension_mapping: [usize; D],
_type: PhantomData<T>,
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
match TensorAccess::try_from(source, dimensions) {
Err(error) => panic!("{}", error),
Ok(success) => success,
}
}
pub fn try_from(
source: S,
dimensions: [Dimension; D],
) -> Result<TensorAccess<T, S, D>, InvalidDimensionsError<D>> {
Ok(TensorAccess {
dimension_mapping: dimension_mapping(&source.view_shape(), &dimensions).ok_or_else(
|| InvalidDimensionsError {
actual: source.view_shape(),
requested: dimensions,
},
)?,
source,
_type: PhantomData,
})
}
pub fn shape(&self) -> [(Dimension, usize); D] {
let memory_shape = self.source.view_shape();
#[allow(clippy::clone_on_copy)]
let mut shape = memory_shape.clone();
for d in 0..D {
shape[d] = memory_shape[self.dimension_mapping[d]];
}
shape
}
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct InvalidDimensionsError<const D: usize> {
actual: [(Dimension, usize); D],
requested: [Dimension; D],
}
impl<const D: usize> Error for InvalidDimensionsError<D> {}
impl<const D: usize> fmt::Display for InvalidDimensionsError<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Requested dimension order: {:?} does not match the set of dimensions in the source: {:?}",
&self.actual, &self.requested
)
}
}
#[test]
fn test_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<InvalidDimensionsError<3>>();
}
#[test]
fn test_send() {
fn assert_send<T: Send>() {}
assert_send::<InvalidDimensionsError<3>>();
}
fn dimension_mapping<const D: usize>(
memory: &[(Dimension, usize); D],
requested: &[Dimension; D],
) -> Option<[usize; D]> {
let mut mapping = [0; D];
for d in 0..D {
let dimension = memory[d].0;
let order = if requested[d] == dimension {
d
} else {
let (n, _) = requested
.iter()
.enumerate()
.find(|(_, d)| **d == dimension)?;
n
};
mapping[d] = order;
}
Some(mapping)
}
#[inline]
fn map_dimensions<const D: usize>(
dimension_mapping: &[usize; D],
indexes: &[usize; D],
) -> [usize; D] {
let mut lookup = [0; D];
for d in 0..D {
lookup[d] = indexes[dimension_mapping[d]];
}
lookup
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
{
pub fn try_get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source
.get_reference(map_dimensions(&self.dimension_mapping, &indexes))
}
#[track_caller]
pub fn get_reference(&self, indexes: [usize; D]) -> &T {
match self.try_get_reference(indexes) {
Some(reference) => reference,
None => panic!(
"Unable to index with {:?}, Tensor dimensions are {:?}.",
indexes,
self.shape()
),
}
}
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorRef<T, D>,
T: Clone,
{
#[track_caller]
pub fn get(&self, indexes: [usize; D]) -> T {
match self.try_get_reference(indexes) {
Some(reference) => reference.clone(),
None => panic!(
"Unable to index with {:?}, Tensor dimensions are {:?}.",
indexes,
self.shape()
),
}
}
}
impl<T, S, const D: usize> TensorAccess<T, S, D>
where
S: TensorMut<T, D>,
{
pub fn try_get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source
.get_reference_mut(map_dimensions(&self.dimension_mapping, &indexes))
}
#[track_caller]
pub fn get_reference_mut(&mut self, indexes: [usize; D]) -> &mut T {
match self.try_get_reference_mut(indexes) {
Some(reference) => reference,
None => panic!("Unable to index with {:?}", indexes),
}
}
}
#[test]
fn test_dimension_mapping() {
use crate::tensors::{dimension, of};
let mapping = dimension_mapping(
&[of("x", 0), of("y", 0), of("z", 0)],
&[dimension("x"), dimension("y"), dimension("z")],
);
assert_eq!([0, 1, 2], mapping.unwrap());
let mapping = dimension_mapping(
&[of("x", 0), of("y", 0), of("z", 0)],
&[dimension("z"), dimension("y"), dimension("x")],
);
assert_eq!([2, 1, 0], mapping.unwrap());
}