easy_ml/tensors/
dimensions.rs

1/*!
2 * Utilities to manipulate Dimensions.
3 *
4 * Contains a number of utility functions for manipulating the dimensions of tensors.
5 *
6 * # Terminology
7 *
8 * Tensors in Easy ML have a **shape** of type `[(Dimension, usize); D]`, where D is a compile time
9 * constant. This type defines a list of dimension names and the lengths along each dimension name.
10 * A tensor with some shape will have data ranging from 0 to the length - 1 along each
11 * dimension name.
12 * Often we want to call methods which only take dimension names in some order, so those
13 * **dimension names** have a type of `[Dimension; D]`. We also may want to index into a Tensor,
14 * which is done by providing the **index**es only, with a type of `[usize; D]`.
15 *
16 * Dimensions and dimension names in Easy ML APIs are treated like lists, the order of
17 * each dimension does make a difference for equality definitions, mathematical operations,
18 * and can be a factor for the order of iteration and indexing. However, many high level APIs
19 * that are not directly involved with order or indexing require only a dimension name and these
20 * are usually less concerned with the order of the dimensions.
21 */
22
23use crate::tensors::Dimension;
24
25/**
26 * Returns the product of the dimension lengths in the provided shape.
27 *
28 * This is equal to the number of elements that will be stored for these dimensions.
29 * A 0 dimensional tensor stores exactly 1 element, a 1 dimensional tensor stores N elements,
30 * a 2 dimensional tensor stores NxM elements and so on.
31 */
32pub fn elements<const D: usize>(shape: &[(Dimension, usize); D]) -> usize {
33    shape.iter().map(|d| d.1).product()
34}
35
36/**
37 * Finds the position of the dimension name in the shape.
38 *
39 * `None` is returned if the dimension name is not in the shape.
40 */
41pub fn position_of<const D: usize>(
42    shape: &[(Dimension, usize); D],
43    dimension: Dimension,
44) -> Option<usize> {
45    shape.iter().position(|(d, _)| d == &dimension)
46}
47
48/**
49 * Checks if the dimension name is in the shape.
50 */
51pub fn contains<const D: usize>(shape: &[(Dimension, usize); D], dimension: Dimension) -> bool {
52    shape.iter().any(|(d, _)| d == &dimension)
53}
54
55/**
56 * Returns the length of the dimension name provided, if one is present in the shape.
57 */
58pub fn length_of<const D: usize>(
59    shape: &[(Dimension, usize); D],
60    dimension: Dimension,
61) -> Option<usize> {
62    shape
63        .iter()
64        .find(|(d, _)| *d == dimension)
65        .map(|(_, length)| *length)
66}
67
68/**
69 * Returns the last index of the dimension name provided, if one is present in the shape.
70 *
71 * This is always 1 less than the length, the 'index' in this sense is based on what the
72 * shape is, not any implementation index. If for some reason a shape with a 0
73 * length was given, the last index will saturate at 0.
74 */
75pub fn last_index_of<const D: usize>(
76    shape: &[(Dimension, usize); D],
77    dimension: Dimension,
78) -> Option<usize> {
79    length_of(shape, dimension).map(|length| length.saturating_sub(1))
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub(crate) struct DimensionMappings<const D: usize> {
84    source_to_requested: [usize; D],
85    requested_to_source: [usize; D],
86}
87
88impl<const D: usize> DimensionMappings<D> {
89    pub(crate) fn no_op_mapping() -> DimensionMappings<D> {
90        DimensionMappings {
91            source_to_requested: std::array::from_fn(|d| d),
92            requested_to_source: std::array::from_fn(|d| d),
93        }
94    }
95
96    // Computes both mappings from from a shape in source order and a matching set of
97    // dimensions in an arbitary order.
98    // If the source order is x,y,z but the requested order is z,x,y then the mapping
99    // from source to requested is [1,2,0] (x becomes second, y becomes last, z becomes first) and
100    // from requested to source is [2,0,1] (z becones last, x becomes first, y becomes second).
101    pub(crate) fn new(
102        source: &[(Dimension, usize); D],
103        requested: &[Dimension; D],
104    ) -> Option<DimensionMappings<D>> {
105        let mut source_to_requested = [0; D];
106        let mut requested_to_source = [0; D];
107        for d in 0..D {
108            let dimension = source[d].0;
109            // happy path, requested dimension is in the same order as in source order
110            if requested[d] == dimension {
111                source_to_requested[d] = d;
112                requested_to_source[d] = d;
113            } else {
114                // If dimensions are in a different order, find the dimension with the
115                // matching dimension name for both mappings at this position in the order.
116                // Since both lists are the same length and we know our source order won't contain
117                // duplicates this also ensures the two lists have exactly the same set of names
118                // as otherwise one of these `find`s will fail.
119                let (n_in_requested, _) = requested
120                    .iter()
121                    .enumerate()
122                    .find(|(_, d)| **d == dimension)?;
123                source_to_requested[d] = n_in_requested;
124                let dimension = requested[d];
125                let (n_in_source, _) = source
126                    .iter()
127                    .enumerate()
128                    .find(|(_, (d, _))| *d == dimension)?;
129                requested_to_source[d] = n_in_source;
130            };
131        }
132        Some(DimensionMappings {
133            source_to_requested,
134            requested_to_source,
135        })
136    }
137
138    /// Reorders some indexes according to the dimension mapping to return the
139    /// indexes in the source order
140    #[inline]
141    pub(crate) fn map_dimensions_to_source(&self, indexes: &[usize; D]) -> [usize; D] {
142        // Our input is in requested order and we return indexes in the source order, so for each
143        // dimension to return (in source order) we're looking up which index from the input to
144        // use, just like for map_linear_data_layout_to_transposed.
145        std::array::from_fn(|d| indexes[self.source_to_requested[d]])
146    }
147
148    /// Reorders some shape according to the dimension mapping to return the
149    /// shape in the requested order
150    #[inline]
151    pub(crate) fn map_shape_to_requested(
152        &self,
153        source: &[(Dimension, usize); D],
154    ) -> [(Dimension, usize); D] {
155        // For each d we're returning, we're giving what the requested dth dimension is
156        // in the source, so we use requested_to_source for mapping. This is different to indexing
157        // or map_linear_data_layout_to_transposed because our output here is in requested order,
158        // not our input.
159        std::array::from_fn(|d| source[self.requested_to_source[d]])
160    }
161
162    #[inline]
163    pub(crate) fn map_linear_data_layout_to_transposed(
164        &self,
165        order: &[Dimension; D],
166    ) -> [Dimension; D] {
167        // In most simple cases the transformation for source -> requested is the same as
168        // requested -> source so the data layout maps the same way as the shape. However,
169        // in most complex 3D or higher cases, we can transpose a tensor in such a way that
170        // the data_layout does not map the same way. To conceptualise this, consider we first
171        // just used a TensorAccess to swap the indexes from ["batch", "row", "column"] to
172        // ["row", "column", "batch"]. This is a mapping of [0->2, 1->0, 2->1] (or if we only
173        // want to write where each dimension ends up, we can call this a mapping of [2, 0, 1]).
174        // TensorAccess doesn't change what each dimension refers to in memory, so the data_layout
175        // stays as ["batch", "row", "column"]. We can then wrap our TensorAccess in a TensorRename
176        // to emulate the behavior of a TensorTranspose. We rename ["row", "column", "batch"]
177        // to ["batch", "row", "column"] again, retaining our swapped dimensions as they are.
178        // Now row in our data_layout is called batch, column becomes row, and batch becomes
179        // column. Hence our data layout becames ["column", "batch", "row"]. If we compare our
180        // data layout to what we started with (["batch", "row", "column"]) we see that we have
181        // mapped it as [1, 2, 0], not [2, 0, 1]. As a more general explanation, the data layout
182        // we return here is mapping from what the data layout of the source is to the order
183        // we're now requesting in, so we use the reverse mapping to what we use for the
184        // view_shape. This ensures the data layout we return is correct, which can always be
185        // sanity checked by constructing a TensorAccess from what we return and verifying that
186        // the data is restored to memory order (assuming the original source was in the same
187        // endianess as Tensor).
188        std::array::from_fn(|d| order[self.source_to_requested[d]])
189    }
190}
191
192/**
193 * Returns true if the dimensions are all the same length. For 0 or 1 dimensions trivially returns
194 * true. For 2 dimensions, this corresponds to a square matrix, and for 3 dimensions, a cube shaped
195 * tensor, and so on.
196 */
197pub fn is_square<const D: usize>(shape: &[(Dimension, usize); D]) -> bool {
198    if D > 1 {
199        let first = shape[0].1;
200        #[allow(clippy::needless_range_loop)]
201        for d in 1..D {
202            if shape[d].1 != first {
203                return false;
204            }
205        }
206        true
207    } else {
208        true
209    }
210}
211
212/**
213 * Returns just the dimension names of the shape, in the same order.
214 */
215pub fn names_of<const D: usize>(shape: &[(Dimension, usize); D]) -> [Dimension; D] {
216    shape.map(|(dimension, _length)| dimension)
217}
218
219pub(crate) fn has_duplicates(shape: &[(Dimension, usize)]) -> bool {
220    for i in 1..shape.len() {
221        let name = shape[i - 1].0;
222        if shape[i..].iter().any(|d| d.0 == name) {
223            return true;
224        }
225    }
226    false
227}
228
229pub(crate) fn has_duplicates_names(dimensions: &[Dimension]) -> bool {
230    for i in 1..dimensions.len() {
231        let name = dimensions[i - 1];
232        if dimensions[i..].contains(&name) {
233            return true;
234        }
235    }
236    false
237}
238
239pub(crate) fn has_duplicates_extra_names(dimensions: &[(usize, Dimension)]) -> bool {
240    for i in 1..dimensions.len() {
241        let name = dimensions[i - 1].1;
242        if dimensions[i..].iter().any(|&d| d.1 == name) {
243            return true;
244        }
245    }
246    false
247}
248
249#[test]
250fn test_dimension_mappings() {
251    let shape = [("x", 0), ("y", 0), ("z", 0)];
252
253    let mapping = DimensionMappings::new(&shape, &["x", "y", "z"]).unwrap();
254    assert_eq!(
255        DimensionMappings {
256            source_to_requested: [0, 1, 2],
257            requested_to_source: [0, 1, 2],
258        },
259        mapping
260    );
261    assert_eq!(
262        [("x", 0), ("y", 0), ("z", 0)],
263        mapping.map_shape_to_requested(&shape),
264    );
265
266    let mapping = DimensionMappings::new(&shape, &["z", "y", "x"]).unwrap();
267    assert_eq!(
268        DimensionMappings {
269            source_to_requested: [2, 1, 0],
270            requested_to_source: [2, 1, 0],
271        },
272        mapping
273    );
274    assert_eq!(
275        [("z", 0), ("y", 0), ("x", 0)],
276        mapping.map_shape_to_requested(&shape),
277    );
278
279    let mapping = DimensionMappings::new(&shape, &["z", "x", "y"]).unwrap();
280    assert_eq!(
281        DimensionMappings {
282            source_to_requested: [1, 2, 0],
283            requested_to_source: [2, 0, 1],
284        },
285        mapping
286    );
287    assert_eq!(
288        [("z", 0), ("x", 0), ("y", 0)],
289        mapping.map_shape_to_requested(&shape),
290    );
291
292    let mapping = DimensionMappings::new(&shape, &["x", "z", "y"]).unwrap();
293    assert_eq!(
294        DimensionMappings {
295            source_to_requested: [0, 2, 1],
296            requested_to_source: [0, 2, 1],
297        },
298        mapping
299    );
300    assert_eq!(
301        [("x", 0), ("z", 0), ("y", 0)],
302        mapping.map_shape_to_requested(&shape),
303    );
304
305    let mapping = DimensionMappings::new(&shape, &["y", "z", "x"]).unwrap();
306    assert_eq!(
307        DimensionMappings {
308            source_to_requested: [2, 0, 1],
309            requested_to_source: [1, 2, 0],
310        },
311        mapping
312    );
313    assert_eq!(
314        [("y", 0), ("z", 0), ("x", 0)],
315        mapping.map_shape_to_requested(&shape),
316    );
317}
318
319#[test]
320fn test_is_square() {
321    assert_eq!(true, is_square(&[]));
322    assert_eq!(true, is_square(&[("x", 1)]));
323    assert_eq!(true, is_square(&[("x", 1), ("y", 1)]));
324    assert_eq!(true, is_square(&[("x", 4), ("y", 4)]));
325    assert_eq!(false, is_square(&[("x", 4), ("y", 3)]));
326    assert_eq!(true, is_square(&[("x", 3), ("y", 3), ("z", 3)]));
327    assert_eq!(false, is_square(&[("x", 3), ("y", 4), ("z", 3)]));
328}
329
330#[test]
331fn test_duplicate_names() {
332    assert_eq!(has_duplicates_names(&["a", "b", "b", "c"]), true);
333    assert_eq!(has_duplicates_names(&["a", "b", "c", "d"]), false);
334    assert_eq!(has_duplicates_names(&["a", "b", "a", "c"]), true);
335    assert_eq!(has_duplicates_names(&["a", "a", "a", "a"]), true);
336    assert_eq!(has_duplicates_names(&["a", "b", "c", "c"]), true);
337}