rstsr_core/tensor/
reduction.rs

1use crate::prelude_dev::*;
2
3macro_rules! trait_reduction {
4    ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
5        pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
6        where
7            D: DimAPI,
8            B: $OpReduceAPI<T, D>,
9        {
10            let tensor = tensor.view();
11            tensor.device().$fn_all(tensor.raw(), tensor.layout())
12        }
13
14        pub fn $fn_axes_f<T, B, D>(
15            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
16            axes: impl TryInto<AxesIndex<isize>, Error = Error>,
17        ) -> Result<Tensor<B::TOut, B, IxD>>
18        where
19            D: DimAPI,
20            B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
21        {
22            let axes = axes.try_into()?;
23            let tensor = tensor.view();
24
25            // special case for summing all axes
26            if axes.as_ref().is_empty() {
27                let sum = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
28                let storage = tensor.device().outof_cpu_vec(vec![sum])?;
29                let layout = Layout::new(vec![], vec![], 0)?;
30                return Tensor::new_f(storage, layout);
31            }
32
33            let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
34            Tensor::new_f(storage, layout)
35        }
36
37        pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
38        where
39            D: DimAPI,
40            B: $OpReduceAPI<T, D>,
41        {
42            $fn_all_f(tensor).rstsr_unwrap()
43        }
44
45        pub fn $fn_axes<T, B, D>(
46            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
47            axes: impl TryInto<AxesIndex<isize>, Error = Error>,
48        ) -> Tensor<B::TOut, B, IxD>
49        where
50            D: DimAPI,
51            B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
52        {
53            $fn_axes_f(tensor, axes).rstsr_unwrap()
54        }
55
56        pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
57        where
58            D: DimAPI,
59            B: $OpReduceAPI<T, D>,
60        {
61            $fn_all_f(tensor)
62        }
63
64        pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
65        where
66            D: DimAPI,
67            B: $OpReduceAPI<T, D>,
68        {
69            $fn_all(tensor)
70        }
71
72        impl<R, T, B, D> TensorAny<R, T, B, D>
73        where
74            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
75            D: DimAPI,
76            B: $OpReduceAPI<T, D>,
77        {
78            pub fn $fn_all_f(&self) -> Result<B::TOut> {
79                $fn_all_f(self)
80            }
81
82            pub fn $fn_all(&self) -> B::TOut {
83                $fn_all(self)
84            }
85
86            pub fn $fn_axes_f(
87                &self,
88                axes: impl TryInto<AxesIndex<isize>, Error = Error>,
89            ) -> Result<Tensor<B::TOut, B, IxD>>
90            where
91                B: DeviceCreationAnyAPI<B::TOut>,
92            {
93                $fn_axes_f(self, axes)
94            }
95
96            pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<B::TOut, B, IxD>
97            where
98                B: DeviceCreationAnyAPI<B::TOut>,
99            {
100                $fn_axes(self, axes)
101            }
102
103            pub fn $fn_f(&self) -> Result<B::TOut> {
104                $fn_f(self)
105            }
106
107            pub fn $fn(&self) -> B::TOut {
108                $fn(self)
109            }
110        }
111    };
112}
113
114#[rustfmt::skip]
115mod impl_trait_reduction {
116    use super::*;
117    trait_reduction!(OpSumAPI, sum, sum_f, sum_axes, sum_axes_f, sum_all, sum_all_f);
118    trait_reduction!(OpMinAPI, min, min_f, min_axes, min_axes_f, min_all, min_all_f);
119    trait_reduction!(OpMaxAPI, max, max_f, max_axes, max_axes_f, max_all, max_all_f);
120    trait_reduction!(OpProdAPI, prod, prod_f, prod_axes, prod_axes_f, prod_all, prod_all_f);
121    trait_reduction!(OpMeanAPI, mean, mean_f, mean_axes, mean_axes_f, mean_all, mean_all_f);
122    trait_reduction!(OpVarAPI, var, var_f, var_axes, var_axes_f, var_all, var_all_f);
123    trait_reduction!(OpStdAPI, std, std_f, std_axes, std_axes_f, std_all, std_all_f);
124    trait_reduction!(OpL2NormAPI, l2_norm, l2_norm_f, l2_norm_axes, l2_norm_axes_f, l2_norm_all, l2_norm_all_f);
125    trait_reduction!(OpArgMinAPI, argmin, argmin_f, argmin_axes, argmin_axes_f, argmin_all, argmin_all_f);
126    trait_reduction!(OpArgMaxAPI, argmax, argmax_f, argmax_axes, argmax_axes_f, argmax_all, argmax_all_f);
127    trait_reduction!(OpAllAPI, all, all_f, all_axes, all_axes_f, all_all, all_all_f);
128    trait_reduction!(OpAnyAPI, any, any_f, any_axes, any_axes_f, any_all, any_all_f);
129    trait_reduction!(OpCountNonZeroAPI, count_nonzero, count_nonzero_f, count_nonzero_axes, count_nonzero_axes_f, count_nonzero_all, count_nonzero_all_f);
130}
131pub use impl_trait_reduction::*;
132
133macro_rules! trait_reduction_arg {
134    ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
135        pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
136        where
137            D: DimAPI,
138            B: $OpReduceAPI<T, D>,
139        {
140            let tensor = tensor.view();
141            tensor.device().$fn_all(tensor.raw(), tensor.layout())
142        }
143
144        pub fn $fn_axes_f<T, B, D>(
145            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
146            axes: impl TryInto<AxesIndex<isize>, Error = Error>,
147        ) -> Result<Tensor<IxD, B, IxD>>
148        where
149            D: DimAPI,
150            B: $OpReduceAPI<T, D> + DeviceAPI<IxD>,
151        {
152            let axes = axes.try_into()?;
153            let tensor = tensor.view();
154
155            let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
156            Tensor::new_f(storage, layout)
157        }
158
159        pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
160        where
161            D: DimAPI,
162            B: $OpReduceAPI<T, D>,
163        {
164            $fn_all_f(tensor).rstsr_unwrap()
165        }
166
167        pub fn $fn_axes<T, B, D>(
168            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
169            axes: impl TryInto<AxesIndex<isize>, Error = Error>,
170        ) -> Tensor<IxD, B, IxD>
171        where
172            D: DimAPI,
173            B: $OpReduceAPI<T, D> + DeviceAPI<IxD>,
174        {
175            $fn_axes_f(tensor, axes).rstsr_unwrap()
176        }
177
178        pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
179        where
180            D: DimAPI,
181            B: $OpReduceAPI<T, D>,
182        {
183            $fn_all_f(tensor)
184        }
185
186        pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
187        where
188            D: DimAPI,
189            B: $OpReduceAPI<T, D>,
190        {
191            $fn_all(tensor)
192        }
193
194        impl<R, T, B, D> TensorAny<R, T, B, D>
195        where
196            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
197            D: DimAPI,
198            B: $OpReduceAPI<T, D>,
199        {
200            pub fn $fn_all_f(&self) -> Result<D> {
201                $fn_all_f(self)
202            }
203
204            pub fn $fn_all(&self) -> D {
205                $fn_all(self)
206            }
207
208            pub fn $fn_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<IxD, B, IxD>>
209            where
210                B: DeviceAPI<IxD>,
211            {
212                $fn_axes_f(self, axes)
213            }
214
215            pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<IxD, B, IxD>
216            where
217                B: DeviceAPI<IxD>,
218            {
219                $fn_axes(self, axes)
220            }
221
222            pub fn $fn_f(&self) -> Result<D> {
223                $fn_f(self)
224            }
225
226            pub fn $fn(&self) -> D {
227                $fn(self)
228            }
229        }
230    };
231}
232
233trait_reduction_arg!(
234    OpUnraveledArgMinAPI,
235    unraveled_argmin,
236    unraveled_argmin_f,
237    unraveled_argmin_axes,
238    unraveled_argmin_axes_f,
239    unraveled_argmin_all,
240    unraveled_argmin_all_f
241);
242trait_reduction_arg!(
243    OpUnraveledArgMaxAPI,
244    unraveled_argmax,
245    unraveled_argmax_f,
246    unraveled_argmax_axes,
247    unraveled_argmax_axes_f,
248    unraveled_argmax_all,
249    unraveled_argmax_all_f
250);
251
252/* #region sum (bool) */
253
254pub trait TensorSumBoolAPI<B, D>
255where
256    D: DimAPI,
257    B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D>,
258{
259    fn sum_all_f(&self) -> Result<usize>;
260    fn sum_all(&self) -> usize {
261        self.sum_all_f().rstsr_unwrap()
262    }
263    fn sum_f(&self) -> Result<usize> {
264        self.sum_all_f()
265    }
266    fn sum(&self) -> usize {
267        self.sum_f().rstsr_unwrap()
268    }
269    fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<usize, B, IxD>>;
270    fn sum_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<usize, B, IxD> {
271        self.sum_axes_f(axes).rstsr_unwrap()
272    }
273}
274
275impl<R, B, D> TensorSumBoolAPI<B, D> for TensorAny<R, bool, B, D>
276where
277    R: DataAPI<Data = <B as DeviceRawAPI<bool>>::Raw>,
278    D: DimAPI,
279    B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D> + DeviceCreationAnyAPI<usize>,
280{
281    fn sum_all_f(&self) -> Result<usize> {
282        self.device().sum_all(self.raw(), self.layout())
283    }
284
285    fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<usize, B, IxD>> {
286        let axes = axes.try_into()?;
287
288        // special case for summing all axes
289        if axes.as_ref().is_empty() {
290            let sum = self.device().sum_all(self.raw(), self.layout())?;
291            let storage = self.device().outof_cpu_vec(vec![sum])?;
292            let layout = Layout::new(vec![], vec![], 0)?;
293            return Tensor::new_f(storage, layout);
294        }
295
296        let (storage, layout) = self.device().sum_axes(self.raw(), self.layout(), axes.as_ref())?;
297        Tensor::new_f(storage, layout)
298    }
299}
300
301/* #endregion */
302
303/* #region allclose */
304
305pub fn allclose_all_f<TA, TB, TE, B, DA, DB>(
306    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
307    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
308    isclose_args: impl Into<IsCloseArgs<TE>>,
309) -> Result<bool>
310where
311    DA: DimAPI,
312    DB: DimAPI,
313    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
314    TE: 'static,
315{
316    let tensor_a = tensor_a.view();
317    let tensor_b = tensor_b.view();
318    let isclose_args = isclose_args.into();
319    let device = tensor_a.device();
320
321    // check device
322    rstsr_assert!(tensor_a.device().same_device(tensor_b.device()), DeviceMismatch)?;
323
324    // check and broadcast layout
325
326    let la = tensor_a.layout().to_dim::<IxD>()?;
327    let lb = tensor_b.layout().to_dim::<IxD>()?;
328    let default_order = device.default_order();
329    let (la_b, lb_b) = broadcast_layout(&la, &lb, default_order)?;
330
331    device.allclose_all(tensor_a.raw(), &la_b, tensor_b.raw(), &lb_b, &isclose_args)
332}
333
334pub fn allclose_all<TA, TB, TE, B, DA, DB>(
335    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
336    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
337    isclose_args: impl Into<IsCloseArgs<TE>>,
338) -> bool
339where
340    DA: DimAPI,
341    DB: DimAPI,
342    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
343    TE: 'static,
344{
345    allclose_all_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
346}
347
348pub fn allclose_f<TA, TB, TE, B, DA, DB>(
349    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
350    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
351    isclose_args: impl Into<IsCloseArgs<TE>>,
352) -> Result<bool>
353where
354    DA: DimAPI,
355    DB: DimAPI,
356    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
357    TE: 'static,
358{
359    allclose_all_f(tensor_a, tensor_b, isclose_args)
360}
361
362pub fn allclose<TA, TB, TE, B, DA, DB>(
363    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
364    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
365    isclose_args: impl Into<IsCloseArgs<TE>>,
366) -> bool
367where
368    DA: DimAPI,
369    DB: DimAPI,
370    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
371    TE: 'static,
372{
373    allclose_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
374}
375
376#[macro_export]
377macro_rules! allclose {
378    ($tensor_a:expr, $tensor_b:expr, $isclose_args:expr) => {{
379        use rstsr::prelude::rstsr_funcs::allclose;
380        allclose($tensor_a, $tensor_b, $isclose_args)
381    }};
382    ($tensor_a:expr, $tensor_b:expr) => {{
383        use rstsr::prelude::rstsr_funcs::allclose;
384        allclose($tensor_a, $tensor_b, None)
385    }};
386}
387
388/* #endregion */
389
390#[cfg(test)]
391mod test {
392    use num::ToPrimitive;
393
394    use super::*;
395
396    #[test]
397    #[cfg(not(feature = "col_major"))]
398    fn test_sum_all_row_major() {
399        // DeviceCpuSerial
400        let a = arange((24, &DeviceCpuSerial::default()));
401        let s = sum_all(&a);
402        assert_eq!(s, 276);
403
404        // np.arange(3240).reshape(12, 15, 18)
405        //   .swapaxes(-1, -2)[2:-3, 1:-4:2, -1:3:-2].sum()
406        let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
407        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
408        let s = a.sum_all();
409        assert_eq!(s, 446586);
410
411        let s = a.sum_axes(());
412        println!("{s:?}");
413        assert_eq!(s.to_scalar(), 446586);
414
415        // DeviceFaer
416        let a = arange(24);
417        let s = sum_all(&a);
418        assert_eq!(s, 276);
419
420        // np.arange(3240).reshape(12, 15, 18)
421        //   .swapaxes(-1, -2)[2:-3, 1:-4:2, -1:3:-2].sum()
422        let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
423        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
424        let s = a.sum_all();
425        assert_eq!(s, 446586);
426
427        let s = a.sum_axes(());
428        println!("{s:?}");
429        assert_eq!(s.to_scalar(), 446586);
430    }
431
432    #[test]
433    #[cfg(feature = "col_major")]
434    fn test_sum_all_col_major() {
435        // DeviceCpuSerial
436        let a = arange((24, &DeviceCpuSerial::default()));
437        let s = sum_all(&a);
438        assert_eq!(s, 276);
439
440        // a = reshape(range(0, 3239), (12, 15, 18));
441        // a = permutedims(a, (1, 3, 2));
442        // a = a[3:9, 2:2:15, 15:-2:4];
443        // sum(a)
444        let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
445        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
446        let s = a.sum_all();
447        assert_eq!(s, 403662);
448
449        let s = a.sum_axes(());
450        println!("{s:?}");
451        assert_eq!(s.to_scalar(), 403662);
452
453        // DeviceFaer
454        let a = arange(24);
455        let s = sum_all(&a);
456        assert_eq!(s, 276);
457
458        let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
459        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
460        let s = a.sum_all();
461        assert_eq!(s, 403662);
462
463        let s = a.sum_axes(());
464        println!("{s:?}");
465        assert_eq!(s.to_scalar(), 403662);
466    }
467
468    #[test]
469    fn test_sum_axes() {
470        #[cfg(not(feature = "col_major"))]
471        {
472            // a = np.arange(3240).reshape(4, 6, 15, 9).transpose(2, 0, 3, 1)
473            // a.sum(axis=(0, -2))
474            // DeviceCpuSerial
475            let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
476            let s = a.sum_axes([0, -2]);
477            println!("{s:?}");
478            assert_eq!(s[[0, 1]], 27270);
479            assert_eq!(s[[1, 2]], 154845);
480            assert_eq!(s[[3, 5]], 428220);
481
482            // DeviceFaer
483            let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
484            let s = a.sum_axes([0, -2]);
485            println!("{s:?}");
486            assert_eq!(s[[0, 1]], 27270);
487            assert_eq!(s[[1, 2]], 154845);
488            assert_eq!(s[[3, 5]], 428220);
489        }
490        #[cfg(feature = "col_major")]
491        {
492            // a = reshape(range(0, 3239), (4, 6, 15, 9));
493            // a = permutedims(a, (3, 1, 4, 2));
494            // sum(a, dims=(1, 3))
495            // DeviceCpuSerial
496            let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
497            let s = a.sum_axes([0, -2]);
498            println!("{s:?}");
499            assert_eq!(s[[0, 1]], 217620);
500            assert_eq!(s[[1, 2]], 218295);
501            assert_eq!(s[[3, 5]], 220185);
502
503            // DeviceFaer
504            let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
505            let s = a.sum_axes([0, -2]);
506            println!("{s:?}");
507            assert_eq!(s[[0, 1]], 217620);
508            assert_eq!(s[[1, 2]], 218295);
509            assert_eq!(s[[3, 5]], 220185);
510        }
511    }
512
513    #[test]
514    fn test_min() {
515        // DeviceCpuSerial
516        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
517        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
518        println!("{a:}");
519        let m = a.min_axes(0);
520        assert_eq!(m.to_vec(), vec![2, 3, 1]);
521        let m = a.min_axes(1);
522        assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
523        let m = a.min_all();
524        assert_eq!(m, 1);
525
526        // DeviceFaer
527        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
528        let a = asarray((&v, [4, 3].c()));
529        println!("{a:}");
530        let m = a.min_axes(0);
531        assert_eq!(m.to_vec(), vec![2, 3, 1]);
532        let m = a.min_axes(1);
533        assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
534        let m = a.min_all();
535        assert_eq!(m, 1);
536    }
537
538    #[test]
539    fn test_mean() {
540        #[cfg(not(feature = "col_major"))]
541        {
542            // DeviceCpuSerial
543            let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
544            let m = a.mean_all();
545            assert_eq!(m, 11.5);
546
547            let m = a.mean_axes((0, 2));
548            println!("{m:}");
549            assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
550
551            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
552            println!("{m:}");
553            assert_eq!(m.to_vec(), vec![18.0, 6.0]);
554
555            // DeviceFaer
556            let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
557            let m = a.mean_all();
558            assert_eq!(m, 11.5);
559
560            let m = a.mean_axes((0, 2));
561            println!("{m:}");
562            assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
563
564            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
565            println!("{m:}");
566            assert_eq!(m.to_vec(), vec![18.0, 6.0]);
567        }
568        #[cfg(feature = "col_major")]
569        {
570            // DeviceCpuSerial
571            let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
572            let m = a.mean_all();
573            assert_eq!(m, 11.5);
574
575            let m = a.mean_axes((0, 2));
576            println!("{m:}");
577            assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
578
579            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
580            println!("{m:}");
581            assert_eq!(m.to_vec(), vec![15.0, 14.0]);
582
583            // // DeviceFaer
584            let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
585            let m = a.mean_all();
586            assert_eq!(m, 11.5);
587
588            let m = a.mean_axes((0, 2));
589            println!("{m:}");
590            assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
591
592            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
593            println!("{m:}");
594            assert_eq!(m.to_vec(), vec![15.0, 14.0]);
595        }
596    }
597
598    #[test]
599    fn test_var() {
600        // DeviceCpuSerial
601        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
602        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
603
604        let m = a.var_all();
605        println!("{m:}");
606        assert!((m - 8.409722222222221).abs() < 1e-10);
607
608        let m = a.var_axes(0);
609        println!("{m:}");
610        assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
611
612        let m = a.var_axes(1);
613        println!("{m:}");
614        assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
615
616        // DeviceFaer
617        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
618        let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
619
620        let m = a.var_all();
621        println!("{m:}");
622        assert!((m - 8.409722222222221).abs() < 1e-10);
623
624        let m = a.var_axes(0);
625        println!("{m:}");
626        assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
627
628        let m = a.var_axes(1);
629        println!("{m:}");
630        assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
631    }
632
633    #[test]
634    fn test_std() {
635        // DeviceCpuSerial
636        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
637        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
638
639        let m = a.std_all();
640        println!("{m:}");
641        assert!((m - 2.899952106884219).abs() < 1e-10);
642
643        let m = a.std_axes(0);
644        println!("{m:}");
645        assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
646
647        let m = a.std_axes(1);
648        println!("{m:}");
649        assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
650
651        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
652        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
653        let v = vr
654            .iter()
655            .zip(vi.iter())
656            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
657            .collect::<Vec<_>>();
658        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
659
660        let m = a.std_all();
661        println!("{m:}");
662        assert!((m - 4.508479664907993).abs() < 1e-10);
663
664        // DeviceFaer
665        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
666        let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
667
668        let m = a.std_all();
669        println!("{m:}");
670        assert!((m - 2.899952106884219).abs() < 1e-10);
671
672        let m = a.std_axes(0);
673        println!("{m:}");
674        assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
675
676        let m = a.std_axes(1);
677        println!("{m:}");
678        assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
679
680        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
681        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
682        let v = vr
683            .iter()
684            .zip(vi.iter())
685            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
686            .collect::<Vec<_>>();
687        let a = asarray((&v, [4, 3].c()));
688
689        let m = a.std_all();
690        println!("{m:}");
691        assert!((m - 4.508479664907993).abs() < 1e-10);
692    }
693
694    #[test]
695    fn test_l2_norm() {
696        // DeviceCpuSerial
697        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
698        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
699        let v = vr
700            .iter()
701            .zip(vi.iter())
702            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
703            .collect::<Vec<_>>();
704        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
705
706        let m = a.l2_norm_all();
707        println!("{m:}");
708        assert!((m - 33.21144381083123).abs() < 1e-10);
709
710        // DeviceFaer
711        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
712        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
713        let v = vr
714            .iter()
715            .zip(vi.iter())
716            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
717            .collect::<Vec<_>>();
718        let a = asarray((&v, [4, 3].c()));
719
720        let m = a.l2_norm_all();
721        println!("{m:}");
722        assert!((m - 33.21144381083123).abs() < 1e-10);
723    }
724
725    #[test]
726    #[cfg(feature = "rayon")]
727    fn test_large_std() {
728        #[cfg(not(feature = "col_major"))]
729        {
730            // a = np.linspace(0, 1, 1048576).reshape(16, 256, 256)
731            // b = np.linspace(1, 2, 1048576).reshape(16, 256, 256)
732            // c = a @ b
733            // print(c.mean(), c.std())
734            // print(c.std(axis=(0, 1))[[0, -1]])
735            // print(c.std(axis=(1, 2))[[0, -1]])
736            let a = linspace((0.0, 1.0, 1048576)).into_shape([16, 256, 256]);
737            let b = linspace((1.0, 2.0, 1048576)).into_shape([16, 256, 256]);
738            let c: Tensor<f64> = &a % &b;
739
740            let c_mean = c.mean_all();
741            println!("{c_mean:?}");
742            assert!((c_mean - 213.2503660477036) < 1e-6);
743
744            let c_std = c.std_all();
745            println!("{c_std:?}");
746            assert!((c_std - 148.88523481701804) < 1e-6);
747
748            let c_std_1 = c.std_axes((0, 1));
749            println!("{c_std_1}");
750            assert!(c_std_1[[0]] - 148.8763226818815 < 1e-6);
751            assert!(c_std_1[[255]] - 148.8941462322758 < 1e-6);
752
753            let c_std_2 = c.std_axes((1, 2));
754            println!("{c_std_2}");
755            assert!(c_std_2[[0]] - 4.763105902995575 < 1e-6);
756            assert!(c_std_2[[15]] - 9.093224903569157 < 1e-6);
757        }
758        #[cfg(feature = "col_major")]
759        {
760            // a = reshape(LinRange(0, 1, 1048576), (256, 256, 16));
761            // b = reshape(LinRange(1, 2, 1048576), (256, 256, 16));
762            // c = Array{Float64}(undef, 256, 256, 16);
763            // for i in 1:16
764            //     c[:, :, i] = a[:, :, i] * b[:, :, i]
765            // end
766            // mean(c), std(c)
767            // std(c, dims=(2, 3))
768            // std(c, dims=(1, 2))
769            let a = linspace((0.0, 1.0, 1048576)).into_shape([256, 256, 16]);
770            let b = linspace((1.0, 2.0, 1048576)).into_shape([256, 256, 16]);
771            let mut c: Tensor<f64> = zeros([256, 256, 16]);
772            for i in 0..16 {
773                c.i_mut((.., .., i)).assign(&a.i((.., .., i)) % &b.i((.., .., i)));
774            }
775
776            let c_mean = c.mean_all();
777            println!("{c_mean:?}");
778            assert!((c_mean - 213.25036604770355) < 1e-6);
779
780            let c_std = c.std_all();
781            println!("{c_std:?}");
782            assert!((c_std - 148.7419537312827) < 1e-6);
783
784            let c_std_1 = c.std_axes((1, 2));
785            println!("{c_std_1}");
786            assert!(c_std_1[[0]] - 148.75113653867191 < 1e-6);
787            assert!(c_std_1[[255]] - 148.7689445622776 < 1e-6);
788
789            let c_std_2 = c.std_axes((0, 1));
790            println!("{c_std_2}");
791            assert!(c_std_2[[0]] - 0.145530296246335 < 1e-6);
792            assert!(c_std_2[[15]] - 4.474611918106057 < 1e-6);
793        }
794    }
795
796    #[test]
797    fn test_unraveled_argmin() {
798        // DeviceCpuSerial
799        let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
800        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
801        println!("{a:}");
802        // [[ 8 4 2]
803        //  [ 9 7 1]
804        //  [ 2 1 8]
805        //  [ 6 10 5]]
806
807        let m = a.unraveled_argmin_all();
808        println!("{m:?}");
809        assert_eq!(m, vec![1, 2]);
810
811        let m = a.unraveled_argmin_axes(-1);
812        println!("{m:?}");
813        let m_vec = m.raw();
814        assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1], vec![2]]);
815
816        let m = a.unraveled_argmin_axes(0);
817        println!("{m:?}");
818        let m_vec = m.raw();
819        assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1]]);
820    }
821
822    #[test]
823    fn test_argmin() {
824        // DeviceCpuSerial
825        let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
826        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
827        println!("{a:}");
828        // [[ 8 4 2]
829        //  [ 9 7 1]
830        //  [ 2 1 8]
831        //  [ 6 10 5]]
832
833        let m = a.argmin_all();
834        println!("{m:?}");
835        assert_eq!(m, 5);
836
837        let m = a.argmin_axes(-1);
838        println!("{m:?}");
839        let m_vec = m.raw();
840        assert_eq!(m_vec, &vec![2, 2, 1, 2]);
841
842        let m = a.argmin_axes(0);
843        println!("{m:?}");
844        let m_vec = m.raw();
845        assert_eq!(m_vec, &vec![2, 2, 1]);
846    }
847
848    #[test]
849    fn test_all() {
850        let a = asarray((vec![true, true, false, true, true, true], [2, 3].c()));
851        let a_all = a.all_axes(-2);
852        println!("{:?}", a_all);
853        assert_eq!(a_all.raw(), &[true, true, false]);
854    }
855
856    #[test]
857    fn test_allclose_cpu_serial() {
858        use rstsr_dtype_traits::IsCloseArgsBuilder;
859
860        let mut device = DeviceCpuSerial::default();
861        device.set_default_order(RowMajor);
862        let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
863        let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
864        let result = allclose(&a, &b, None);
865        println!("Allclose result: {result}");
866        assert!(result);
867        let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
868        let result = allclose(&a, &b, args);
869        println!("Allclose result with tight args: {result}");
870        assert!(!result);
871    }
872
873    #[test]
874    #[cfg(feature = "faer")]
875    fn test_allclose_faer() {
876        use rstsr_dtype_traits::IsCloseArgsBuilder;
877
878        let mut device = DeviceFaer::default();
879        device.set_default_order(RowMajor);
880        let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
881        let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
882        let result = allclose(&a, &b, None);
883        println!("Allclose result: {result}");
884        assert!(result);
885        let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
886        let result = allclose(&a, &b, args);
887        println!("Allclose result with tight args: {result}");
888        assert!(!result);
889    }
890}