easy_ml/tensors/dimensions.rs
1/*!
2 * Utilities to manipulate Dimensions.
3 *
4 * Contains a number of utility functions for manipulating the dimensions of tensors.
5 *
6 * # Terminology
7 *
8 * Tensors in Easy ML have a **shape** of type `[(Dimension, usize); D]`, where D is a compile time
9 * constant. This type defines a list of dimension names and the lengths along each dimension name.
10 * A tensor with some shape will have data ranging from 0 to the length - 1 along each
11 * dimension name.
12 * Often we want to call methods which only take dimension names in some order, so those
13 * **dimension names** have a type of `[Dimension; D]`. We also may want to index into a Tensor,
14 * which is done by providing the **index**es only, with a type of `[usize; D]`.
15 *
16 * Dimensions and dimension names in Easy ML APIs are treated like lists, the order of
17 * each dimension does make a difference for equality definitions, mathematical operations,
18 * and can be a factor for the order of iteration and indexing. However, many high level APIs
19 * that are not directly involved with order or indexing require only a dimension name and these
20 * are usually less concerned with the order of the dimensions.
21 */
22
23use crate::tensors::Dimension;
24
25/**
26 * Returns the product of the dimension lengths in the provided shape.
27 *
28 * This is equal to the number of elements that will be stored for these dimensions.
29 * A 0 dimensional tensor stores exactly 1 element, a 1 dimensional tensor stores N elements,
30 * a 2 dimensional tensor stores NxM elements and so on.
31 */
32pub fn elements<const D: usize>(shape: &[(Dimension, usize); D]) -> usize {
33 shape.iter().map(|d| d.1).product()
34}
35
36/**
37 * Finds the position of the dimension name in the shape.
38 *
39 * `None` is returned if the dimension name is not in the shape.
40 */
41pub fn position_of<const D: usize>(
42 shape: &[(Dimension, usize); D],
43 dimension: Dimension,
44) -> Option<usize> {
45 shape.iter().position(|(d, _)| d == &dimension)
46}
47
48/**
49 * Checks if the dimension name is in the shape.
50 */
51pub fn contains<const D: usize>(shape: &[(Dimension, usize); D], dimension: Dimension) -> bool {
52 shape.iter().any(|(d, _)| d == &dimension)
53}
54
55/**
56 * Returns the length of the dimension name provided, if one is present in the shape.
57 */
58pub fn length_of<const D: usize>(
59 shape: &[(Dimension, usize); D],
60 dimension: Dimension,
61) -> Option<usize> {
62 shape
63 .iter()
64 .find(|(d, _)| *d == dimension)
65 .map(|(_, length)| *length)
66}
67
68/**
69 * Returns the last index of the dimension name provided, if one is present in the shape.
70 *
71 * This is always 1 less than the length, the 'index' in this sense is based on what the
72 * shape is, not any implementation index. If for some reason a shape with a 0
73 * length was given, the last index will saturate at 0.
74 */
75pub fn last_index_of<const D: usize>(
76 shape: &[(Dimension, usize); D],
77 dimension: Dimension,
78) -> Option<usize> {
79 length_of(shape, dimension).map(|length| length.saturating_sub(1))
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub(crate) struct DimensionMappings<const D: usize> {
84 source_to_requested: [usize; D],
85 requested_to_source: [usize; D],
86}
87
88impl<const D: usize> DimensionMappings<D> {
89 pub(crate) fn no_op_mapping() -> DimensionMappings<D> {
90 DimensionMappings {
91 source_to_requested: std::array::from_fn(|d| d),
92 requested_to_source: std::array::from_fn(|d| d),
93 }
94 }
95
96 // Computes both mappings from from a shape in source order and a matching set of
97 // dimensions in an arbitary order.
98 // If the source order is x,y,z but the requested order is z,x,y then the mapping
99 // from source to requested is [1,2,0] (x becomes second, y becomes last, z becomes first) and
100 // from requested to source is [2,0,1] (z becones last, x becomes first, y becomes second).
101 pub(crate) fn new(
102 source: &[(Dimension, usize); D],
103 requested: &[Dimension; D],
104 ) -> Option<DimensionMappings<D>> {
105 let mut source_to_requested = [0; D];
106 let mut requested_to_source = [0; D];
107 for d in 0..D {
108 let dimension = source[d].0;
109 // happy path, requested dimension is in the same order as in source order
110 if requested[d] == dimension {
111 source_to_requested[d] = d;
112 requested_to_source[d] = d;
113 } else {
114 // If dimensions are in a different order, find the dimension with the
115 // matching dimension name for both mappings at this position in the order.
116 // Since both lists are the same length and we know our source order won't contain
117 // duplicates this also ensures the two lists have exactly the same set of names
118 // as otherwise one of these `find`s will fail.
119 let (n_in_requested, _) = requested
120 .iter()
121 .enumerate()
122 .find(|(_, d)| **d == dimension)?;
123 source_to_requested[d] = n_in_requested;
124 let dimension = requested[d];
125 let (n_in_source, _) = source
126 .iter()
127 .enumerate()
128 .find(|(_, (d, _))| *d == dimension)?;
129 requested_to_source[d] = n_in_source;
130 };
131 }
132 Some(DimensionMappings {
133 source_to_requested,
134 requested_to_source,
135 })
136 }
137
138 /// Reorders some indexes according to the dimension mapping to return the
139 /// indexes in the source order
140 #[inline]
141 pub(crate) fn map_dimensions_to_source(&self, indexes: &[usize; D]) -> [usize; D] {
142 // Our input is in requested order and we return indexes in the source order, so for each
143 // dimension to return (in source order) we're looking up which index from the input to
144 // use, just like for map_linear_data_layout_to_transposed.
145 std::array::from_fn(|d| indexes[self.source_to_requested[d]])
146 }
147
148 /// Reorders some shape according to the dimension mapping to return the
149 /// shape in the requested order
150 #[inline]
151 pub(crate) fn map_shape_to_requested(
152 &self,
153 source: &[(Dimension, usize); D],
154 ) -> [(Dimension, usize); D] {
155 // For each d we're returning, we're giving what the requested dth dimension is
156 // in the source, so we use requested_to_source for mapping. This is different to indexing
157 // or map_linear_data_layout_to_transposed because our output here is in requested order,
158 // not our input.
159 std::array::from_fn(|d| source[self.requested_to_source[d]])
160 }
161
162 #[inline]
163 pub(crate) fn map_linear_data_layout_to_transposed(
164 &self,
165 order: &[Dimension; D],
166 ) -> [Dimension; D] {
167 // In most simple cases the transformation for source -> requested is the same as
168 // requested -> source so the data layout maps the same way as the shape. However,
169 // in most complex 3D or higher cases, we can transpose a tensor in such a way that
170 // the data_layout does not map the same way. To conceptualise this, consider we first
171 // just used a TensorAccess to swap the indexes from ["batch", "row", "column"] to
172 // ["row", "column", "batch"]. This is a mapping of [0->2, 1->0, 2->1] (or if we only
173 // want to write where each dimension ends up, we can call this a mapping of [2, 0, 1]).
174 // TensorAccess doesn't change what each dimension refers to in memory, so the data_layout
175 // stays as ["batch", "row", "column"]. We can then wrap our TensorAccess in a TensorRename
176 // to emulate the behavior of a TensorTranspose. We rename ["row", "column", "batch"]
177 // to ["batch", "row", "column"] again, retaining our swapped dimensions as they are.
178 // Now row in our data_layout is called batch, column becomes row, and batch becomes
179 // column. Hence our data layout becames ["column", "batch", "row"]. If we compare our
180 // data layout to what we started with (["batch", "row", "column"]) we see that we have
181 // mapped it as [1, 2, 0], not [2, 0, 1]. As a more general explanation, the data layout
182 // we return here is mapping from what the data layout of the source is to the order
183 // we're now requesting in, so we use the reverse mapping to what we use for the
184 // view_shape. This ensures the data layout we return is correct, which can always be
185 // sanity checked by constructing a TensorAccess from what we return and verifying that
186 // the data is restored to memory order (assuming the original source was in the same
187 // endianess as Tensor).
188 std::array::from_fn(|d| order[self.source_to_requested[d]])
189 }
190}
191
192/**
193 * Returns true if the dimensions are all the same length. For 0 or 1 dimensions trivially returns
194 * true. For 2 dimensions, this corresponds to a square matrix, and for 3 dimensions, a cube shaped
195 * tensor, and so on.
196 */
197pub fn is_square<const D: usize>(shape: &[(Dimension, usize); D]) -> bool {
198 if D > 1 {
199 let first = shape[0].1;
200 #[allow(clippy::needless_range_loop)]
201 for d in 1..D {
202 if shape[d].1 != first {
203 return false;
204 }
205 }
206 true
207 } else {
208 true
209 }
210}
211
212/**
213 * Returns just the dimension names of the shape, in the same order.
214 */
215pub fn names_of<const D: usize>(shape: &[(Dimension, usize); D]) -> [Dimension; D] {
216 shape.map(|(dimension, _length)| dimension)
217}
218
219pub(crate) fn has_duplicates(shape: &[(Dimension, usize)]) -> bool {
220 for i in 1..shape.len() {
221 let name = shape[i - 1].0;
222 if shape[i..].iter().any(|d| d.0 == name) {
223 return true;
224 }
225 }
226 false
227}
228
229pub(crate) fn has_duplicates_names(dimensions: &[Dimension]) -> bool {
230 for i in 1..dimensions.len() {
231 let name = dimensions[i - 1];
232 if dimensions[i..].contains(&name) {
233 return true;
234 }
235 }
236 false
237}
238
239pub(crate) fn has_duplicates_extra_names(dimensions: &[(usize, Dimension)]) -> bool {
240 for i in 1..dimensions.len() {
241 let name = dimensions[i - 1].1;
242 if dimensions[i..].iter().any(|&d| d.1 == name) {
243 return true;
244 }
245 }
246 false
247}
248
249#[test]
250fn test_dimension_mappings() {
251 let shape = [("x", 0), ("y", 0), ("z", 0)];
252
253 let mapping = DimensionMappings::new(&shape, &["x", "y", "z"]).unwrap();
254 assert_eq!(
255 DimensionMappings {
256 source_to_requested: [0, 1, 2],
257 requested_to_source: [0, 1, 2],
258 },
259 mapping
260 );
261 assert_eq!(
262 [("x", 0), ("y", 0), ("z", 0)],
263 mapping.map_shape_to_requested(&shape),
264 );
265
266 let mapping = DimensionMappings::new(&shape, &["z", "y", "x"]).unwrap();
267 assert_eq!(
268 DimensionMappings {
269 source_to_requested: [2, 1, 0],
270 requested_to_source: [2, 1, 0],
271 },
272 mapping
273 );
274 assert_eq!(
275 [("z", 0), ("y", 0), ("x", 0)],
276 mapping.map_shape_to_requested(&shape),
277 );
278
279 let mapping = DimensionMappings::new(&shape, &["z", "x", "y"]).unwrap();
280 assert_eq!(
281 DimensionMappings {
282 source_to_requested: [1, 2, 0],
283 requested_to_source: [2, 0, 1],
284 },
285 mapping
286 );
287 assert_eq!(
288 [("z", 0), ("x", 0), ("y", 0)],
289 mapping.map_shape_to_requested(&shape),
290 );
291
292 let mapping = DimensionMappings::new(&shape, &["x", "z", "y"]).unwrap();
293 assert_eq!(
294 DimensionMappings {
295 source_to_requested: [0, 2, 1],
296 requested_to_source: [0, 2, 1],
297 },
298 mapping
299 );
300 assert_eq!(
301 [("x", 0), ("z", 0), ("y", 0)],
302 mapping.map_shape_to_requested(&shape),
303 );
304
305 let mapping = DimensionMappings::new(&shape, &["y", "z", "x"]).unwrap();
306 assert_eq!(
307 DimensionMappings {
308 source_to_requested: [2, 0, 1],
309 requested_to_source: [1, 2, 0],
310 },
311 mapping
312 );
313 assert_eq!(
314 [("y", 0), ("z", 0), ("x", 0)],
315 mapping.map_shape_to_requested(&shape),
316 );
317}
318
319#[test]
320fn test_is_square() {
321 assert_eq!(true, is_square(&[]));
322 assert_eq!(true, is_square(&[("x", 1)]));
323 assert_eq!(true, is_square(&[("x", 1), ("y", 1)]));
324 assert_eq!(true, is_square(&[("x", 4), ("y", 4)]));
325 assert_eq!(false, is_square(&[("x", 4), ("y", 3)]));
326 assert_eq!(true, is_square(&[("x", 3), ("y", 3), ("z", 3)]));
327 assert_eq!(false, is_square(&[("x", 3), ("y", 4), ("z", 3)]));
328}
329
330#[test]
331fn test_duplicate_names() {
332 assert_eq!(has_duplicates_names(&["a", "b", "b", "c"]), true);
333 assert_eq!(has_duplicates_names(&["a", "b", "c", "d"]), false);
334 assert_eq!(has_duplicates_names(&["a", "b", "a", "c"]), true);
335 assert_eq!(has_duplicates_names(&["a", "a", "a", "a"]), true);
336 assert_eq!(has_duplicates_names(&["a", "b", "c", "c"]), true);
337}