easy_ml/tensors/views/
renamed.rs

1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * A combination of new dimension names and a tensor.
8 *
9 * The provided dimension names override the dimension names in the
10 * [`view_shape`](TensorRef::view_shape) of the TensorRef exposed.
11 *
12 * ```
13 * use easy_ml::tensors::Tensor;
14 * use easy_ml::tensors::views::{TensorRename, TensorView};
15 * let a_b = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
16 * let b_c = TensorView::from(TensorRename::from(&a_b, ["b", "c"]));
17 * let also_b_c = a_b.rename_view(["b", "c"]);
18 * assert_ne!(a_b, b_c);
19 * assert_eq!(b_c, also_b_c);
20 * ```
21 */
22#[derive(Clone, Debug)]
23pub struct TensorRename<T, S, const D: usize> {
24    source: S,
25    dimensions: [Dimension; D],
26    _type: PhantomData<T>,
27}
28
29impl<T, S, const D: usize> TensorRename<T, S, D>
30where
31    S: TensorRef<T, D>,
32{
33    /**
34     * Creates a TensorRename from a source and a list of dimension names to override the
35     * view_shape with.
36     *
37     * # Panics
38     *
39     * - If a dimension name is not unique
40     */
41    #[track_caller]
42    pub fn from(source: S, dimensions: [Dimension; D]) -> TensorRename<T, S, D> {
43        if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
44            panic!("Dimension names must all be unique: {:?}", &dimensions);
45        }
46        TensorRename {
47            source,
48            dimensions,
49            _type: PhantomData,
50        }
51    }
52
53    /**
54     * Consumes the TensorRename, yielding the source it was created from.
55     */
56    pub fn source(self) -> S {
57        self.source
58    }
59
60    /**
61     * Gives a reference to the TensorRename's source (in which the dimension names may be
62     * different).
63     */
64    pub fn source_ref(&self) -> &S {
65        &self.source
66    }
67
68    /**
69     * Gives a mutable reference to the TensorRename's source (in which the dimension names may be
70     * different).
71     */
72    // # Safety
73    //
74    // Although we're giving out a mutable reference here and thus the Tensor could be modified
75    // by the caller, it's impossible to change the dimensionality of the source due to this
76    // being determind at compile time by const generics, so as we only need our names to
77    // match the dimensionality of the source after any modifications we don't have any edge
78    // cases that could make them invalid.
79    pub fn source_ref_mut(&mut self) -> &mut S {
80        &mut self.source
81    }
82
83    /**
84     * Gets the dimension names this TensorRename is overriding the source it
85     * was created from with.
86     */
87    pub fn get_names(&self) -> &[Dimension; D] {
88        &self.dimensions
89    }
90
91    // # Safety
92    //
93    // Giving out a mutable reference to our dimension names could allow a caller to make
94    // them non unique which would invalidate our TensorRef implementation. However, a setter
95    // method is fine because we can ensure this invariant is not broken.
96    /**
97     * Sets the dimension names this TensorRename is overriding the source it
98     * was created from with.
99     *
100     * # Panics
101     *
102     * - If a dimension name is not unique
103     */
104    #[track_caller]
105    pub fn set_names(&mut self, dimensions: [Dimension; D]) {
106        if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
107            panic!("Dimension names must all be unique: {:?}", &dimensions);
108        }
109        self.dimensions = dimensions;
110    }
111}
112
113// # Safety
114//
115// The type implementing TensorRef must implement it correctly, so by delegating to it
116// without changing any indexes or introducing interior mutability, and ensuring we do
117// not introduce non unique dimension names, we implement TensorRef correctly as well.
118/**
119 * A TensorRename implements TensorRef, with the dimension names the TensorRename was created
120 * with overriding the dimension names in the original source.
121 */
122unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRename<T, S, D>
123where
124    S: TensorRef<T, D>,
125{
126    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
127        self.source.get_reference(indexes)
128    }
129
130    fn view_shape(&self) -> [(Dimension, usize); D] {
131        let mut shape = self.source.view_shape();
132        for (i, element) in shape.iter_mut().enumerate() {
133            *element = (self.dimensions[i], element.1);
134        }
135        shape
136    }
137
138    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
139        unsafe { self.source.get_reference_unchecked(indexes) }
140    }
141
142    fn data_layout(&self) -> DataLayout<D> {
143        let data_layout = self.source.data_layout();
144        match data_layout {
145            DataLayout::Linear(order) => {
146                let shape = self.source.view_shape();
147                // Map the dimension name order to position in the view shape instead of name
148                let order_d: [usize; D] = std::array::from_fn(|i| {
149                    let name = order[i];
150                    dimensions::position_of(&shape, name)
151                        .unwrap_or_else(|| panic!(
152                            "Source implementation contained dimension {} in data_layout that was not in the view_shape {:?} which breaks the contract of TensorRef",
153                            name, &shape
154                        ))
155                });
156                // TensorRename doesn't move dimensions around, so now we can map from position
157                // order to our new dimension names.
158                DataLayout::Linear(std::array::from_fn(|i| self.dimensions[order_d[i]]))
159            }
160            _ => data_layout,
161        }
162    }
163}
164
165// # Safety
166//
167// The type implementing TensorMut must implement it correctly, so by delegating to it
168// without changing any indexes or introducing interior mutability, and ensuring we do
169// not introduce non unique dimension names, we implement TensorMut correctly as well.
170/**
171 * A TensorRename implements TensorMut, with the dimension names the TensorRename was created
172 * with overriding the dimension names in the original source.
173 */
174unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRename<T, S, D>
175where
176    S: TensorMut<T, D>,
177{
178    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
179        self.source.get_reference_mut(indexes)
180    }
181
182    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
183        unsafe { self.source.get_reference_unchecked_mut(indexes) }
184    }
185}
186
187#[test]
188fn test_renamed_view_shape() {
189    use crate::tensors::Tensor;
190    let tensor = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
191    let b_c = tensor.rename_view(["b", "c"]);
192    assert_eq!([("b", 2), ("c", 2)], b_c.shape());
193}