1use crate::prelude_dev::*;
2
3macro_rules! trait_reduction {
4 ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
5 pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
6 where
7 D: DimAPI,
8 B: $OpReduceAPI<T, D>,
9 {
10 let tensor = tensor.view();
11 tensor.device().$fn_all(tensor.raw(), tensor.layout())
12 }
13
14 pub fn $fn_axes_f<T, B, D>(
15 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
16 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
17 ) -> Result<Tensor<B::TOut, B, IxD>>
18 where
19 D: DimAPI,
20 B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
21 {
22 let axes = axes.try_into().map_err(Into::into)?;
23 let tensor = tensor.view();
24
25 match axes {
26 AxesIndex::None => {
27 let sum = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
28 let storage = tensor.device().outof_cpu_vec(vec![sum])?;
29 let layout = Layout::new(vec![], vec![], 0)?;
30 Tensor::new_f(storage, layout)
31 },
32 _ => {
33 let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
34 Tensor::new_f(storage, layout)
35 },
36 }
37 }
38
39 pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
40 where
41 D: DimAPI,
42 B: $OpReduceAPI<T, D>,
43 {
44 $fn_all_f(tensor).rstsr_unwrap()
45 }
46
47 pub fn $fn_axes<T, B, D>(
48 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
49 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
50 ) -> Tensor<B::TOut, B, IxD>
51 where
52 D: DimAPI,
53 B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
54 {
55 $fn_axes_f(tensor, axes).rstsr_unwrap()
56 }
57
58 pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
59 where
60 D: DimAPI,
61 B: $OpReduceAPI<T, D>,
62 {
63 $fn_all_f(tensor)
64 }
65
66 pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
67 where
68 D: DimAPI,
69 B: $OpReduceAPI<T, D>,
70 {
71 $fn_all(tensor)
72 }
73
74 impl<R, T, B, D> TensorAny<R, T, B, D>
75 where
76 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
77 D: DimAPI,
78 B: $OpReduceAPI<T, D>,
79 {
80 pub fn $fn_all_f(&self) -> Result<B::TOut> {
81 $fn_all_f(self)
82 }
83
84 pub fn $fn_all(&self) -> B::TOut {
85 $fn_all(self)
86 }
87
88 pub fn $fn_axes_f(
89 &self,
90 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
91 ) -> Result<Tensor<B::TOut, B, IxD>>
92 where
93 B: DeviceCreationAnyAPI<B::TOut>,
94 {
95 $fn_axes_f(self, axes)
96 }
97
98 pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<B::TOut, B, IxD>
99 where
100 B: DeviceCreationAnyAPI<B::TOut>,
101 {
102 $fn_axes(self, axes)
103 }
104
105 pub fn $fn_f(&self) -> Result<B::TOut> {
106 $fn_f(self)
107 }
108
109 pub fn $fn(&self) -> B::TOut {
110 $fn(self)
111 }
112 }
113 };
114}
115
116#[rustfmt::skip]
117mod impl_trait_reduction {
118 use super::*;
119 trait_reduction!(OpSumAPI, sum, sum_f, sum_axes, sum_axes_f, sum_all, sum_all_f);
120 trait_reduction!(OpMinAPI, min, min_f, min_axes, min_axes_f, min_all, min_all_f);
121 trait_reduction!(OpMaxAPI, max, max_f, max_axes, max_axes_f, max_all, max_all_f);
122 trait_reduction!(OpProdAPI, prod, prod_f, prod_axes, prod_axes_f, prod_all, prod_all_f);
123 trait_reduction!(OpMeanAPI, mean, mean_f, mean_axes, mean_axes_f, mean_all, mean_all_f);
124 trait_reduction!(OpVarAPI, var, var_f, var_axes, var_axes_f, var_all, var_all_f);
125 trait_reduction!(OpStdAPI, std, std_f, std_axes, std_axes_f, std_all, std_all_f);
126 trait_reduction!(OpL2NormAPI, l2_norm, l2_norm_f, l2_norm_axes, l2_norm_axes_f, l2_norm_all, l2_norm_all_f);
127 trait_reduction!(OpArgMinAPI, argmin, argmin_f, argmin_axes, argmin_axes_f, argmin_all, argmin_all_f);
128 trait_reduction!(OpArgMaxAPI, argmax, argmax_f, argmax_axes, argmax_axes_f, argmax_all, argmax_all_f);
129 trait_reduction!(OpAllAPI, all, all_f, all_axes, all_axes_f, all_all, all_all_f);
130 trait_reduction!(OpAnyAPI, any, any_f, any_axes, any_axes_f, any_all, any_all_f);
131 trait_reduction!(OpCountNonZeroAPI, count_nonzero, count_nonzero_f, count_nonzero_axes, count_nonzero_axes_f, count_nonzero_all, count_nonzero_all_f);
132}
133pub use impl_trait_reduction::*;
134
135macro_rules! trait_reduction_arg {
136 ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
137 pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
138 where
139 D: DimAPI,
140 B: $OpReduceAPI<T, D>,
141 {
142 let tensor = tensor.view();
143 tensor.device().$fn_all(tensor.raw(), tensor.layout())
144 }
145
146 pub fn $fn_axes_f<T, B, D>(
147 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
148 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
149 ) -> Result<Tensor<IxD, B, IxD>>
150 where
151 D: DimAPI,
152 B: $OpReduceAPI<T, D> + DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
153 {
154 let axes = axes.try_into().map_err(Into::into)?;
155 let tensor = tensor.view();
156
157 match axes {
158 AxesIndex::None => {
159 let arg = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
161 let storage = tensor.device().outof_cpu_vec(vec![arg.into()])?;
162 let layout = Layout::new(vec![], vec![], 0)?;
163 Tensor::new_f(storage, layout)
164 },
165 _ => {
166 let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
167 Tensor::new_f(storage, layout)
168 },
169 }
170 }
171
172 pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
173 where
174 D: DimAPI,
175 B: $OpReduceAPI<T, D>,
176 {
177 $fn_all_f(tensor).rstsr_unwrap()
178 }
179
180 pub fn $fn_axes<T, B, D>(
181 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
182 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
183 ) -> Tensor<IxD, B, IxD>
184 where
185 D: DimAPI,
186 B: $OpReduceAPI<T, D> + DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
187 {
188 $fn_axes_f(tensor, axes).rstsr_unwrap()
189 }
190
191 pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
192 where
193 D: DimAPI,
194 B: $OpReduceAPI<T, D>,
195 {
196 $fn_all_f(tensor)
197 }
198
199 pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
200 where
201 D: DimAPI,
202 B: $OpReduceAPI<T, D>,
203 {
204 $fn_all(tensor)
205 }
206
207 impl<R, T, B, D> TensorAny<R, T, B, D>
208 where
209 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
210 D: DimAPI,
211 B: $OpReduceAPI<T, D>,
212 {
213 pub fn $fn_all_f(&self) -> Result<D> {
214 $fn_all_f(self)
215 }
216
217 pub fn $fn_all(&self) -> D {
218 $fn_all(self)
219 }
220
221 pub fn $fn_axes_f(
222 &self,
223 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
224 ) -> Result<Tensor<IxD, B, IxD>>
225 where
226 B: DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
227 {
228 $fn_axes_f(self, axes)
229 }
230
231 pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<IxD, B, IxD>
232 where
233 B: DeviceAPI<IxD> + DeviceCreationAnyAPI<IxD>,
234 {
235 $fn_axes(self, axes)
236 }
237
238 pub fn $fn_f(&self) -> Result<D> {
239 $fn_f(self)
240 }
241
242 pub fn $fn(&self) -> D {
243 $fn(self)
244 }
245 }
246 };
247}
248
249trait_reduction_arg!(
250 OpUnraveledArgMinAPI,
251 unraveled_argmin,
252 unraveled_argmin_f,
253 unraveled_argmin_axes,
254 unraveled_argmin_axes_f,
255 unraveled_argmin_all,
256 unraveled_argmin_all_f
257);
258trait_reduction_arg!(
259 OpUnraveledArgMaxAPI,
260 unraveled_argmax,
261 unraveled_argmax_f,
262 unraveled_argmax_axes,
263 unraveled_argmax_axes_f,
264 unraveled_argmax_all,
265 unraveled_argmax_all_f
266);
267
268pub trait TensorSumBoolAPI<B, D>
271where
272 D: DimAPI,
273 B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D>,
274{
275 fn sum_all_f(&self) -> Result<usize>;
276 fn sum_all(&self) -> usize {
277 self.sum_all_f().rstsr_unwrap()
278 }
279 fn sum_f(&self) -> Result<usize> {
280 self.sum_all_f()
281 }
282 fn sum(&self) -> usize {
283 self.sum_f().rstsr_unwrap()
284 }
285 fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Result<Tensor<usize, B, IxD>>;
286 fn sum_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Tensor<usize, B, IxD> {
287 self.sum_axes_f(axes).rstsr_unwrap()
288 }
289}
290
291impl<R, B, D> TensorSumBoolAPI<B, D> for TensorAny<R, bool, B, D>
292where
293 R: DataAPI<Data = <B as DeviceRawAPI<bool>>::Raw>,
294 D: DimAPI,
295 B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D> + DeviceCreationAnyAPI<usize>,
296{
297 fn sum_all_f(&self) -> Result<usize> {
298 self.device().sum_all(self.raw(), self.layout())
299 }
300
301 fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Result<Tensor<usize, B, IxD>> {
302 let axes = axes.try_into().map_err(Into::into)?;
303
304 match axes {
305 AxesIndex::None => {
306 let sum = self.device().sum_all(self.raw(), self.layout())?;
308 let storage = self.device().outof_cpu_vec(vec![sum])?;
309 let layout = Layout::new(vec![], vec![], 0)?;
310 Tensor::new_f(storage, layout)
311 },
312 _ => {
313 let (storage, layout) = self.device().sum_axes(self.raw(), self.layout(), axes.as_ref())?;
314 Tensor::new_f(storage, layout)
315 },
316 }
317 }
318}
319
320pub fn allclose_all_f<TA, TB, TE, B, DA, DB>(
325 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
326 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
327 isclose_args: impl Into<IsCloseArgs<TE>>,
328) -> Result<bool>
329where
330 DA: DimAPI,
331 DB: DimAPI,
332 B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
333 TE: 'static,
334{
335 let tensor_a = tensor_a.view();
336 let tensor_b = tensor_b.view();
337 let isclose_args = isclose_args.into();
338 let device = tensor_a.device();
339
340 rstsr_assert!(tensor_a.device().same_device(tensor_b.device()), DeviceMismatch)?;
342
343 let la = tensor_a.layout().to_dim::<IxD>()?;
346 let lb = tensor_b.layout().to_dim::<IxD>()?;
347 let default_order = device.default_order();
348 let (la_b, lb_b) = broadcast_layout(&la, &lb, default_order)?;
349
350 device.allclose_all(tensor_a.raw(), &la_b, tensor_b.raw(), &lb_b, &isclose_args)
351}
352
353pub fn allclose_all<TA, TB, TE, B, DA, DB>(
354 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
355 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
356 isclose_args: impl Into<IsCloseArgs<TE>>,
357) -> bool
358where
359 DA: DimAPI,
360 DB: DimAPI,
361 B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
362 TE: 'static,
363{
364 allclose_all_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
365}
366
367pub fn allclose_f<TA, TB, TE, B, DA, DB>(
368 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
369 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
370 isclose_args: impl Into<IsCloseArgs<TE>>,
371) -> Result<bool>
372where
373 DA: DimAPI,
374 DB: DimAPI,
375 B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
376 TE: 'static,
377{
378 allclose_all_f(tensor_a, tensor_b, isclose_args)
379}
380
381pub fn allclose<TA, TB, TE, B, DA, DB>(
382 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
383 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
384 isclose_args: impl Into<IsCloseArgs<TE>>,
385) -> bool
386where
387 DA: DimAPI,
388 DB: DimAPI,
389 B: DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
390 TE: 'static,
391{
392 allclose_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
393}
394
395#[macro_export]
396macro_rules! allclose {
397 ($tensor_a:expr, $tensor_b:expr, $isclose_args:expr) => {{
398 use rstsr::prelude::rstsr_funcs::allclose;
399 allclose($tensor_a, $tensor_b, $isclose_args)
400 }};
401 ($tensor_a:expr, $tensor_b:expr) => {{
402 use rstsr::prelude::rstsr_funcs::allclose;
403 allclose($tensor_a, $tensor_b, None)
404 }};
405}
406
407#[cfg(test)]
410mod test {
411 use num::ToPrimitive;
412
413 use super::*;
414
415 #[test]
416 #[cfg(not(feature = "col_major"))]
417 fn test_sum_all_row_major() {
418 let a = arange((24, &DeviceCpuSerial::default()));
420 let s = sum_all(&a);
421 assert_eq!(s, 276);
422
423 let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
426 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
427 let s = a.sum_all();
428 assert_eq!(s, 446586);
429
430 let s = a.sum_axes(None);
431 println!("{s:?}");
432 assert_eq!(s.to_scalar(), 446586);
433
434 let a = arange(24);
436 let s = sum_all(&a);
437 assert_eq!(s, 276);
438
439 let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
442 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
443 let s = a.sum_all();
444 assert_eq!(s, 446586);
445
446 let s = a.sum_axes(None);
447 println!("{s:?}");
448 assert_eq!(s.to_scalar(), 446586);
449 }
450
451 #[test]
452 #[cfg(feature = "col_major")]
453 fn test_sum_all_col_major() {
454 let a = arange((24, &DeviceCpuSerial::default()));
456 let s = sum_all(&a);
457 assert_eq!(s, 276);
458
459 let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
464 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
465 let s = a.sum_all();
466 assert_eq!(s, 403662);
467
468 let s = a.sum_axes(None);
469 println!("{s:?}");
470 assert_eq!(s.to_scalar(), 403662);
471
472 let a = arange(24);
474 let s = sum_all(&a);
475 assert_eq!(s, 276);
476
477 let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
478 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
479 let s = a.sum_all();
480 assert_eq!(s, 403662);
481
482 let s = a.sum_axes(None);
483 println!("{s:?}");
484 assert_eq!(s.to_scalar(), 403662);
485 }
486
487 #[test]
488 fn test_sum_axes() {
489 #[cfg(not(feature = "col_major"))]
490 {
491 let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
495 let s = a.sum_axes([0, -2]);
496 println!("{s:?}");
497 assert_eq!(s[[0, 1]], 27270);
498 assert_eq!(s[[1, 2]], 154845);
499 assert_eq!(s[[3, 5]], 428220);
500
501 let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
503 let s = a.sum_axes([0, -2]);
504 println!("{s:?}");
505 assert_eq!(s[[0, 1]], 27270);
506 assert_eq!(s[[1, 2]], 154845);
507 assert_eq!(s[[3, 5]], 428220);
508 }
509 #[cfg(feature = "col_major")]
510 {
511 let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
516 let s = a.sum_axes([0, -2]);
517 println!("{s:?}");
518 assert_eq!(s[[0, 1]], 217620);
519 assert_eq!(s[[1, 2]], 218295);
520 assert_eq!(s[[3, 5]], 220185);
521
522 let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
524 let s = a.sum_axes([0, -2]);
525 println!("{s:?}");
526 assert_eq!(s[[0, 1]], 217620);
527 assert_eq!(s[[1, 2]], 218295);
528 assert_eq!(s[[3, 5]], 220185);
529 }
530 }
531
532 #[test]
533 fn test_min() {
534 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
536 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
537 println!("{a:}");
538 let m = a.min_axes(0);
539 assert_eq!(m.to_vec(), vec![2, 3, 1]);
540 let m = a.min_axes(1);
541 assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
542 let m = a.min_all();
543 assert_eq!(m, 1);
544
545 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
547 let a = asarray((&v, [4, 3].c()));
548 println!("{a:}");
549 let m = a.min_axes(0);
550 assert_eq!(m.to_vec(), vec![2, 3, 1]);
551 let m = a.min_axes(1);
552 assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
553 let m = a.min_all();
554 assert_eq!(m, 1);
555 }
556
557 #[test]
558 fn test_mean() {
559 #[cfg(not(feature = "col_major"))]
560 {
561 let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
563 let m = a.mean_all();
564 assert_eq!(m, 11.5);
565
566 let m = a.mean_axes((0, 2));
567 println!("{m:}");
568 assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
569
570 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
571 println!("{m:}");
572 assert_eq!(m.to_vec(), vec![18.0, 6.0]);
573
574 let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
576 let m = a.mean_all();
577 assert_eq!(m, 11.5);
578
579 let m = a.mean_axes((0, 2));
580 println!("{m:}");
581 assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
582
583 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
584 println!("{m:}");
585 assert_eq!(m.to_vec(), vec![18.0, 6.0]);
586 }
587 #[cfg(feature = "col_major")]
588 {
589 let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
591 let m = a.mean_all();
592 assert_eq!(m, 11.5);
593
594 let m = a.mean_axes((0, 2));
595 println!("{m:}");
596 assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
597
598 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
599 println!("{m:}");
600 assert_eq!(m.to_vec(), vec![15.0, 14.0]);
601
602 let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
604 let m = a.mean_all();
605 assert_eq!(m, 11.5);
606
607 let m = a.mean_axes((0, 2));
608 println!("{m:}");
609 assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
610
611 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
612 println!("{m:}");
613 assert_eq!(m.to_vec(), vec![15.0, 14.0]);
614 }
615 }
616
617 #[test]
618 fn test_var() {
619 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
621 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
622
623 let m = a.var_all();
624 println!("{m:}");
625 assert!((m - 8.409722222222221).abs() < 1e-10);
626
627 let m = a.var_axes(0);
628 println!("{m:}");
629 assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
630
631 let m = a.var_axes(1);
632 println!("{m:}");
633 assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
634
635 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
637 let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
638
639 let m = a.var_all();
640 println!("{m:}");
641 assert!((m - 8.409722222222221).abs() < 1e-10);
642
643 let m = a.var_axes(0);
644 println!("{m:}");
645 assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
646
647 let m = a.var_axes(1);
648 println!("{m:}");
649 assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
650 }
651
652 #[test]
653 fn test_std() {
654 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
656 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
657
658 let m = a.std_all();
659 println!("{m:}");
660 assert!((m - 2.899952106884219).abs() < 1e-10);
661
662 let m = a.std_axes(0);
663 println!("{m:}");
664 assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
665
666 let m = a.std_axes(1);
667 println!("{m:}");
668 assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
669
670 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
671 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
672 let v = vr
673 .iter()
674 .zip(vi.iter())
675 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
676 .collect::<Vec<_>>();
677 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
678
679 let m = a.std_all();
680 println!("{m:}");
681 assert!((m - 4.508479664907993).abs() < 1e-10);
682
683 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
685 let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
686
687 let m = a.std_all();
688 println!("{m:}");
689 assert!((m - 2.899952106884219).abs() < 1e-10);
690
691 let m = a.std_axes(0);
692 println!("{m:}");
693 assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
694
695 let m = a.std_axes(1);
696 println!("{m:}");
697 assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
698
699 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
700 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
701 let v = vr
702 .iter()
703 .zip(vi.iter())
704 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
705 .collect::<Vec<_>>();
706 let a = asarray((&v, [4, 3].c()));
707
708 let m = a.std_all();
709 println!("{m:}");
710 assert!((m - 4.508479664907993).abs() < 1e-10);
711 }
712
713 #[test]
714 fn test_l2_norm() {
715 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
717 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
718 let v = vr
719 .iter()
720 .zip(vi.iter())
721 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
722 .collect::<Vec<_>>();
723 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
724
725 let m = a.l2_norm_all();
726 println!("{m:}");
727 assert!((m - 33.21144381083123).abs() < 1e-10);
728
729 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
731 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
732 let v = vr
733 .iter()
734 .zip(vi.iter())
735 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
736 .collect::<Vec<_>>();
737 let a = asarray((&v, [4, 3].c()));
738
739 let m = a.l2_norm_all();
740 println!("{m:}");
741 assert!((m - 33.21144381083123).abs() < 1e-10);
742 }
743
744 #[test]
745 #[cfg(feature = "rayon")]
746 fn test_large_std() {
747 #[cfg(not(feature = "col_major"))]
748 {
749 let a = linspace((0.0, 1.0, 1048576)).into_shape([16, 256, 256]);
756 let b = linspace((1.0, 2.0, 1048576)).into_shape([16, 256, 256]);
757 let c: Tensor<f64> = &a % &b;
758
759 let c_mean = c.mean_all();
760 println!("{c_mean:?}");
761 assert!((c_mean - 213.2503660477036) < 1e-6);
762
763 let c_std = c.std_all();
764 println!("{c_std:?}");
765 assert!((c_std - 148.88523481701804) < 1e-6);
766
767 let c_std_1 = c.std_axes((0, 1));
768 println!("{c_std_1}");
769 assert!(c_std_1[[0]] - 148.8763226818815 < 1e-6);
770 assert!(c_std_1[[255]] - 148.8941462322758 < 1e-6);
771
772 let c_std_2 = c.std_axes((1, 2));
773 println!("{c_std_2}");
774 assert!(c_std_2[[0]] - 4.763105902995575 < 1e-6);
775 assert!(c_std_2[[15]] - 9.093224903569157 < 1e-6);
776 }
777 #[cfg(feature = "col_major")]
778 {
779 let a = linspace((0.0, 1.0, 1048576)).into_shape([256, 256, 16]);
789 let b = linspace((1.0, 2.0, 1048576)).into_shape([256, 256, 16]);
790 let mut c: Tensor<f64> = zeros([256, 256, 16]);
791 for i in 0..16 {
792 c.i_mut((.., .., i)).assign(&a.i((.., .., i)) % &b.i((.., .., i)));
793 }
794
795 let c_mean = c.mean_all();
796 println!("{c_mean:?}");
797 assert!((c_mean - 213.25036604770355) < 1e-6);
798
799 let c_std = c.std_all();
800 println!("{c_std:?}");
801 assert!((c_std - 148.7419537312827) < 1e-6);
802
803 let c_std_1 = c.std_axes((1, 2));
804 println!("{c_std_1}");
805 assert!(c_std_1[[0]] - 148.75113653867191 < 1e-6);
806 assert!(c_std_1[[255]] - 148.7689445622776 < 1e-6);
807
808 let c_std_2 = c.std_axes((0, 1));
809 println!("{c_std_2}");
810 assert!(c_std_2[[0]] - 0.145530296246335 < 1e-6);
811 assert!(c_std_2[[15]] - 4.474611918106057 < 1e-6);
812 }
813 }
814
815 #[test]
816 fn test_unraveled_argmin() {
817 let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
819 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
820 println!("{a:}");
821 let m = a.unraveled_argmin_all();
827 println!("{m:?}");
828 assert_eq!(m, vec![1, 2]);
829
830 let m = a.unraveled_argmin_axes(-1);
831 println!("{m:?}");
832 let m_vec = m.raw();
833 assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1], vec![2]]);
834
835 let m = a.unraveled_argmin_axes(0);
836 println!("{m:?}");
837 let m_vec = m.raw();
838 assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1]]);
839 }
840
841 #[test]
842 fn test_argmin() {
843 let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
845 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
846 println!("{a:}");
847 let m = a.argmin_all();
853 println!("{m:?}");
854 assert_eq!(m, 5);
855
856 let m = a.argmin_axes(-1);
857 println!("{m:?}");
858 let m_vec = m.raw();
859 assert_eq!(m_vec, &vec![2, 2, 1, 2]);
860
861 let m = a.argmin_axes(0);
862 println!("{m:?}");
863 let m_vec = m.raw();
864 assert_eq!(m_vec, &vec![2, 2, 1]);
865 }
866
867 #[test]
868 fn test_all() {
869 let a = asarray((vec![true, true, false, true, true, true], [2, 3].c()));
870 let a_all = a.all_axes(-2);
871 println!("{:?}", a_all);
872 assert_eq!(a_all.raw(), &[true, true, false]);
873 }
874
875 #[test]
876 fn test_allclose_cpu_serial() {
877 use rstsr_dtype_traits::IsCloseArgsBuilder;
878
879 let mut device = DeviceCpuSerial::default();
880 device.set_default_order(RowMajor);
881 let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
882 let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
883 let result = allclose(&a, &b, None);
884 println!("Allclose result: {result}");
885 assert!(result);
886 let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
887 let result = allclose(&a, &b, args);
888 println!("Allclose result with tight args: {result}");
889 assert!(!result);
890 }
891
892 #[test]
893 #[cfg(feature = "faer")]
894 fn test_allclose_faer() {
895 use rstsr_dtype_traits::IsCloseArgsBuilder;
896
897 let mut device = DeviceFaer::default();
898 device.set_default_order(RowMajor);
899 let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
900 let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
901 let result = allclose(&a, &b, None);
902 println!("Allclose result: {result}");
903 assert!(result);
904 let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
905 let result = allclose(&a, &b, args);
906 println!("Allclose result with tight args: {result}");
907 assert!(!result);
908 }
909}