use crate::tensors::Dimension;
use crate::tensors::dimensions;
use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct TensorRename<T, S, const D: usize> {
source: S,
dimensions: [Dimension; D],
_type: PhantomData<T>,
}
impl<T, S, const D: usize> TensorRename<T, S, D>
where
S: TensorRef<T, D>,
{
#[track_caller]
pub fn from(source: S, dimensions: [Dimension; D]) -> TensorRename<T, S, D> {
if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
panic!("Dimension names must all be unique: {:?}", &dimensions);
}
TensorRename {
source,
dimensions,
_type: PhantomData,
}
}
pub fn source(self) -> S {
self.source
}
pub fn source_ref(&self) -> &S {
&self.source
}
pub fn source_ref_mut(&mut self) -> &mut S {
&mut self.source
}
pub fn get_names(&self) -> &[Dimension; D] {
&self.dimensions
}
#[track_caller]
pub fn set_names(&mut self, dimensions: [Dimension; D]) {
if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
panic!("Dimension names must all be unique: {:?}", &dimensions);
}
self.dimensions = dimensions;
}
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRename<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source.get_reference(indexes)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
let mut shape = self.source.view_shape();
for (i, element) in shape.iter_mut().enumerate() {
*element = (self.dimensions[i], element.1);
}
shape
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe { self.source.get_reference_unchecked(indexes) }
}
fn data_layout(&self) -> DataLayout<D> {
let data_layout = self.source.data_layout();
match data_layout {
DataLayout::Linear(order) => {
let shape = self.source.view_shape();
let order_d: [usize; D] = std::array::from_fn(|i| {
let name = order[i];
dimensions::position_of(&shape, name)
.unwrap_or_else(|| panic!(
"Source implementation contained dimension {} in data_layout that was not in the view_shape {:?} which breaks the contract of TensorRef",
name, &shape
))
});
DataLayout::Linear(std::array::from_fn(|i| self.dimensions[order_d[i]]))
}
_ => data_layout,
}
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRename<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source.get_reference_mut(indexes)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe { self.source.get_reference_unchecked_mut(indexes) }
}
}
#[test]
fn test_renamed_view_shape() {
use crate::tensors::Tensor;
let tensor = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
let b_c = tensor.rename_view(["b", "c"]);
assert_eq!([("b", 2), ("c", 2)], b_c.shape());
}