easy_ml/tensors/views/renamed.rs
1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * A combination of new dimension names and a tensor.
8 *
9 * The provided dimension names override the dimension names in the
10 * [`view_shape`](TensorRef::view_shape) of the TensorRef exposed.
11 *
12 * ```
13 * use easy_ml::tensors::Tensor;
14 * use easy_ml::tensors::views::{TensorRename, TensorView};
15 * let a_b = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
16 * let b_c = TensorView::from(TensorRename::from(&a_b, ["b", "c"]));
17 * let also_b_c = a_b.rename_view(["b", "c"]);
18 * assert_ne!(a_b, b_c);
19 * assert_eq!(b_c, also_b_c);
20 * ```
21 */
22#[derive(Clone, Debug)]
23pub struct TensorRename<T, S, const D: usize> {
24 source: S,
25 dimensions: [Dimension; D],
26 _type: PhantomData<T>,
27}
28
29impl<T, S, const D: usize> TensorRename<T, S, D>
30where
31 S: TensorRef<T, D>,
32{
33 /**
34 * Creates a TensorRename from a source and a list of dimension names to override the
35 * view_shape with.
36 *
37 * # Panics
38 *
39 * - If a dimension name is not unique
40 */
41 #[track_caller]
42 pub fn from(source: S, dimensions: [Dimension; D]) -> TensorRename<T, S, D> {
43 if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
44 panic!("Dimension names must all be unique: {:?}", &dimensions);
45 }
46 TensorRename {
47 source,
48 dimensions,
49 _type: PhantomData,
50 }
51 }
52
53 /**
54 * Consumes the TensorRename, yielding the source it was created from.
55 */
56 pub fn source(self) -> S {
57 self.source
58 }
59
60 /**
61 * Gives a reference to the TensorRename's source (in which the dimension names may be
62 * different).
63 */
64 pub fn source_ref(&self) -> &S {
65 &self.source
66 }
67
68 /**
69 * Gives a mutable reference to the TensorRename's source (in which the dimension names may be
70 * different).
71 */
72 // # Safety
73 //
74 // Although we're giving out a mutable reference here and thus the Tensor could be modified
75 // by the caller, it's impossible to change the dimensionality of the source due to this
76 // being determind at compile time by const generics, so as we only need our names to
77 // match the dimensionality of the source after any modifications we don't have any edge
78 // cases that could make them invalid.
79 pub fn source_ref_mut(&mut self) -> &mut S {
80 &mut self.source
81 }
82
83 /**
84 * Gets the dimension names this TensorRename is overriding the source it
85 * was created from with.
86 */
87 pub fn get_names(&self) -> &[Dimension; D] {
88 &self.dimensions
89 }
90
91 // # Safety
92 //
93 // Giving out a mutable reference to our dimension names could allow a caller to make
94 // them non unique which would invalidate our TensorRef implementation. However, a setter
95 // method is fine because we can ensure this invariant is not broken.
96 /**
97 * Sets the dimension names this TensorRename is overriding the source it
98 * was created from with.
99 *
100 * # Panics
101 *
102 * - If a dimension name is not unique
103 */
104 #[track_caller]
105 pub fn set_names(&mut self, dimensions: [Dimension; D]) {
106 if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
107 panic!("Dimension names must all be unique: {:?}", &dimensions);
108 }
109 self.dimensions = dimensions;
110 }
111}
112
113// # Safety
114//
115// The type implementing TensorRef must implement it correctly, so by delegating to it
116// without changing any indexes or introducing interior mutability, and ensuring we do
117// not introduce non unique dimension names, we implement TensorRef correctly as well.
118/**
119 * A TensorRename implements TensorRef, with the dimension names the TensorRename was created
120 * with overriding the dimension names in the original source.
121 */
122unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRename<T, S, D>
123where
124 S: TensorRef<T, D>,
125{
126 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
127 self.source.get_reference(indexes)
128 }
129
130 fn view_shape(&self) -> [(Dimension, usize); D] {
131 let mut shape = self.source.view_shape();
132 for (i, element) in shape.iter_mut().enumerate() {
133 *element = (self.dimensions[i], element.1);
134 }
135 shape
136 }
137
138 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
139 unsafe { self.source.get_reference_unchecked(indexes) }
140 }
141
142 fn data_layout(&self) -> DataLayout<D> {
143 let data_layout = self.source.data_layout();
144 match data_layout {
145 DataLayout::Linear(order) => {
146 let shape = self.source.view_shape();
147 // Map the dimension name order to position in the view shape instead of name
148 let order_d: [usize; D] = std::array::from_fn(|i| {
149 let name = order[i];
150 dimensions::position_of(&shape, name)
151 .unwrap_or_else(|| panic!(
152 "Source implementation contained dimension {} in data_layout that was not in the view_shape {:?} which breaks the contract of TensorRef",
153 name, &shape
154 ))
155 });
156 // TensorRename doesn't move dimensions around, so now we can map from position
157 // order to our new dimension names.
158 DataLayout::Linear(std::array::from_fn(|i| self.dimensions[order_d[i]]))
159 }
160 _ => data_layout,
161 }
162 }
163}
164
165// # Safety
166//
167// The type implementing TensorMut must implement it correctly, so by delegating to it
168// without changing any indexes or introducing interior mutability, and ensuring we do
169// not introduce non unique dimension names, we implement TensorMut correctly as well.
170/**
171 * A TensorRename implements TensorMut, with the dimension names the TensorRename was created
172 * with overriding the dimension names in the original source.
173 */
174unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRename<T, S, D>
175where
176 S: TensorMut<T, D>,
177{
178 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
179 self.source.get_reference_mut(indexes)
180 }
181
182 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
183 unsafe { self.source.get_reference_unchecked_mut(indexes) }
184 }
185}
186
187#[test]
188fn test_renamed_view_shape() {
189 use crate::tensors::Tensor;
190 let tensor = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
191 let b_c = tensor.rename_view(["b", "c"]);
192 assert_eq!([("b", 2), ("c", 2)], b_c.shape());
193}