use crate::tensors;
use crate::tensors::Dimension;
use crate::tensors::dimensions;
use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct TensorReshape<T, S, const D: usize, const D2: usize> {
source: S,
shape: [(Dimension, usize); D2],
strides: [usize; D2],
source_strides: [usize; D],
_type: PhantomData<T>,
}
impl<T, S, const D: usize, const D2: usize> TensorReshape<T, S, D, D2>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, shape: [(Dimension, usize); D2]) -> TensorReshape<T, S, D, D2> {
if dimensions::has_duplicates(&shape) {
panic!("Dimension names must all be unique: {:?}", &shape);
}
let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
let given_one_dimensional_length = dimensions::elements(&shape);
if given_one_dimensional_length != existing_one_dimensional_length {
panic!(
"Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
&shape,
&given_one_dimensional_length,
&existing_one_dimensional_length,
&source.view_shape()
);
}
let source_strides = tensors::compute_strides(&source.view_shape());
TensorReshape {
source,
shape,
strides: tensors::compute_strides(&shape),
source_strides,
_type: PhantomData,
}
}
#[allow(dead_code)]
pub fn source(self) -> S {
self.source
}
#[allow(dead_code)]
pub fn source_ref(&self) -> &S {
&self.source
}
}
impl<T, S, const D: usize> TensorReshape<T, S, D, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from_existing_dimensions(source: S, lengths: [usize; D]) -> TensorReshape<T, S, D, D> {
let previous_shape = source.view_shape();
let shape = std::array::from_fn(|n| (previous_shape[n].0, lengths[0]));
let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
let given_one_dimensional_length = dimensions::elements(&shape);
if given_one_dimensional_length != existing_one_dimensional_length {
panic!(
"Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
&shape,
&given_one_dimensional_length,
&existing_one_dimensional_length,
&source.view_shape()
);
}
let source_strides = tensors::compute_strides(&source.view_shape());
TensorReshape {
source,
shape,
strides: tensors::compute_strides(&shape),
source_strides,
_type: PhantomData,
}
}
}
fn unflatten<const D: usize>(nth: usize, strides: &[usize; D]) -> [usize; D] {
let mut steps_remaining = nth;
let mut index = [0; D];
for d in 0..D {
let stride = strides[d];
index[d] = steps_remaining / stride;
steps_remaining %= stride;
}
index
}
#[test]
fn unflatten_produces_indices_in_n_dimensions() {
let strides = tensors::compute_strides(&[("x", 2), ("y", 2)]);
assert_eq!([0, 0], unflatten(0, &strides));
assert_eq!([0, 1], unflatten(1, &strides));
assert_eq!([1, 0], unflatten(2, &strides));
assert_eq!([1, 1], unflatten(3, &strides));
let strides = tensors::compute_strides(&[("x", 3), ("y", 2)]);
assert_eq!([0, 0], unflatten(0, &strides));
assert_eq!([0, 1], unflatten(1, &strides));
assert_eq!([1, 0], unflatten(2, &strides));
assert_eq!([1, 1], unflatten(3, &strides));
assert_eq!([2, 0], unflatten(4, &strides));
assert_eq!([2, 1], unflatten(5, &strides));
let strides = tensors::compute_strides(&[("x", 2), ("y", 3)]);
assert_eq!([0, 0], unflatten(0, &strides));
assert_eq!([0, 1], unflatten(1, &strides));
assert_eq!([0, 2], unflatten(2, &strides));
assert_eq!([1, 0], unflatten(3, &strides));
assert_eq!([1, 1], unflatten(4, &strides));
assert_eq!([1, 2], unflatten(5, &strides));
let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 1)]);
assert_eq!([0, 0, 0], unflatten(0, &strides));
assert_eq!([0, 1, 0], unflatten(1, &strides));
assert_eq!([0, 2, 0], unflatten(2, &strides));
assert_eq!([1, 0, 0], unflatten(3, &strides));
assert_eq!([1, 1, 0], unflatten(4, &strides));
assert_eq!([1, 2, 0], unflatten(5, &strides));
let strides = tensors::compute_strides(&[("batch", 1), ("x", 2), ("y", 3)]);
assert_eq!([0, 0, 0], unflatten(0, &strides));
assert_eq!([0, 0, 1], unflatten(1, &strides));
assert_eq!([0, 0, 2], unflatten(2, &strides));
assert_eq!([0, 1, 0], unflatten(3, &strides));
assert_eq!([0, 1, 1], unflatten(4, &strides));
assert_eq!([0, 1, 2], unflatten(5, &strides));
let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 2)]);
assert_eq!([0, 0, 0], unflatten(0, &strides));
assert_eq!([0, 0, 1], unflatten(1, &strides));
assert_eq!([0, 1, 0], unflatten(2, &strides));
assert_eq!([0, 1, 1], unflatten(3, &strides));
assert_eq!([0, 2, 0], unflatten(4, &strides));
assert_eq!([0, 2, 1], unflatten(5, &strides));
assert_eq!([1, 0, 0], unflatten(6, &strides));
assert_eq!([1, 0, 1], unflatten(7, &strides));
assert_eq!([1, 1, 0], unflatten(8, &strides));
assert_eq!([1, 1, 1], unflatten(9, &strides));
assert_eq!([1, 2, 0], unflatten(10, &strides));
assert_eq!([1, 2, 1], unflatten(11, &strides));
}
unsafe impl<T, S, const D: usize, const D2: usize> TensorRef<T, D2> for TensorReshape<T, S, D, D2>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D2]) -> Option<&T> {
let one_dimensional_index =
tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
self.source
.get_reference(unflatten(one_dimensional_index, &self.source_strides))
}
fn view_shape(&self) -> [(Dimension, usize); D2] {
self.shape
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D2]) -> &T {
unsafe {
let one_dimensional_index =
tensors::get_index_direct_unchecked(&indexes, &self.strides);
self.source
.get_reference_unchecked(unflatten(one_dimensional_index, &self.source_strides))
}
}
fn data_layout(&self) -> DataLayout<D2> {
DataLayout::Other
}
}
unsafe impl<T, S, const D: usize, const D2: usize> TensorMut<T, D2> for TensorReshape<T, S, D, D2>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D2]) -> Option<&mut T> {
let one_dimensional_index =
tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
self.source
.get_reference_mut(unflatten(one_dimensional_index, &self.source_strides))
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D2]) -> &mut T {
unsafe {
let one_dimensional_index =
tensors::get_index_direct_unchecked(&indexes, &self.strides);
self.source
.get_reference_unchecked_mut(unflatten(one_dimensional_index, &self.source_strides))
}
}
}