math/tensor/
tensor_shape.rs

1use crate::tensor::{AxisIndex, Unitless};
2use num::ToPrimitive;
3use std::{collections::HashSet, iter::FromIterator};
4
5/// The shape of an N-dimensional tensor has a size for each dimension, with an
6/// associated stride, e.g., a row-major 3 x 5 matrix will have a stride of 5
7/// for the dimension of size 3 and a stride of 1 for the dimension of size 5,
8/// and the resulting `dims_strides` is `[(3, 5), (5, 1)]`. Index 0 of
9/// `dims_strides` always refers to the leftmost dimension.
10#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
11pub struct TensorShape {
12    pub dims_strides: Vec<(Unitless, Unitless)>,
13}
14
15impl TensorShape {
16    pub fn dims(&self) -> Vec<Unitless> {
17        self.dims_strides.iter().map(|(dim, _)| *dim).collect()
18    }
19
20    pub fn strides(&self) -> Vec<Unitless> {
21        self.dims_strides
22            .iter()
23            .map(|(_, stride)| *stride)
24            .collect()
25    }
26
27    pub fn ndim(&self) -> usize {
28        self.dims_strides.len()
29    }
30
31    pub fn num_elements(&self) -> usize {
32        if self.dims_strides.len() > 0 {
33            self.dims_strides
34                .iter()
35                .fold(1, |acc, &(d, _)| acc * d as usize)
36        } else {
37            0
38        }
39    }
40
41    pub fn to_transposed(&self, axes: Vec<AxisIndex>) -> TensorShape {
42        assert_eq!(
43            axes.len(),
44            self.dims_strides.len(),
45            "length of axes ({}) != length of dims_strides ({})",
46            axes.len(),
47            self.dims_strides.len()
48        );
49        assert_eq!(
50            HashSet::<AxisIndex>::from_iter(axes.clone().into_iter()).len(),
51            self.dims_strides.len(),
52            "all axes must be distinct"
53        );
54        let dims_strides =
55            axes.into_iter().map(|i| self.dims_strides[i]).collect();
56        TensorShape {
57            dims_strides,
58        }
59    }
60}
61
62pub trait HasTensorShape {
63    fn shape(&self) -> &TensorShape;
64}
65
66macro_rules! impl_from_for_tensor_shape {
67    ($t:ty) => {
68        impl From<$t> for TensorShape {
69            fn from(shape: $t) -> Self {
70                // default to row-major order
71                let strides: Vec<Unitless> = shape
72                    .iter()
73                    .rev()
74                    .scan(1i64, |acc, len| {
75                        let s = *acc;
76                        *acc *= *len as i64;
77                        Some(s)
78                    })
79                    .collect();
80
81                TensorShape {
82                    dims_strides: shape
83                        .iter()
84                        .map(|s| s.to_i64().unwrap())
85                        .zip(strides.into_iter().rev())
86                        .collect(),
87                }
88            }
89        }
90    };
91}
92
93impl_from_for_tensor_shape!(Vec<i32>);
94impl_from_for_tensor_shape!(Vec<u32>);
95impl_from_for_tensor_shape!(Vec<i64>);
96impl_from_for_tensor_shape!(Vec<u64>);
97impl_from_for_tensor_shape!(Vec<isize>);
98impl_from_for_tensor_shape!(Vec<usize>);
99
100impl_from_for_tensor_shape!(&Vec<i32>);
101impl_from_for_tensor_shape!(&Vec<u32>);
102impl_from_for_tensor_shape!(&Vec<i64>);
103impl_from_for_tensor_shape!(&Vec<u64>);
104impl_from_for_tensor_shape!(&Vec<isize>);
105impl_from_for_tensor_shape!(&Vec<usize>);
106
107// implementing for small fixed-size arrays for ergonomic reasons
108impl_from_for_tensor_shape!([i32; 1]);
109impl_from_for_tensor_shape!([i32; 2]);
110impl_from_for_tensor_shape!([i32; 3]);
111impl_from_for_tensor_shape!([i32; 4]);
112impl_from_for_tensor_shape!([i32; 5]);
113impl_from_for_tensor_shape!([i32; 6]);
114impl_from_for_tensor_shape!([i32; 7]);
115impl_from_for_tensor_shape!([i32; 8]);
116
117impl_from_for_tensor_shape!([u32; 1]);
118impl_from_for_tensor_shape!([u32; 2]);
119impl_from_for_tensor_shape!([u32; 3]);
120impl_from_for_tensor_shape!([u32; 4]);
121impl_from_for_tensor_shape!([u32; 5]);
122impl_from_for_tensor_shape!([u32; 6]);
123impl_from_for_tensor_shape!([u32; 7]);
124impl_from_for_tensor_shape!([u32; 8]);
125
126impl_from_for_tensor_shape!([i64; 1]);
127impl_from_for_tensor_shape!([i64; 2]);
128impl_from_for_tensor_shape!([i64; 3]);
129impl_from_for_tensor_shape!([i64; 4]);
130impl_from_for_tensor_shape!([i64; 5]);
131impl_from_for_tensor_shape!([i64; 6]);
132impl_from_for_tensor_shape!([i64; 7]);
133impl_from_for_tensor_shape!([i64; 8]);
134
135impl_from_for_tensor_shape!([u64; 1]);
136impl_from_for_tensor_shape!([u64; 2]);
137impl_from_for_tensor_shape!([u64; 3]);
138impl_from_for_tensor_shape!([u64; 4]);
139impl_from_for_tensor_shape!([u64; 5]);
140impl_from_for_tensor_shape!([u64; 6]);
141impl_from_for_tensor_shape!([u64; 7]);
142impl_from_for_tensor_shape!([u64; 8]);
143
144impl_from_for_tensor_shape!([isize; 1]);
145impl_from_for_tensor_shape!([isize; 2]);
146impl_from_for_tensor_shape!([isize; 3]);
147impl_from_for_tensor_shape!([isize; 4]);
148impl_from_for_tensor_shape!([isize; 5]);
149impl_from_for_tensor_shape!([isize; 6]);
150impl_from_for_tensor_shape!([isize; 7]);
151impl_from_for_tensor_shape!([isize; 8]);
152
153impl_from_for_tensor_shape!([usize; 1]);
154impl_from_for_tensor_shape!([usize; 2]);
155impl_from_for_tensor_shape!([usize; 3]);
156impl_from_for_tensor_shape!([usize; 4]);
157impl_from_for_tensor_shape!([usize; 5]);
158impl_from_for_tensor_shape!([usize; 6]);
159impl_from_for_tensor_shape!([usize; 7]);
160impl_from_for_tensor_shape!([usize; 8]);
161
162impl_from_for_tensor_shape!(&[isize; 1]);
163impl_from_for_tensor_shape!(&[isize; 2]);
164impl_from_for_tensor_shape!(&[isize; 3]);
165impl_from_for_tensor_shape!(&[isize; 4]);
166impl_from_for_tensor_shape!(&[isize; 5]);
167impl_from_for_tensor_shape!(&[isize; 6]);
168impl_from_for_tensor_shape!(&[isize; 7]);
169impl_from_for_tensor_shape!(&[isize; 8]);
170
171impl_from_for_tensor_shape!(&[usize; 1]);
172impl_from_for_tensor_shape!(&[usize; 2]);
173impl_from_for_tensor_shape!(&[usize; 3]);
174impl_from_for_tensor_shape!(&[usize; 4]);
175impl_from_for_tensor_shape!(&[usize; 5]);
176impl_from_for_tensor_shape!(&[usize; 6]);
177impl_from_for_tensor_shape!(&[usize; 7]);
178impl_from_for_tensor_shape!(&[usize; 8]);
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::tensor::Unitless;
184
185    #[test]
186    fn test_tensor_shape() {
187        {
188            let shape = TensorShape::from([2, 4, 3]);
189            assert_eq!(shape.dims(), vec![2, 4, 3]);
190            assert_eq!(shape.strides(), vec![12, 3, 1]);
191            assert_eq!(shape.ndim(), 3);
192        }
193        {
194            let empty_shape = TensorShape::from(Vec::<Unitless>::new());
195            assert_eq!(empty_shape.dims(), vec![]);
196            assert_eq!(empty_shape.strides(), vec![]);
197            assert_eq!(empty_shape.ndim(), 0);
198        }
199    }
200
201    #[test]
202    fn test_tensor_shape_from_trait() {
203        macro_rules! check_from_iter {
204            ($iter:expr) => {
205                let tensor_shape = TensorShape::from($iter);
206                assert_eq!(tensor_shape.dims_strides, vec![
207                    (3, 10),
208                    (2, 5),
209                    (5, 1)
210                ]);
211            };
212        }
213        check_from_iter!(vec![3i32, 2, 5]);
214        check_from_iter!(vec![3u32, 2, 5]);
215        check_from_iter!(vec![3i64, 2, 5]);
216        check_from_iter!(vec![3u64, 2, 5]);
217        check_from_iter!(vec![3isize, 2, 5]);
218        check_from_iter!(vec![3usize, 2, 5]);
219        check_from_iter!(&vec![3i32, 2, 5]);
220        check_from_iter!(&vec![3u32, 2, 5]);
221        check_from_iter!(&vec![3i64, 2, 5]);
222        check_from_iter!(&vec![3u64, 2, 5]);
223        check_from_iter!(&vec![3isize, 2, 5]);
224        check_from_iter!(&vec![3usize, 2, 5]);
225        check_from_iter!([3i32, 2, 5]);
226        check_from_iter!([3u32, 2, 5]);
227        check_from_iter!([3i64, 2, 5]);
228        check_from_iter!([3u64, 2, 5]);
229        check_from_iter!([3isize, 2, 5]);
230        check_from_iter!([3usize, 2, 5]);
231        check_from_iter!(&[3isize, 2, 5]);
232        check_from_iter!(&[3usize, 2, 5]);
233    }
234}