use crate::tensors::Dimension;
use crate::tensors::dimensions;
use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct TensorReverse<T, S, const D: usize> {
source: S,
reversed: [bool; D],
_type: PhantomData<T>,
}
impl<T, S, const D: usize> TensorReverse<T, S, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, dimensions: &[Dimension]) -> TensorReverse<T, S, D> {
if crate::tensors::dimensions::has_duplicates_names(dimensions) {
panic!("Dimension names must all be unique: {:?}", dimensions);
}
let shape = source.view_shape();
if let Some(dimension) = dimensions.iter().find(|d| !dimensions::contains(&shape, d)) {
panic!(
"Dimension names to reverse must be in the source: {:?} is not in {:?}",
dimension, shape
);
}
let reversed = std::array::from_fn(|i| dimensions.contains(&shape[i].0));
TensorReverse {
source,
reversed,
_type: PhantomData,
}
}
pub fn source(self) -> S {
self.source
}
pub fn source_ref(&self) -> &S {
&self.source
}
pub fn source_ref_mut(&mut self) -> &mut S {
&mut self.source
}
}
pub(crate) fn reverse_indexes<const D: usize>(
indexes: &[usize; D],
shape: &[(Dimension, usize); D],
reversed: &[bool; D],
) -> [usize; D] {
std::array::from_fn(|d| {
if reversed[d] {
let length = shape[d].1;
let last_index = length - 1;
let index = indexes[d];
last_index - index
} else {
indexes[d]
}
})
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorReverse<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source.get_reference(reverse_indexes(
&indexes,
&self.view_shape(),
&self.reversed,
))
}
fn view_shape(&self) -> [(Dimension, usize); D] {
self.source.view_shape()
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe {
self.source.get_reference_unchecked(reverse_indexes(
&indexes,
&self.view_shape(),
&self.reversed,
))
}
}
fn data_layout(&self) -> DataLayout<D> {
DataLayout::Other
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorReverse<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source.get_reference_mut(reverse_indexes(
&indexes,
&self.view_shape(),
&self.reversed,
))
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe {
self.source.get_reference_unchecked_mut(reverse_indexes(
&indexes,
&self.view_shape(),
&self.reversed,
))
}
}
}
#[test]
fn test_reversed_tensors() {
use crate::tensors::Tensor;
let tensor = Tensor::from([("a", 2), ("b", 3), ("c", 2)], (0..12).collect());
assert_eq!(
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
tensor.iter().collect::<Vec<_>>()
);
let reversed = tensor.reverse_owned(&["a", "c"]);
assert_eq!(
vec![7, 6, 9, 8, 11, 10, 1, 0, 3, 2, 5, 4],
reversed.iter().collect::<Vec<_>>()
);
}