Skip to main content

rstsr_core/tensor/
reduction.rs

1use crate::prelude_dev::*;
2
3macro_rules! trait_reduction {
4    ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
5        pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
6        where
7            D: DimAPI,
8            B: $OpReduceAPI<T, D>,
9        {
10            let tensor = tensor.view();
11            tensor.device().$fn_all(tensor.raw(), tensor.layout())
12        }
13
14        pub fn $fn_axes_f<T, B, D>(
15            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
16            axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
17        ) -> Result<Tensor<B::TOut, B, IxD>>
18        where
19            D: DimAPI,
20            B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
21        {
22            let axes = axes.try_into().map_err(Into::into)?;
23            let tensor = tensor.view();
24
25            match axes {
26                AxesIndex::None => {
27                    let sum = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
28                    let storage = tensor.device().outof_cpu_vec(vec![sum])?;
29                    let layout = Layout::new(vec![], vec![], 0)?;
30                    Tensor::new_f(storage, layout)
31                },
32                _ => {
33                    let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
34                    Tensor::new_f(storage, layout)
35                },
36            }
37        }
38
39        pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
40        where
41            D: DimAPI,
42            B: $OpReduceAPI<T, D>,
43        {
44            $fn_all_f(tensor).rstsr_unwrap()
45        }
46
47        pub fn $fn_axes<T, B, D>(
48            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
49            axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
50        ) -> Tensor<B::TOut, B, IxD>
51        where
52            D: DimAPI,
53            B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
54        {
55            $fn_axes_f(tensor, axes).rstsr_unwrap()
56        }
57
58        pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
59        where
60            D: DimAPI,
61            B: $OpReduceAPI<T, D>,
62        {
63            $fn_all_f(tensor)
64        }
65
66        pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
67        where
68            D: DimAPI,
69            B: $OpReduceAPI<T, D>,
70        {
71            $fn_all(tensor)
72        }
73
74        impl<R, T, B, D> TensorAny<R, T, B, D>
75        where
76            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
77            D: DimAPI,
78            B: $OpReduceAPI<T, D>,
79        {
80            pub fn $fn_all_f(&self) -> Result<B::TOut> {
81                $fn_all_f(self)
82            }
83
84            pub fn $fn_all(&self) -> B::TOut {
85                $fn_all(self)
86            }
87
88            pub fn $fn_axes_f(
89                &self,
90                axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
91            ) -> Result<Tensor<B::TOut, B, IxD>>
92            where
93                B: DeviceCreationAnyAPI<B::TOut>,
94            {
95                $fn_axes_f(self, axes)
96            }
97
98            pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<B::TOut, B, IxD>
99            where
100                B: DeviceCreationAnyAPI<B::TOut>,
101            {
102                $fn_axes(self, axes)
103            }
104
105            pub fn $fn_f(&self) -> Result<B::TOut> {
106                $fn_f(self)
107            }
108
109            pub fn $fn(&self) -> B::TOut {
110                $fn(self)
111            }
112        }
113    };
114}
115
116#[rustfmt::skip]
117mod impl_trait_reduction {
118    use super::*;
119    trait_reduction!(OpSumAPI, sum, sum_f, sum_axes, sum_axes_f, sum_all, sum_all_f);
120    trait_reduction!(OpMinAPI, min, min_f, min_axes, min_axes_f, min_all, min_all_f);
121    trait_reduction!(OpMaxAPI, max, max_f, max_axes, max_axes_f, max_all, max_all_f);
122    trait_reduction!(OpProdAPI, prod, prod_f, prod_axes, prod_axes_f, prod_all, prod_all_f);
123    trait_reduction!(OpMeanAPI, mean, mean_f, mean_axes, mean_axes_f, mean_all, mean_all_f);
124    trait_reduction!(OpVarAPI, var, var_f, var_axes, var_axes_f, var_all, var_all_f);
125    trait_reduction!(OpStdAPI, std, std_f, std_axes, std_axes_f, std_all, std_all_f);
126    trait_reduction!(OpL2NormAPI, l2_norm, l2_norm_f, l2_norm_axes, l2_norm_axes_f, l2_norm_all, l2_norm_all_f);
127    trait_reduction!(OpArgMinAPI, argmin, argmin_f, argmin_axes, argmin_axes_f, argmin_all, argmin_all_f);
128    trait_reduction!(OpArgMaxAPI, argmax, argmax_f, argmax_axes, argmax_axes_f, argmax_all, argmax_all_f);
129    trait_reduction!(OpAllAPI, all, all_f, all_axes, all_axes_f, all_all, all_all_f);
130    trait_reduction!(OpAnyAPI, any, any_f, any_axes, any_axes_f, any_all, any_all_f);
131    trait_reduction!(OpCountNonZeroAPI, count_nonzero, count_nonzero_f, count_nonzero_axes, count_nonzero_axes_f, count_nonzero_all, count_nonzero_all_f);
132}
133pub use impl_trait_reduction::*;
134
135macro_rules! trait_reduction_arg {
136    ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
137        pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
138        where
139            D: DimAPI,
140            B: $OpReduceAPI<T, D>,
141        {
142            let tensor = tensor.view();
143            tensor.device().$fn_all(tensor.raw(), tensor.layout())
144        }
145
146        pub fn $fn_axes_f<T, B, D>(
147            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
148            axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
149        ) -> Result<Tensor<IxD, B, IxD>>
150        where
151            D: DimAPI,
152            B: $OpReduceAPI<T, D> + DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
153        {
154            let axes = axes.try_into().map_err(Into::into)?;
155            let tensor = tensor.view();
156
157            match axes {
158                AxesIndex::None => {
159                    // special case for reducing all axes
160                    let arg = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
161                    let storage = tensor.device().outof_cpu_vec(vec![arg.into()])?;
162                    let layout = Layout::new(vec![], vec![], 0)?;
163                    Tensor::new_f(storage, layout)
164                },
165                _ => {
166                    let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
167                    Tensor::new_f(storage, layout)
168                },
169            }
170        }
171
172        pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
173        where
174            D: DimAPI,
175            B: $OpReduceAPI<T, D>,
176        {
177            $fn_all_f(tensor).rstsr_unwrap()
178        }
179
180        pub fn $fn_axes<T, B, D>(
181            tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
182            axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
183        ) -> Tensor<IxD, B, IxD>
184        where
185            D: DimAPI,
186            B: $OpReduceAPI<T, D> + DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
187        {
188            $fn_axes_f(tensor, axes).rstsr_unwrap()
189        }
190
191        pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
192        where
193            D: DimAPI,
194            B: $OpReduceAPI<T, D>,
195        {
196            $fn_all_f(tensor)
197        }
198
199        pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
200        where
201            D: DimAPI,
202            B: $OpReduceAPI<T, D>,
203        {
204            $fn_all(tensor)
205        }
206
207        impl<R, T, B, D> TensorAny<R, T, B, D>
208        where
209            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
210            D: DimAPI,
211            B: $OpReduceAPI<T, D>,
212        {
213            pub fn $fn_all_f(&self) -> Result<D> {
214                $fn_all_f(self)
215            }
216
217            pub fn $fn_all(&self) -> D {
218                $fn_all(self)
219            }
220
221            pub fn $fn_axes_f(
222                &self,
223                axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
224            ) -> Result<Tensor<IxD, B, IxD>>
225            where
226                B: DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
227            {
228                $fn_axes_f(self, axes)
229            }
230
231            pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<IxD, B, IxD>
232            where
233                B: DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
234            {
235                $fn_axes(self, axes)
236            }
237
238            pub fn $fn_f(&self) -> Result<D> {
239                $fn_f(self)
240            }
241
242            pub fn $fn(&self) -> D {
243                $fn(self)
244            }
245        }
246    };
247}
248
249trait_reduction_arg!(
250    OpUnraveledArgMinAPI,
251    unraveled_argmin,
252    unraveled_argmin_f,
253    unraveled_argmin_axes,
254    unraveled_argmin_axes_f,
255    unraveled_argmin_all,
256    unraveled_argmin_all_f
257);
258trait_reduction_arg!(
259    OpUnraveledArgMaxAPI,
260    unraveled_argmax,
261    unraveled_argmax_f,
262    unraveled_argmax_axes,
263    unraveled_argmax_axes_f,
264    unraveled_argmax_all,
265    unraveled_argmax_all_f
266);
267
268/* #region sum (bool) */
269
270pub trait TensorSumBoolAPI<B, D>
271where
272    D: DimAPI,
273    B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D>,
274{
275    fn sum_all_f(&self) -> Result<usize>;
276    fn sum_all(&self) -> usize {
277        self.sum_all_f().rstsr_unwrap()
278    }
279    fn sum_f(&self) -> Result<usize> {
280        self.sum_all_f()
281    }
282    fn sum(&self) -> usize {
283        self.sum_f().rstsr_unwrap()
284    }
285    fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Result<Tensor<usize, B, IxD>>;
286    fn sum_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<usize, B, IxD> {
287        self.sum_axes_f(axes).rstsr_unwrap()
288    }
289}
290
291impl<R, B, D> TensorSumBoolAPI<B, D> for TensorAny<R, bool, B, D>
292where
293    R: DataAPI<Data = <B as DeviceRawAPI<bool>>::Raw>,
294    D: DimAPI,
295    B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D> + DeviceCreationAnyAPI<usize>,
296{
297    fn sum_all_f(&self) -> Result<usize> {
298        self.device().sum_all(self.raw(), self.layout())
299    }
300
301    fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Result<Tensor<usize, B, IxD>> {
302        let axes = axes.try_into().map_err(Into::into)?;
303
304        match axes {
305            AxesIndex::None => {
306                // special case for summing all axes
307                let sum = self.device().sum_all(self.raw(), self.layout())?;
308                let storage = self.device().outof_cpu_vec(vec![sum])?;
309                let layout = Layout::new(vec![], vec![], 0)?;
310                Tensor::new_f(storage, layout)
311            },
312            _ => {
313                let (storage, layout) = self.device().sum_axes(self.raw(), self.layout(), axes.as_ref())?;
314                Tensor::new_f(storage, layout)
315            },
316        }
317    }
318}
319
320/* #endregion */
321
322/* #region allclose */
323
324pub fn allclose_all_f<TA, TB, TE, B, DA, DB>(
325    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
326    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
327    isclose_args: impl Into<IsCloseArgs<TE>>,
328) -> Result<bool>
329where
330    DA: DimAPI,
331    DB: DimAPI,
332    B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
333    TE: 'static,
334{
335    let tensor_a = tensor_a.view();
336    let tensor_b = tensor_b.view();
337    let isclose_args = isclose_args.into();
338    let device = tensor_a.device();
339
340    // check device
341    rstsr_assert!(tensor_a.device().same_device(tensor_b.device()), DeviceMismatch)?;
342
343    // check and broadcast layout
344
345    let la = tensor_a.layout().to_dim::<IxD>()?;
346    let lb = tensor_b.layout().to_dim::<IxD>()?;
347    let default_order = device.default_order();
348    let (la_b, lb_b) = broadcast_layout(&la, &lb, default_order)?;
349
350    device.allclose_all(tensor_a.raw(), &la_b, tensor_b.raw(), &lb_b, &isclose_args)
351}
352
353pub fn allclose_all<TA, TB, TE, B, DA, DB>(
354    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
355    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
356    isclose_args: impl Into<IsCloseArgs<TE>>,
357) -> bool
358where
359    DA: DimAPI,
360    DB: DimAPI,
361    B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
362    TE: 'static,
363{
364    allclose_all_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
365}
366
367pub fn allclose_f<TA, TB, TE, B, DA, DB>(
368    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
369    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
370    isclose_args: impl Into<IsCloseArgs<TE>>,
371) -> Result<bool>
372where
373    DA: DimAPI,
374    DB: DimAPI,
375    B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
376    TE: 'static,
377{
378    allclose_all_f(tensor_a, tensor_b, isclose_args)
379}
380
381pub fn allclose<TA, TB, TE, B, DA, DB>(
382    tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
383    tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
384    isclose_args: impl Into<IsCloseArgs<TE>>,
385) -> bool
386where
387    DA: DimAPI,
388    DB: DimAPI,
389    B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
390    TE: 'static,
391{
392    allclose_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
393}
394
395#[macro_export]
396macro_rules! allclose {
397    ($tensor_a:expr, $tensor_b:expr, $isclose_args:expr) => {{
398        use rstsr::prelude::rstsr_funcs::allclose;
399        allclose($tensor_a, $tensor_b, $isclose_args)
400    }};
401    ($tensor_a:expr, $tensor_b:expr) => {{
402        use rstsr::prelude::rstsr_funcs::allclose;
403        allclose($tensor_a, $tensor_b, None)
404    }};
405}
406
407/* #endregion */
408
409#[cfg(test)]
410mod test {
411    use num::ToPrimitive;
412
413    use super::*;
414
415    #[test]
416    #[cfg(not(feature = "col_major"))]
417    fn test_sum_all_row_major() {
418        // DeviceCpuSerial
419        let a = arange((24, &DeviceCpuSerial::default()));
420        let s = sum_all(&a);
421        assert_eq!(s, 276);
422
423        // np.arange(3240).reshape(12, 15, 18)
424        //   .swapaxes(-1, -2)[2:-3, 1:-4:2, -1:3:-2].sum()
425        let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
426        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
427        let s = a.sum_all();
428        assert_eq!(s, 446586);
429
430        let s = a.sum_axes(None);
431        println!("{s:?}");
432        assert_eq!(s.to_scalar(), 446586);
433
434        // DeviceFaer
435        let a = arange(24);
436        let s = sum_all(&a);
437        assert_eq!(s, 276);
438
439        // np.arange(3240).reshape(12, 15, 18)
440        //   .swapaxes(-1, -2)[2:-3, 1:-4:2, -1:3:-2].sum()
441        let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
442        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
443        let s = a.sum_all();
444        assert_eq!(s, 446586);
445
446        let s = a.sum_axes(None);
447        println!("{s:?}");
448        assert_eq!(s.to_scalar(), 446586);
449    }
450
451    #[test]
452    #[cfg(feature = "col_major")]
453    fn test_sum_all_col_major() {
454        // DeviceCpuSerial
455        let a = arange((24, &DeviceCpuSerial::default()));
456        let s = sum_all(&a);
457        assert_eq!(s, 276);
458
459        // a = reshape(range(0, 3239), (12, 15, 18));
460        // a = permutedims(a, (1, 3, 2));
461        // a = a[3:9, 2:2:15, 15:-2:4];
462        // sum(a)
463        let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
464        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
465        let s = a.sum_all();
466        assert_eq!(s, 403662);
467
468        let s = a.sum_axes(None);
469        println!("{s:?}");
470        assert_eq!(s.to_scalar(), 403662);
471
472        // DeviceFaer
473        let a = arange(24);
474        let s = sum_all(&a);
475        assert_eq!(s, 276);
476
477        let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
478        let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
479        let s = a.sum_all();
480        assert_eq!(s, 403662);
481
482        let s = a.sum_axes(None);
483        println!("{s:?}");
484        assert_eq!(s.to_scalar(), 403662);
485    }
486
487    #[test]
488    fn test_sum_axes() {
489        #[cfg(not(feature = "col_major"))]
490        {
491            // a = np.arange(3240).reshape(4, 6, 15, 9).transpose(2, 0, 3, 1)
492            // a.sum(axis=(0, -2))
493            // DeviceCpuSerial
494            let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
495            let s = a.sum_axes([0, -2]);
496            println!("{s:?}");
497            assert_eq!(s[[0, 1]], 27270);
498            assert_eq!(s[[1, 2]], 154845);
499            assert_eq!(s[[3, 5]], 428220);
500
501            // DeviceFaer
502            let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
503            let s = a.sum_axes([0, -2]);
504            println!("{s:?}");
505            assert_eq!(s[[0, 1]], 27270);
506            assert_eq!(s[[1, 2]], 154845);
507            assert_eq!(s[[3, 5]], 428220);
508        }
509        #[cfg(feature = "col_major")]
510        {
511            // a = reshape(range(0, 3239), (4, 6, 15, 9));
512            // a = permutedims(a, (3, 1, 4, 2));
513            // sum(a, dims=(1, 3))
514            // DeviceCpuSerial
515            let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
516            let s = a.sum_axes([0, -2]);
517            println!("{s:?}");
518            assert_eq!(s[[0, 1]], 217620);
519            assert_eq!(s[[1, 2]], 218295);
520            assert_eq!(s[[3, 5]], 220185);
521
522            // DeviceFaer
523            let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
524            let s = a.sum_axes([0, -2]);
525            println!("{s:?}");
526            assert_eq!(s[[0, 1]], 217620);
527            assert_eq!(s[[1, 2]], 218295);
528            assert_eq!(s[[3, 5]], 220185);
529        }
530    }
531
532    #[test]
533    fn test_min() {
534        // DeviceCpuSerial
535        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
536        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
537        println!("{a:}");
538        let m = a.min_axes(0);
539        assert_eq!(m.to_vec(), vec![2, 3, 1]);
540        let m = a.min_axes(1);
541        assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
542        let m = a.min_all();
543        assert_eq!(m, 1);
544
545        // DeviceFaer
546        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
547        let a = asarray((&v, [4, 3].c()));
548        println!("{a:}");
549        let m = a.min_axes(0);
550        assert_eq!(m.to_vec(), vec![2, 3, 1]);
551        let m = a.min_axes(1);
552        assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
553        let m = a.min_all();
554        assert_eq!(m, 1);
555    }
556
557    #[test]
558    fn test_mean() {
559        #[cfg(not(feature = "col_major"))]
560        {
561            // DeviceCpuSerial
562            let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
563            let m = a.mean_all();
564            assert_eq!(m, 11.5);
565
566            let m = a.mean_axes((0, 2));
567            println!("{m:}");
568            assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
569
570            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
571            println!("{m:}");
572            assert_eq!(m.to_vec(), vec![18.0, 6.0]);
573
574            // DeviceFaer
575            let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
576            let m = a.mean_all();
577            assert_eq!(m, 11.5);
578
579            let m = a.mean_axes((0, 2));
580            println!("{m:}");
581            assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
582
583            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
584            println!("{m:}");
585            assert_eq!(m.to_vec(), vec![18.0, 6.0]);
586        }
587        #[cfg(feature = "col_major")]
588        {
589            // DeviceCpuSerial
590            let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
591            let m = a.mean_all();
592            assert_eq!(m, 11.5);
593
594            let m = a.mean_axes((0, 2));
595            println!("{m:}");
596            assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
597
598            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
599            println!("{m:}");
600            assert_eq!(m.to_vec(), vec![15.0, 14.0]);
601
602            // // DeviceFaer
603            let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
604            let m = a.mean_all();
605            assert_eq!(m, 11.5);
606
607            let m = a.mean_axes((0, 2));
608            println!("{m:}");
609            assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
610
611            let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
612            println!("{m:}");
613            assert_eq!(m.to_vec(), vec![15.0, 14.0]);
614        }
615    }
616
617    #[test]
618    fn test_var() {
619        // DeviceCpuSerial
620        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
621        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
622
623        let m = a.var_all();
624        println!("{m:}");
625        assert!((m - 8.409722222222221).abs() < 1e-10);
626
627        let m = a.var_axes(0);
628        println!("{m:}");
629        assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
630
631        let m = a.var_axes(1);
632        println!("{m:}");
633        assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
634
635        // DeviceFaer
636        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
637        let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
638
639        let m = a.var_all();
640        println!("{m:}");
641        assert!((m - 8.409722222222221).abs() < 1e-10);
642
643        let m = a.var_axes(0);
644        println!("{m:}");
645        assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
646
647        let m = a.var_axes(1);
648        println!("{m:}");
649        assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
650    }
651
652    #[test]
653    fn test_std() {
654        // DeviceCpuSerial
655        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
656        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
657
658        let m = a.std_all();
659        println!("{m:}");
660        assert!((m - 2.899952106884219).abs() < 1e-10);
661
662        let m = a.std_axes(0);
663        println!("{m:}");
664        assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
665
666        let m = a.std_axes(1);
667        println!("{m:}");
668        assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
669
670        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
671        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
672        let v = vr
673            .iter()
674            .zip(vi.iter())
675            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
676            .collect::<Vec<_>>();
677        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
678
679        let m = a.std_all();
680        println!("{m:}");
681        assert!((m - 4.508479664907993).abs() < 1e-10);
682
683        // DeviceFaer
684        let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
685        let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
686
687        let m = a.std_all();
688        println!("{m:}");
689        assert!((m - 2.899952106884219).abs() < 1e-10);
690
691        let m = a.std_axes(0);
692        println!("{m:}");
693        assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
694
695        let m = a.std_axes(1);
696        println!("{m:}");
697        assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
698
699        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
700        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
701        let v = vr
702            .iter()
703            .zip(vi.iter())
704            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
705            .collect::<Vec<_>>();
706        let a = asarray((&v, [4, 3].c()));
707
708        let m = a.std_all();
709        println!("{m:}");
710        assert!((m - 4.508479664907993).abs() < 1e-10);
711    }
712
713    #[test]
714    fn test_l2_norm() {
715        // DeviceCpuSerial
716        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
717        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
718        let v = vr
719            .iter()
720            .zip(vi.iter())
721            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
722            .collect::<Vec<_>>();
723        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
724
725        let m = a.l2_norm_all();
726        println!("{m:}");
727        assert!((m - 33.21144381083123).abs() < 1e-10);
728
729        // DeviceFaer
730        let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
731        let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
732        let v = vr
733            .iter()
734            .zip(vi.iter())
735            .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
736            .collect::<Vec<_>>();
737        let a = asarray((&v, [4, 3].c()));
738
739        let m = a.l2_norm_all();
740        println!("{m:}");
741        assert!((m - 33.21144381083123).abs() < 1e-10);
742    }
743
744    #[test]
745    #[cfg(feature = "rayon")]
746    fn test_large_std() {
747        #[cfg(not(feature = "col_major"))]
748        {
749            // a = np.linspace(0, 1, 1048576).reshape(16, 256, 256)
750            // b = np.linspace(1, 2, 1048576).reshape(16, 256, 256)
751            // c = a @ b
752            // print(c.mean(), c.std())
753            // print(c.std(axis=(0, 1))[[0, -1]])
754            // print(c.std(axis=(1, 2))[[0, -1]])
755            let a = linspace((0.0, 1.0, 1048576)).into_shape([16, 256, 256]);
756            let b = linspace((1.0, 2.0, 1048576)).into_shape([16, 256, 256]);
757            let c: Tensor<f64> = &a % &b;
758
759            let c_mean = c.mean_all();
760            println!("{c_mean:?}");
761            assert!((c_mean - 213.2503660477036) < 1e-6);
762
763            let c_std = c.std_all();
764            println!("{c_std:?}");
765            assert!((c_std - 148.88523481701804) < 1e-6);
766
767            let c_std_1 = c.std_axes((0, 1));
768            println!("{c_std_1}");
769            assert!(c_std_1[[0]] - 148.8763226818815 < 1e-6);
770            assert!(c_std_1[[255]] - 148.8941462322758 < 1e-6);
771
772            let c_std_2 = c.std_axes((1, 2));
773            println!("{c_std_2}");
774            assert!(c_std_2[[0]] - 4.763105902995575 < 1e-6);
775            assert!(c_std_2[[15]] - 9.093224903569157 < 1e-6);
776        }
777        #[cfg(feature = "col_major")]
778        {
779            // a = reshape(LinRange(0, 1, 1048576), (256, 256, 16));
780            // b = reshape(LinRange(1, 2, 1048576), (256, 256, 16));
781            // c = Array{Float64}(undef, 256, 256, 16);
782            // for i in 1:16
783            //     c[:, :, i] = a[:, :, i] * b[:, :, i]
784            // end
785            // mean(c), std(c)
786            // std(c, dims=(2, 3))
787            // std(c, dims=(1, 2))
788            let a = linspace((0.0, 1.0, 1048576)).into_shape([256, 256, 16]);
789            let b = linspace((1.0, 2.0, 1048576)).into_shape([256, 256, 16]);
790            let mut c: Tensor<f64> = zeros([256, 256, 16]);
791            for i in 0..16 {
792                c.i_mut((.., .., i)).assign(&a.i((.., .., i)) % &b.i((.., .., i)));
793            }
794
795            let c_mean = c.mean_all();
796            println!("{c_mean:?}");
797            assert!((c_mean - 213.25036604770355) < 1e-6);
798
799            let c_std = c.std_all();
800            println!("{c_std:?}");
801            assert!((c_std - 148.7419537312827) < 1e-6);
802
803            let c_std_1 = c.std_axes((1, 2));
804            println!("{c_std_1}");
805            assert!(c_std_1[[0]] - 148.75113653867191 < 1e-6);
806            assert!(c_std_1[[255]] - 148.7689445622776 < 1e-6);
807
808            let c_std_2 = c.std_axes((0, 1));
809            println!("{c_std_2}");
810            assert!(c_std_2[[0]] - 0.145530296246335 < 1e-6);
811            assert!(c_std_2[[15]] - 4.474611918106057 < 1e-6);
812        }
813    }
814
815    #[test]
816    fn test_unraveled_argmin() {
817        // DeviceCpuSerial
818        let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
819        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
820        println!("{a:}");
821        // [[ 8 4 2]
822        //  [ 9 7 1]
823        //  [ 2 1 8]
824        //  [ 6 10 5]]
825
826        let m = a.unraveled_argmin_all();
827        println!("{m:?}");
828        assert_eq!(m, vec![1, 2]);
829
830        let m = a.unraveled_argmin_axes(-1);
831        println!("{m:?}");
832        let m_vec = m.raw();
833        assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1], vec![2]]);
834
835        let m = a.unraveled_argmin_axes(0);
836        println!("{m:?}");
837        let m_vec = m.raw();
838        assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1]]);
839    }
840
841    #[test]
842    fn test_argmin() {
843        // DeviceCpuSerial
844        let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
845        let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
846        println!("{a:}");
847        // [[ 8 4 2]
848        //  [ 9 7 1]
849        //  [ 2 1 8]
850        //  [ 6 10 5]]
851
852        let m = a.argmin_all();
853        println!("{m:?}");
854        assert_eq!(m, 5);
855
856        let m = a.argmin_axes(-1);
857        println!("{m:?}");
858        let m_vec = m.raw();
859        assert_eq!(m_vec, &vec![2, 2, 1, 2]);
860
861        let m = a.argmin_axes(0);
862        println!("{m:?}");
863        let m_vec = m.raw();
864        assert_eq!(m_vec, &vec![2, 2, 1]);
865    }
866
867    #[test]
868    fn test_all() {
869        let a = asarray((vec![true, true, false, true, true, true], [2, 3].c()));
870        let a_all = a.all_axes(-2);
871        println!("{:?}", a_all);
872        assert_eq!(a_all.raw(), &[true, true, false]);
873    }
874
875    #[test]
876    fn test_allclose_cpu_serial() {
877        use rstsr_dtype_traits::IsCloseArgsBuilder;
878
879        let mut device = DeviceCpuSerial::default();
880        device.set_default_order(RowMajor);
881        let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
882        let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
883        let result = allclose(&a, &b, None);
884        println!("Allclose result: {result}");
885        assert!(result);
886        let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
887        let result = allclose(&a, &b, args);
888        println!("Allclose result with tight args: {result}");
889        assert!(!result);
890    }
891
892    #[test]
893    #[cfg(feature = "faer")]
894    fn test_allclose_faer() {
895        use rstsr_dtype_traits::IsCloseArgsBuilder;
896
897        let mut device = DeviceFaer::default();
898        device.set_default_order(RowMajor);
899        let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
900        let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
901        let result = allclose(&a, &b, None);
902        println!("Allclose result: {result}");
903        assert!(result);
904        let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
905        let result = allclose(&a, &b, args);
906        println!("Allclose result with tight args: {result}");
907        assert!(!result);
908    }
909}