rstsr_core/tensor/manuplication/
mod.rs

1//! This module handles tensor data manipulation.
2
3pub mod broadcast;
4pub mod expand_dims;
5pub mod flip;
6pub mod into_dim;
7pub mod reshape;
8pub mod reshape_assume_contig;
9pub mod squeeze;
10pub mod to_contig;
11pub mod to_layout;
12pub mod transpose;
13
14pub mod exports {
15    use super::*;
16
17    pub use broadcast::*;
18    pub use expand_dims::*;
19    pub use flip::*;
20    pub use into_dim::*;
21    pub use reshape::*;
22    pub use reshape_assume_contig::*;
23    pub use squeeze::*;
24    pub use to_contig::*;
25    pub use to_layout::*;
26    pub use transpose::*;
27}
28
29#[cfg(test)]
30mod test_reshape {
31    use crate::prelude_dev::*;
32
33    #[test]
34    fn test_playground() {
35        #[cfg(not(feature = "col_major"))]
36        {
37            let a1 = linspace((1.0, 24.0, 24));
38            let a2 = a1.to_shape([2, 3, 4]);
39            let default_order = a1.device().default_order();
40            println!("{a2:?}");
41            println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
42
43            let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
44            println!("{v:?}");
45
46            let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
47            let b2 = b1.to_shape([24]);
48            println!("{b2:?}");
49            println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
50
51            let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
52            println!("{v:?}");
53        }
54        #[cfg(feature = "col_major")]
55        {
56            let a1 = linspace((1.0, 24.0, 24));
57            let a2 = a1.to_shape([2, 3, 4]);
58            let default_order = a1.device().default_order();
59            println!("{a2:?}");
60            println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
61            println!("a2[:, :, 0] =\n{:}", a2.i((.., .., 0)));
62            println!("a2[:, :, 1] =\n{:}", a2.i((.., .., 1)));
63            println!("a2[:, :, 2] =\n{:}", a2.i((.., .., 2)));
64            println!("a2[:, :, 3] =\n{:}", a2.i((.., .., 3)));
65
66            let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
67            println!("{v:?}");
68
69            let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
70            let b2 = b1.to_shape([24]);
71            println!("{b2:?}");
72            println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
73
74            let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
75            println!("{v:?}");
76        }
77    }
78
79    #[test]
80    fn test_contig() {
81        #[cfg(not(feature = "col_major"))]
82        {
83            let layout_in = vec![2, 3, 4].c();
84            let default_order = RowMajor;
85            let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
86            assert_eq!(layout_out.unwrap(), vec![2, 3, 4].c());
87
88            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
89            assert_eq!(layout_out.unwrap(), vec![3, 2, 4].c());
90
91            let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
92            assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].c());
93        }
94        #[cfg(feature = "col_major")]
95        {
96            let layout_in = vec![2, 3, 4].f();
97            let default_order = ColMajor;
98            let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
99            assert_eq!(layout_out.unwrap(), vec![2, 3, 4].f());
100
101            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
102            assert_eq!(layout_out.unwrap(), vec![3, 2, 4].f());
103
104            let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
105            assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].f());
106        }
107    }
108
109    #[test]
110    fn test_partial_contig() {
111        #[cfg(not(feature = "col_major"))]
112        {
113            // np.zeros(12, 15, 18); a[3:, :, ::3]
114            // this case is actually contiguous, but with stride 3
115            let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 3], 810).unwrap();
116            let default_order = RowMajor;
117
118            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
119            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
120            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 9, 3]);
121            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
122
123            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
124            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![10, 27, 3]);
125            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![243, 9, 3]);
126            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
127
128            // insert some new axes
129            let layout_out = layout_reshapeable(&layout_in, &vec![1, 10, 1, 27, 3], default_order).unwrap();
130            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![1, 10, 1, 27, 3]);
131            // strides follows c-contiguous, but zero is also valid for broadcast
132            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2430, 243, 243, 9, 3]);
133
134            // np.zeros(12, 15, 18); a[3:, :, 3:15:2]
135            // this case is not contiguous in last two dimensions
136            let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 2], 813).unwrap();
137
138            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
139            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
140            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 6, 2]);
141            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
142
143            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
144            assert!(layout_out.is_none());
145        }
146        #[cfg(feature = "col_major")]
147        {
148            let layout_in = Layout::new(vec![6, 15, 9], vec![3, 18, 270], 810).unwrap();
149            let default_order = ColMajor;
150
151            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
152            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
153            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 18, 162]);
154            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
155
156            let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 10], default_order).unwrap();
157            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 10]);
158            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243]);
159            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
160
161            // insert some new axes
162            let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 1, 10, 1], default_order).unwrap();
163            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 1, 10, 1]);
164            // strides follows f-contiguous, but zero is also valid for broadcast
165            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243, 243, 2430]);
166
167            // np.zeros(12, 15, 18); a[3:, :, 3:15:2]
168            // this case is not contiguous in last two dimensions
169            let layout_in = Layout::new(vec![6, 15, 9], vec![2, 18, 270], 813).unwrap();
170
171            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
172            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
173            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2, 6, 18, 162]);
174            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
175
176            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
177            assert!(layout_out.is_none());
178        }
179    }
180
181    #[test]
182    fn test_minus_stride() {
183        #[cfg(not(feature = "col_major"))]
184        {
185            // np.zeros(12, 15, 18); a[3:, ::-1, ::-3]
186            // this case should be seen contiguous in last two dimensions
187            let layout_in = Layout::new(vec![9, 15, 6], vec![270, -18, -3], 1079).unwrap();
188            let default_order = RowMajor;
189
190            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
191            assert!(layout_out.is_none());
192
193            let layout_out = layout_reshapeable(&layout_in, &vec![3, 3, 10, 9], default_order).unwrap();
194            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 3, 10, 9]);
195            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, -27, -3]);
196        }
197    }
198
199    #[test]
200    fn test_broadcast_reshape() {
201        #[cfg(not(feature = "col_major"))]
202        {
203            // a = np.zeros(12, 15, 18);
204            // b = np.broadcast_to(a[:, None], (12, 16, 15, 18))
205            let layout_in = unsafe { Layout::new_unchecked(vec![12, 16, 15, 18], vec![270, 0, 18, 1], 0) };
206            let default_order = RowMajor;
207
208            let layout_out = layout_reshapeable(&layout_in, &vec![4, 3, 4, 4, 9, 1, 30], default_order).unwrap();
209            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![4, 3, 4, 4, 9, 1, 30]);
210            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, 0, 0, 30, 30, 1]);
211
212            let layout_out = layout_reshapeable(&layout_in, &vec![16, 12, 15, 18], default_order).unwrap();
213            assert!(layout_out.is_none());
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use crate::prelude_dev::*;
221
222    #[test]
223    fn test_to_shape_assume_contig() {
224        let a = linspace((2.5, 3.2, 16));
225        let b = a.to_shape_assume_contig_f([4, 4]).unwrap();
226        println!("{b:.3?}");
227    }
228
229    #[test]
230    fn test_expand_dims() {
231        let a: Tensor<f64, _> = zeros([4, 9, 8]);
232        let b = a.expand_dims(2);
233        assert_eq!(b.shape(), &[4, 9, 1, 8]);
234        let b = a.expand_dims([1, 3]);
235        assert_eq!(b.shape(), &[4, 1, 9, 1, 8]);
236        let b = a.expand_dims([1, -1]);
237        assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
238        let b = a.expand_dims([-1, -4, 1, 0]);
239        assert_eq!(b.shape(), &[1, 1, 4, 1, 9, 8, 1]);
240    }
241
242    #[test]
243    fn test_squeeze() {
244        let a: Tensor<f64, _> = zeros([4, 1, 9, 1, 8, 1]);
245        let b = a.squeeze(3);
246        assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
247        let b = a.squeeze([1, 3]);
248        assert_eq!(b.shape(), &[4, 9, 8, 1]);
249        let b = a.squeeze([1, -1]);
250        assert_eq!(b.shape(), &[4, 9, 1, 8]);
251        let b = a.squeeze_f(-7);
252        assert!(b.is_err());
253    }
254
255    #[test]
256    fn test_flip() {
257        let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
258        println!("{a:?}");
259
260        let b = a.flip(1);
261        println!("{b:?}");
262        assert_eq!(b.shape(), &[2, 3, 4]);
263        let c = a.flip([0, -1]);
264        println!("{c:?}");
265        assert_eq!(c.shape(), &[2, 3, 4]);
266    }
267
268    #[test]
269    fn test_swapaxes() {
270        let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
271        println!("{a:?}");
272
273        let b = a.swapaxes(0, 1);
274        println!("{b:?}");
275        assert_eq!(b.shape(), &[3, 2, 4]);
276    }
277
278    #[test]
279    fn test_to_shape() {
280        let a = linspace((0.0, 15.0, 16));
281        let mut a = a.to_shape([4, 4]);
282        a.layout = Layout::new(vec![2, 2], vec![2, 4], 0).unwrap();
283        println!("{a:?}");
284        let b = a.to_shape([2, 2]);
285        println!("{b:?}");
286
287        let c = a.to_shape([2, -1]);
288        println!("{c:?}");
289        assert_eq!(c.shape(), &[2, 2]);
290
291        let d = a.to_shape_f([3, -1]);
292        assert!(d.is_err());
293    }
294
295    #[test]
296    fn test_broadcast_to() {
297        #[cfg(not(feature = "col_major"))]
298        {
299            let a = linspace((0.0, 15.0, 16));
300            let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
301            let a = a.to_broadcast_f([6, 4, 3, 4]).unwrap();
302            println!("{a:?}");
303            assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([6, 4, 3, 4], [0, 4, 0, 1], 0) });
304        }
305        #[cfg(feature = "col_major")]
306        {
307            let a = linspace((0.0, 15.0, 16));
308            let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
309            let a = a.to_broadcast_f([4, 3, 4, 6]).unwrap();
310            println!("{a:?}");
311            assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([4, 3, 4, 6], [1, 0, 4, 0], 0) });
312        }
313    }
314
315    #[test]
316    fn test_to_layout() {
317        let a = linspace((0.0, 15.0, 16));
318        let a = a.change_shape([4, 4]);
319        let a = a.into_layout(Layout::new([2, 8], [12, 120], 8).unwrap());
320        println!("{a:?}");
321    }
322}