math/tensor/
tensor_shape.rs1use crate::tensor::{AxisIndex, Unitless};
2use num::ToPrimitive;
3use std::{collections::HashSet, iter::FromIterator};
4
5#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
11pub struct TensorShape {
12 pub dims_strides: Vec<(Unitless, Unitless)>,
13}
14
15impl TensorShape {
16 pub fn dims(&self) -> Vec<Unitless> {
17 self.dims_strides.iter().map(|(dim, _)| *dim).collect()
18 }
19
20 pub fn strides(&self) -> Vec<Unitless> {
21 self.dims_strides
22 .iter()
23 .map(|(_, stride)| *stride)
24 .collect()
25 }
26
27 pub fn ndim(&self) -> usize {
28 self.dims_strides.len()
29 }
30
31 pub fn num_elements(&self) -> usize {
32 if self.dims_strides.len() > 0 {
33 self.dims_strides
34 .iter()
35 .fold(1, |acc, &(d, _)| acc * d as usize)
36 } else {
37 0
38 }
39 }
40
41 pub fn to_transposed(&self, axes: Vec<AxisIndex>) -> TensorShape {
42 assert_eq!(
43 axes.len(),
44 self.dims_strides.len(),
45 "length of axes ({}) != length of dims_strides ({})",
46 axes.len(),
47 self.dims_strides.len()
48 );
49 assert_eq!(
50 HashSet::<AxisIndex>::from_iter(axes.clone().into_iter()).len(),
51 self.dims_strides.len(),
52 "all axes must be distinct"
53 );
54 let dims_strides =
55 axes.into_iter().map(|i| self.dims_strides[i]).collect();
56 TensorShape {
57 dims_strides,
58 }
59 }
60}
61
62pub trait HasTensorShape {
63 fn shape(&self) -> &TensorShape;
64}
65
66macro_rules! impl_from_for_tensor_shape {
67 ($t:ty) => {
68 impl From<$t> for TensorShape {
69 fn from(shape: $t) -> Self {
70 let strides: Vec<Unitless> = shape
72 .iter()
73 .rev()
74 .scan(1i64, |acc, len| {
75 let s = *acc;
76 *acc *= *len as i64;
77 Some(s)
78 })
79 .collect();
80
81 TensorShape {
82 dims_strides: shape
83 .iter()
84 .map(|s| s.to_i64().unwrap())
85 .zip(strides.into_iter().rev())
86 .collect(),
87 }
88 }
89 }
90 };
91}
92
93impl_from_for_tensor_shape!(Vec<i32>);
94impl_from_for_tensor_shape!(Vec<u32>);
95impl_from_for_tensor_shape!(Vec<i64>);
96impl_from_for_tensor_shape!(Vec<u64>);
97impl_from_for_tensor_shape!(Vec<isize>);
98impl_from_for_tensor_shape!(Vec<usize>);
99
100impl_from_for_tensor_shape!(&Vec<i32>);
101impl_from_for_tensor_shape!(&Vec<u32>);
102impl_from_for_tensor_shape!(&Vec<i64>);
103impl_from_for_tensor_shape!(&Vec<u64>);
104impl_from_for_tensor_shape!(&Vec<isize>);
105impl_from_for_tensor_shape!(&Vec<usize>);
106
107impl_from_for_tensor_shape!([i32; 1]);
109impl_from_for_tensor_shape!([i32; 2]);
110impl_from_for_tensor_shape!([i32; 3]);
111impl_from_for_tensor_shape!([i32; 4]);
112impl_from_for_tensor_shape!([i32; 5]);
113impl_from_for_tensor_shape!([i32; 6]);
114impl_from_for_tensor_shape!([i32; 7]);
115impl_from_for_tensor_shape!([i32; 8]);
116
117impl_from_for_tensor_shape!([u32; 1]);
118impl_from_for_tensor_shape!([u32; 2]);
119impl_from_for_tensor_shape!([u32; 3]);
120impl_from_for_tensor_shape!([u32; 4]);
121impl_from_for_tensor_shape!([u32; 5]);
122impl_from_for_tensor_shape!([u32; 6]);
123impl_from_for_tensor_shape!([u32; 7]);
124impl_from_for_tensor_shape!([u32; 8]);
125
126impl_from_for_tensor_shape!([i64; 1]);
127impl_from_for_tensor_shape!([i64; 2]);
128impl_from_for_tensor_shape!([i64; 3]);
129impl_from_for_tensor_shape!([i64; 4]);
130impl_from_for_tensor_shape!([i64; 5]);
131impl_from_for_tensor_shape!([i64; 6]);
132impl_from_for_tensor_shape!([i64; 7]);
133impl_from_for_tensor_shape!([i64; 8]);
134
135impl_from_for_tensor_shape!([u64; 1]);
136impl_from_for_tensor_shape!([u64; 2]);
137impl_from_for_tensor_shape!([u64; 3]);
138impl_from_for_tensor_shape!([u64; 4]);
139impl_from_for_tensor_shape!([u64; 5]);
140impl_from_for_tensor_shape!([u64; 6]);
141impl_from_for_tensor_shape!([u64; 7]);
142impl_from_for_tensor_shape!([u64; 8]);
143
144impl_from_for_tensor_shape!([isize; 1]);
145impl_from_for_tensor_shape!([isize; 2]);
146impl_from_for_tensor_shape!([isize; 3]);
147impl_from_for_tensor_shape!([isize; 4]);
148impl_from_for_tensor_shape!([isize; 5]);
149impl_from_for_tensor_shape!([isize; 6]);
150impl_from_for_tensor_shape!([isize; 7]);
151impl_from_for_tensor_shape!([isize; 8]);
152
153impl_from_for_tensor_shape!([usize; 1]);
154impl_from_for_tensor_shape!([usize; 2]);
155impl_from_for_tensor_shape!([usize; 3]);
156impl_from_for_tensor_shape!([usize; 4]);
157impl_from_for_tensor_shape!([usize; 5]);
158impl_from_for_tensor_shape!([usize; 6]);
159impl_from_for_tensor_shape!([usize; 7]);
160impl_from_for_tensor_shape!([usize; 8]);
161
162impl_from_for_tensor_shape!(&[isize; 1]);
163impl_from_for_tensor_shape!(&[isize; 2]);
164impl_from_for_tensor_shape!(&[isize; 3]);
165impl_from_for_tensor_shape!(&[isize; 4]);
166impl_from_for_tensor_shape!(&[isize; 5]);
167impl_from_for_tensor_shape!(&[isize; 6]);
168impl_from_for_tensor_shape!(&[isize; 7]);
169impl_from_for_tensor_shape!(&[isize; 8]);
170
171impl_from_for_tensor_shape!(&[usize; 1]);
172impl_from_for_tensor_shape!(&[usize; 2]);
173impl_from_for_tensor_shape!(&[usize; 3]);
174impl_from_for_tensor_shape!(&[usize; 4]);
175impl_from_for_tensor_shape!(&[usize; 5]);
176impl_from_for_tensor_shape!(&[usize; 6]);
177impl_from_for_tensor_shape!(&[usize; 7]);
178impl_from_for_tensor_shape!(&[usize; 8]);
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::tensor::Unitless;
184
185 #[test]
186 fn test_tensor_shape() {
187 {
188 let shape = TensorShape::from([2, 4, 3]);
189 assert_eq!(shape.dims(), vec![2, 4, 3]);
190 assert_eq!(shape.strides(), vec![12, 3, 1]);
191 assert_eq!(shape.ndim(), 3);
192 }
193 {
194 let empty_shape = TensorShape::from(Vec::<Unitless>::new());
195 assert_eq!(empty_shape.dims(), vec![]);
196 assert_eq!(empty_shape.strides(), vec![]);
197 assert_eq!(empty_shape.ndim(), 0);
198 }
199 }
200
201 #[test]
202 fn test_tensor_shape_from_trait() {
203 macro_rules! check_from_iter {
204 ($iter:expr) => {
205 let tensor_shape = TensorShape::from($iter);
206 assert_eq!(tensor_shape.dims_strides, vec![
207 (3, 10),
208 (2, 5),
209 (5, 1)
210 ]);
211 };
212 }
213 check_from_iter!(vec![3i32, 2, 5]);
214 check_from_iter!(vec![3u32, 2, 5]);
215 check_from_iter!(vec![3i64, 2, 5]);
216 check_from_iter!(vec![3u64, 2, 5]);
217 check_from_iter!(vec![3isize, 2, 5]);
218 check_from_iter!(vec![3usize, 2, 5]);
219 check_from_iter!(&vec![3i32, 2, 5]);
220 check_from_iter!(&vec![3u32, 2, 5]);
221 check_from_iter!(&vec![3i64, 2, 5]);
222 check_from_iter!(&vec![3u64, 2, 5]);
223 check_from_iter!(&vec![3isize, 2, 5]);
224 check_from_iter!(&vec![3usize, 2, 5]);
225 check_from_iter!([3i32, 2, 5]);
226 check_from_iter!([3u32, 2, 5]);
227 check_from_iter!([3i64, 2, 5]);
228 check_from_iter!([3u64, 2, 5]);
229 check_from_iter!([3isize, 2, 5]);
230 check_from_iter!([3usize, 2, 5]);
231 check_from_iter!(&[3isize, 2, 5]);
232 check_from_iter!(&[3usize, 2, 5]);
233 }
234}