1pub mod broadcast;
4pub mod expand_dims;
5pub mod flip;
6pub mod into_dim;
7pub mod reshape;
8pub mod reshape_assume_contig;
9pub mod squeeze;
10pub mod to_contig;
11pub mod to_layout;
12pub mod transpose;
13
14pub mod exports {
15 use super::*;
16
17 pub use broadcast::*;
18 pub use expand_dims::*;
19 pub use flip::*;
20 pub use into_dim::*;
21 pub use reshape::*;
22 pub use reshape_assume_contig::*;
23 pub use squeeze::*;
24 pub use to_contig::*;
25 pub use to_layout::*;
26 pub use transpose::*;
27}
28
29#[cfg(test)]
30mod test_reshape {
31 use crate::prelude_dev::*;
32
33 #[test]
34 fn test_playground() {
35 #[cfg(not(feature = "col_major"))]
36 {
37 let a1 = linspace((1.0, 24.0, 24));
38 let a2 = a1.to_shape([2, 3, 4]);
39 let default_order = a1.device().default_order();
40 println!("{a2:?}");
41 println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
42
43 let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
44 println!("{v:?}");
45
46 let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
47 let b2 = b1.to_shape([24]);
48 println!("{b2:?}");
49 println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
50
51 let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
52 println!("{v:?}");
53 }
54 #[cfg(feature = "col_major")]
55 {
56 let a1 = linspace((1.0, 24.0, 24));
57 let a2 = a1.to_shape([2, 3, 4]);
58 let default_order = a1.device().default_order();
59 println!("{a2:?}");
60 println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
61 println!("a2[:, :, 0] =\n{:}", a2.i((.., .., 0)));
62 println!("a2[:, :, 1] =\n{:}", a2.i((.., .., 1)));
63 println!("a2[:, :, 2] =\n{:}", a2.i((.., .., 2)));
64 println!("a2[:, :, 3] =\n{:}", a2.i((.., .., 3)));
65
66 let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
67 println!("{v:?}");
68
69 let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
70 let b2 = b1.to_shape([24]);
71 println!("{b2:?}");
72 println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
73
74 let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
75 println!("{v:?}");
76 }
77 }
78
79 #[test]
80 fn test_contig() {
81 #[cfg(not(feature = "col_major"))]
82 {
83 let layout_in = vec![2, 3, 4].c();
84 let default_order = RowMajor;
85 let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
86 assert_eq!(layout_out.unwrap(), vec![2, 3, 4].c());
87
88 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
89 assert_eq!(layout_out.unwrap(), vec![3, 2, 4].c());
90
91 let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
92 assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].c());
93 }
94 #[cfg(feature = "col_major")]
95 {
96 let layout_in = vec![2, 3, 4].f();
97 let default_order = ColMajor;
98 let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
99 assert_eq!(layout_out.unwrap(), vec![2, 3, 4].f());
100
101 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
102 assert_eq!(layout_out.unwrap(), vec![3, 2, 4].f());
103
104 let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
105 assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].f());
106 }
107 }
108
109 #[test]
110 fn test_partial_contig() {
111 #[cfg(not(feature = "col_major"))]
112 {
113 let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 3], 810).unwrap();
116 let default_order = RowMajor;
117
118 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
119 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
120 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 9, 3]);
121 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
122
123 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
124 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![10, 27, 3]);
125 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![243, 9, 3]);
126 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
127
128 let layout_out = layout_reshapeable(&layout_in, &vec![1, 10, 1, 27, 3], default_order).unwrap();
130 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![1, 10, 1, 27, 3]);
131 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2430, 243, 243, 9, 3]);
133
134 let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 2], 813).unwrap();
137
138 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
139 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
140 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 6, 2]);
141 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
142
143 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
144 assert!(layout_out.is_none());
145 }
146 #[cfg(feature = "col_major")]
147 {
148 let layout_in = Layout::new(vec![6, 15, 9], vec![3, 18, 270], 810).unwrap();
149 let default_order = ColMajor;
150
151 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
152 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
153 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 18, 162]);
154 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
155
156 let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 10], default_order).unwrap();
157 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 10]);
158 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243]);
159 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
160
161 let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 1, 10, 1], default_order).unwrap();
163 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 1, 10, 1]);
164 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243, 243, 2430]);
166
167 let layout_in = Layout::new(vec![6, 15, 9], vec![2, 18, 270], 813).unwrap();
170
171 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
172 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
173 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2, 6, 18, 162]);
174 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
175
176 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
177 assert!(layout_out.is_none());
178 }
179 }
180
181 #[test]
182 fn test_minus_stride() {
183 #[cfg(not(feature = "col_major"))]
184 {
185 let layout_in = Layout::new(vec![9, 15, 6], vec![270, -18, -3], 1079).unwrap();
188 let default_order = RowMajor;
189
190 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
191 assert!(layout_out.is_none());
192
193 let layout_out = layout_reshapeable(&layout_in, &vec![3, 3, 10, 9], default_order).unwrap();
194 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 3, 10, 9]);
195 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, -27, -3]);
196 }
197 }
198
199 #[test]
200 fn test_broadcast_reshape() {
201 #[cfg(not(feature = "col_major"))]
202 {
203 let layout_in = unsafe { Layout::new_unchecked(vec![12, 16, 15, 18], vec![270, 0, 18, 1], 0) };
206 let default_order = RowMajor;
207
208 let layout_out = layout_reshapeable(&layout_in, &vec![4, 3, 4, 4, 9, 1, 30], default_order).unwrap();
209 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![4, 3, 4, 4, 9, 1, 30]);
210 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, 0, 0, 30, 30, 1]);
211
212 let layout_out = layout_reshapeable(&layout_in, &vec![16, 12, 15, 18], default_order).unwrap();
213 assert!(layout_out.is_none());
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use crate::prelude_dev::*;
221
222 #[test]
223 fn test_to_shape_assume_contig() {
224 let a = linspace((2.5, 3.2, 16));
225 let b = a.to_shape_assume_contig_f([4, 4]).unwrap();
226 println!("{b:.3?}");
227 }
228
229 #[test]
230 fn test_expand_dims() {
231 let a: Tensor<f64, _> = zeros([4, 9, 8]);
232 let b = a.expand_dims(2);
233 assert_eq!(b.shape(), &[4, 9, 1, 8]);
234 let b = a.expand_dims([1, 3]);
235 assert_eq!(b.shape(), &[4, 1, 9, 1, 8]);
236 let b = a.expand_dims([1, -1]);
237 assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
238 let b = a.expand_dims([-1, -4, 1, 0]);
239 assert_eq!(b.shape(), &[1, 1, 4, 1, 9, 8, 1]);
240 }
241
242 #[test]
243 fn test_squeeze() {
244 let a: Tensor<f64, _> = zeros([4, 1, 9, 1, 8, 1]);
245 let b = a.squeeze(3);
246 assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
247 let b = a.squeeze([1, 3]);
248 assert_eq!(b.shape(), &[4, 9, 8, 1]);
249 let b = a.squeeze([1, -1]);
250 assert_eq!(b.shape(), &[4, 9, 1, 8]);
251 let b = a.squeeze_f(-7);
252 assert!(b.is_err());
253 }
254
255 #[test]
256 fn test_flip() {
257 let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
258 println!("{a:?}");
259
260 let b = a.flip(1);
261 println!("{b:?}");
262 assert_eq!(b.shape(), &[2, 3, 4]);
263 let c = a.flip([0, -1]);
264 println!("{c:?}");
265 assert_eq!(c.shape(), &[2, 3, 4]);
266 }
267
268 #[test]
269 fn test_swapaxes() {
270 let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
271 println!("{a:?}");
272
273 let b = a.swapaxes(0, 1);
274 println!("{b:?}");
275 assert_eq!(b.shape(), &[3, 2, 4]);
276 }
277
278 #[test]
279 fn test_to_shape() {
280 let a = linspace((0.0, 15.0, 16));
281 let mut a = a.to_shape([4, 4]);
282 a.layout = Layout::new(vec![2, 2], vec![2, 4], 0).unwrap();
283 println!("{a:?}");
284 let b = a.to_shape([2, 2]);
285 println!("{b:?}");
286
287 let c = a.to_shape([2, -1]);
288 println!("{c:?}");
289 assert_eq!(c.shape(), &[2, 2]);
290
291 let d = a.to_shape_f([3, -1]);
292 assert!(d.is_err());
293 }
294
295 #[test]
296 fn test_broadcast_to() {
297 #[cfg(not(feature = "col_major"))]
298 {
299 let a = linspace((0.0, 15.0, 16));
300 let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
301 let a = a.to_broadcast_f([6, 4, 3, 4]).unwrap();
302 println!("{a:?}");
303 assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([6, 4, 3, 4], [0, 4, 0, 1], 0) });
304 }
305 #[cfg(feature = "col_major")]
306 {
307 let a = linspace((0.0, 15.0, 16));
308 let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
309 let a = a.to_broadcast_f([4, 3, 4, 6]).unwrap();
310 println!("{a:?}");
311 assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([4, 3, 4, 6], [1, 0, 4, 0], 0) });
312 }
313 }
314
315 #[test]
316 fn test_to_layout() {
317 let a = linspace((0.0, 15.0, 16));
318 let a = a.change_shape([4, 4]);
319 let a = a.into_layout(Layout::new([2, 8], [12, 120], 8).unwrap());
320 println!("{a:?}");
321 }
322}