1use crate::prelude_dev::*;
2
3macro_rules! trait_reduction {
4 ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
5 pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
6 where
7 D: DimAPI,
8 B: $OpReduceAPI<T, D>,
9 {
10 let tensor = tensor.view();
11 tensor.device().$fn_all(tensor.raw(), tensor.layout())
12 }
13
14 pub fn $fn_axes_f<T, B, D>(
15 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
16 axes: impl TryInto<AxesIndex<isize>, Error = Error>,
17 ) -> Result<Tensor<B::TOut, B, IxD>>
18 where
19 D: DimAPI,
20 B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
21 {
22 let axes = axes.try_into()?;
23 let tensor = tensor.view();
24
25 if axes.as_ref().is_empty() {
27 let sum = tensor.device().$fn_all(tensor.raw(), tensor.layout())?;
28 let storage = tensor.device().outof_cpu_vec(vec![sum])?;
29 let layout = Layout::new(vec![], vec![], 0)?;
30 return Tensor::new_f(storage, layout);
31 }
32
33 let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
34 Tensor::new_f(storage, layout)
35 }
36
37 pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
38 where
39 D: DimAPI,
40 B: $OpReduceAPI<T, D>,
41 {
42 $fn_all_f(tensor).rstsr_unwrap()
43 }
44
45 pub fn $fn_axes<T, B, D>(
46 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
47 axes: impl TryInto<AxesIndex<isize>, Error = Error>,
48 ) -> Tensor<B::TOut, B, IxD>
49 where
50 D: DimAPI,
51 B: $OpReduceAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
52 {
53 $fn_axes_f(tensor, axes).rstsr_unwrap()
54 }
55
56 pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<B::TOut>
57 where
58 D: DimAPI,
59 B: $OpReduceAPI<T, D>,
60 {
61 $fn_all_f(tensor)
62 }
63
64 pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> B::TOut
65 where
66 D: DimAPI,
67 B: $OpReduceAPI<T, D>,
68 {
69 $fn_all(tensor)
70 }
71
72 impl<R, T, B, D> TensorAny<R, T, B, D>
73 where
74 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
75 D: DimAPI,
76 B: $OpReduceAPI<T, D>,
77 {
78 pub fn $fn_all_f(&self) -> Result<B::TOut> {
79 $fn_all_f(self)
80 }
81
82 pub fn $fn_all(&self) -> B::TOut {
83 $fn_all(self)
84 }
85
86 pub fn $fn_axes_f(
87 &self,
88 axes: impl TryInto<AxesIndex<isize>, Error = Error>,
89 ) -> Result<Tensor<B::TOut, B, IxD>>
90 where
91 B: DeviceCreationAnyAPI<B::TOut>,
92 {
93 $fn_axes_f(self, axes)
94 }
95
96 pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<B::TOut, B, IxD>
97 where
98 B: DeviceCreationAnyAPI<B::TOut>,
99 {
100 $fn_axes(self, axes)
101 }
102
103 pub fn $fn_f(&self) -> Result<B::TOut> {
104 $fn_f(self)
105 }
106
107 pub fn $fn(&self) -> B::TOut {
108 $fn(self)
109 }
110 }
111 };
112}
113
114#[rustfmt::skip]
115mod impl_trait_reduction {
116 use super::*;
117 trait_reduction!(OpSumAPI, sum, sum_f, sum_axes, sum_axes_f, sum_all, sum_all_f);
118 trait_reduction!(OpMinAPI, min, min_f, min_axes, min_axes_f, min_all, min_all_f);
119 trait_reduction!(OpMaxAPI, max, max_f, max_axes, max_axes_f, max_all, max_all_f);
120 trait_reduction!(OpProdAPI, prod, prod_f, prod_axes, prod_axes_f, prod_all, prod_all_f);
121 trait_reduction!(OpMeanAPI, mean, mean_f, mean_axes, mean_axes_f, mean_all, mean_all_f);
122 trait_reduction!(OpVarAPI, var, var_f, var_axes, var_axes_f, var_all, var_all_f);
123 trait_reduction!(OpStdAPI, std, std_f, std_axes, std_axes_f, std_all, std_all_f);
124 trait_reduction!(OpL2NormAPI, l2_norm, l2_norm_f, l2_norm_axes, l2_norm_axes_f, l2_norm_all, l2_norm_all_f);
125 trait_reduction!(OpArgMinAPI, argmin, argmin_f, argmin_axes, argmin_axes_f, argmin_all, argmin_all_f);
126 trait_reduction!(OpArgMaxAPI, argmax, argmax_f, argmax_axes, argmax_axes_f, argmax_all, argmax_all_f);
127 trait_reduction!(OpAllAPI, all, all_f, all_axes, all_axes_f, all_all, all_all_f);
128 trait_reduction!(OpAnyAPI, any, any_f, any_axes, any_axes_f, any_all, any_all_f);
129 trait_reduction!(OpCountNonZeroAPI, count_nonzero, count_nonzero_f, count_nonzero_axes, count_nonzero_axes_f, count_nonzero_all, count_nonzero_all_f);
130}
131pub use impl_trait_reduction::*;
132
133macro_rules! trait_reduction_arg {
134 ($OpReduceAPI: ident, $fn: ident, $fn_f: ident, $fn_axes: ident, $fn_axes_f: ident, $fn_all: ident, $fn_all_f: ident) => {
135 pub fn $fn_all_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
136 where
137 D: DimAPI,
138 B: $OpReduceAPI<T, D>,
139 {
140 let tensor = tensor.view();
141 tensor.device().$fn_all(tensor.raw(), tensor.layout())
142 }
143
144 pub fn $fn_axes_f<T, B, D>(
145 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
146 axes: impl TryInto<AxesIndex<isize>, Error = Error>,
147 ) -> Result<Tensor<IxD, B, IxD>>
148 where
149 D: DimAPI,
150 B: $OpReduceAPI<T, D> + DeviceAPI<IxD>,
151 {
152 let axes = axes.try_into()?;
153 let tensor = tensor.view();
154
155 let (storage, layout) = tensor.device().$fn_axes(tensor.raw(), tensor.layout(), axes.as_ref())?;
156 Tensor::new_f(storage, layout)
157 }
158
159 pub fn $fn_all<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
160 where
161 D: DimAPI,
162 B: $OpReduceAPI<T, D>,
163 {
164 $fn_all_f(tensor).rstsr_unwrap()
165 }
166
167 pub fn $fn_axes<T, B, D>(
168 tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>,
169 axes: impl TryInto<AxesIndex<isize>, Error = Error>,
170 ) -> Tensor<IxD, B, IxD>
171 where
172 D: DimAPI,
173 B: $OpReduceAPI<T, D> + DeviceAPI<IxD>,
174 {
175 $fn_axes_f(tensor, axes).rstsr_unwrap()
176 }
177
178 pub fn $fn_f<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> Result<D>
179 where
180 D: DimAPI,
181 B: $OpReduceAPI<T, D>,
182 {
183 $fn_all_f(tensor)
184 }
185
186 pub fn $fn<T, B, D>(tensor: impl TensorViewAPI<Type = T, Backend = B, Dim = D>) -> D
187 where
188 D: DimAPI,
189 B: $OpReduceAPI<T, D>,
190 {
191 $fn_all(tensor)
192 }
193
194 impl<R, T, B, D> TensorAny<R, T, B, D>
195 where
196 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
197 D: DimAPI,
198 B: $OpReduceAPI<T, D>,
199 {
200 pub fn $fn_all_f(&self) -> Result<D> {
201 $fn_all_f(self)
202 }
203
204 pub fn $fn_all(&self) -> D {
205 $fn_all(self)
206 }
207
208 pub fn $fn_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<IxD, B, IxD>>
209 where
210 B: DeviceAPI<IxD>,
211 {
212 $fn_axes_f(self, axes)
213 }
214
215 pub fn $fn_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<IxD, B, IxD>
216 where
217 B: DeviceAPI<IxD>,
218 {
219 $fn_axes(self, axes)
220 }
221
222 pub fn $fn_f(&self) -> Result<D> {
223 $fn_f(self)
224 }
225
226 pub fn $fn(&self) -> D {
227 $fn(self)
228 }
229 }
230 };
231}
232
233trait_reduction_arg!(
234 OpUnraveledArgMinAPI,
235 unraveled_argmin,
236 unraveled_argmin_f,
237 unraveled_argmin_axes,
238 unraveled_argmin_axes_f,
239 unraveled_argmin_all,
240 unraveled_argmin_all_f
241);
242trait_reduction_arg!(
243 OpUnraveledArgMaxAPI,
244 unraveled_argmax,
245 unraveled_argmax_f,
246 unraveled_argmax_axes,
247 unraveled_argmax_axes_f,
248 unraveled_argmax_all,
249 unraveled_argmax_all_f
250);
251
252pub trait TensorSumBoolAPI<B, D>
255where
256 D: DimAPI,
257 B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D>,
258{
259 fn sum_all_f(&self) -> Result<usize>;
260 fn sum_all(&self) -> usize {
261 self.sum_all_f().rstsr_unwrap()
262 }
263 fn sum_f(&self) -> Result<usize> {
264 self.sum_all_f()
265 }
266 fn sum(&self) -> usize {
267 self.sum_f().rstsr_unwrap()
268 }
269 fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<usize, B, IxD>>;
270 fn sum_axes(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Tensor<usize, B, IxD> {
271 self.sum_axes_f(axes).rstsr_unwrap()
272 }
273}
274
275impl<R, B, D> TensorSumBoolAPI<B, D> for TensorAny<R, bool, B, D>
276where
277 R: DataAPI<Data = <B as DeviceRawAPI<bool>>::Raw>,
278 D: DimAPI,
279 B: DeviceAPI<bool> + DeviceAPI<usize> + OpSumBoolAPI<D> + DeviceCreationAnyAPI<usize>,
280{
281 fn sum_all_f(&self) -> Result<usize> {
282 self.device().sum_all(self.raw(), self.layout())
283 }
284
285 fn sum_axes_f(&self, axes: impl TryInto<AxesIndex<isize>, Error = Error>) -> Result<Tensor<usize, B, IxD>> {
286 let axes = axes.try_into()?;
287
288 if axes.as_ref().is_empty() {
290 let sum = self.device().sum_all(self.raw(), self.layout())?;
291 let storage = self.device().outof_cpu_vec(vec![sum])?;
292 let layout = Layout::new(vec![], vec![], 0)?;
293 return Tensor::new_f(storage, layout);
294 }
295
296 let (storage, layout) = self.device().sum_axes(self.raw(), self.layout(), axes.as_ref())?;
297 Tensor::new_f(storage, layout)
298 }
299}
300
301pub fn allclose_all_f<TA, TB, TE, B, DA, DB>(
306 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
307 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
308 isclose_args: impl Into<IsCloseArgs<TE>>,
309) -> Result<bool>
310where
311 DA: DimAPI,
312 DB: DimAPI,
313 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
314 TE: 'static,
315{
316 let tensor_a = tensor_a.view();
317 let tensor_b = tensor_b.view();
318 let isclose_args = isclose_args.into();
319 let device = tensor_a.device();
320
321 rstsr_assert!(tensor_a.device().same_device(tensor_b.device()), DeviceMismatch)?;
323
324 let la = tensor_a.layout().to_dim::<IxD>()?;
327 let lb = tensor_b.layout().to_dim::<IxD>()?;
328 let default_order = device.default_order();
329 let (la_b, lb_b) = broadcast_layout(&la, &lb, default_order)?;
330
331 device.allclose_all(tensor_a.raw(), &la_b, tensor_b.raw(), &lb_b, &isclose_args)
332}
333
334pub fn allclose_all<TA, TB, TE, B, DA, DB>(
335 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
336 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
337 isclose_args: impl Into<IsCloseArgs<TE>>,
338) -> bool
339where
340 DA: DimAPI,
341 DB: DimAPI,
342 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
343 TE: 'static,
344{
345 allclose_all_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
346}
347
348pub fn allclose_f<TA, TB, TE, B, DA, DB>(
349 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
350 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
351 isclose_args: impl Into<IsCloseArgs<TE>>,
352) -> Result<bool>
353where
354 DA: DimAPI,
355 DB: DimAPI,
356 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
357 TE: 'static,
358{
359 allclose_all_f(tensor_a, tensor_b, isclose_args)
360}
361
362pub fn allclose<TA, TB, TE, B, DA, DB>(
363 tensor_a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
364 tensor_b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
365 isclose_args: impl Into<IsCloseArgs<TE>>,
366) -> bool
367where
368 DA: DimAPI,
369 DB: DimAPI,
370 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool> + OpAllCloseAPI<TA, TB, TE, IxD>,
371 TE: 'static,
372{
373 allclose_f(tensor_a, tensor_b, isclose_args).rstsr_unwrap()
374}
375
376#[macro_export]
377macro_rules! allclose {
378 ($tensor_a:expr, $tensor_b:expr, $isclose_args:expr) => {{
379 use rstsr::prelude::rstsr_funcs::allclose;
380 allclose($tensor_a, $tensor_b, $isclose_args)
381 }};
382 ($tensor_a:expr, $tensor_b:expr) => {{
383 use rstsr::prelude::rstsr_funcs::allclose;
384 allclose($tensor_a, $tensor_b, None)
385 }};
386}
387
388#[cfg(test)]
391mod test {
392 use num::ToPrimitive;
393
394 use super::*;
395
396 #[test]
397 #[cfg(not(feature = "col_major"))]
398 fn test_sum_all_row_major() {
399 let a = arange((24, &DeviceCpuSerial::default()));
401 let s = sum_all(&a);
402 assert_eq!(s, 276);
403
404 let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
407 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
408 let s = a.sum_all();
409 assert_eq!(s, 446586);
410
411 let s = a.sum_axes(());
412 println!("{s:?}");
413 assert_eq!(s.to_scalar(), 446586);
414
415 let a = arange(24);
417 let s = sum_all(&a);
418 assert_eq!(s, 276);
419
420 let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
423 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
424 let s = a.sum_all();
425 assert_eq!(s, 446586);
426
427 let s = a.sum_axes(());
428 println!("{s:?}");
429 assert_eq!(s.to_scalar(), 446586);
430 }
431
432 #[test]
433 #[cfg(feature = "col_major")]
434 fn test_sum_all_col_major() {
435 let a = arange((24, &DeviceCpuSerial::default()));
437 let s = sum_all(&a);
438 assert_eq!(s, 276);
439
440 let a_owned = arange((3240, &DeviceCpuSerial::default())).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
445 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
446 let s = a.sum_all();
447 assert_eq!(s, 403662);
448
449 let s = a.sum_axes(());
450 println!("{s:?}");
451 assert_eq!(s.to_scalar(), 403662);
452
453 let a = arange(24);
455 let s = sum_all(&a);
456 assert_eq!(s, 276);
457
458 let a_owned: Tensor<usize> = arange(3240).into_shape([12, 15, 18]).into_swapaxes(-1, -2);
459 let a = a_owned.i((slice!(2, -3), slice!(1, -4, 2), slice!(-1, 3, -2)));
460 let s = a.sum_all();
461 assert_eq!(s, 403662);
462
463 let s = a.sum_axes(());
464 println!("{s:?}");
465 assert_eq!(s.to_scalar(), 403662);
466 }
467
468 #[test]
469 fn test_sum_axes() {
470 #[cfg(not(feature = "col_major"))]
471 {
472 let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
476 let s = a.sum_axes([0, -2]);
477 println!("{s:?}");
478 assert_eq!(s[[0, 1]], 27270);
479 assert_eq!(s[[1, 2]], 154845);
480 assert_eq!(s[[3, 5]], 428220);
481
482 let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
484 let s = a.sum_axes([0, -2]);
485 println!("{s:?}");
486 assert_eq!(s[[0, 1]], 27270);
487 assert_eq!(s[[1, 2]], 154845);
488 assert_eq!(s[[3, 5]], 428220);
489 }
490 #[cfg(feature = "col_major")]
491 {
492 let a = arange((3240, &DeviceCpuSerial::default())).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
497 let s = a.sum_axes([0, -2]);
498 println!("{s:?}");
499 assert_eq!(s[[0, 1]], 217620);
500 assert_eq!(s[[1, 2]], 218295);
501 assert_eq!(s[[3, 5]], 220185);
502
503 let a: Tensor<usize> = arange(3240).into_shape([4, 6, 15, 9]).into_transpose([2, 0, 3, 1]);
505 let s = a.sum_axes([0, -2]);
506 println!("{s:?}");
507 assert_eq!(s[[0, 1]], 217620);
508 assert_eq!(s[[1, 2]], 218295);
509 assert_eq!(s[[3, 5]], 220185);
510 }
511 }
512
513 #[test]
514 fn test_min() {
515 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
517 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
518 println!("{a:}");
519 let m = a.min_axes(0);
520 assert_eq!(m.to_vec(), vec![2, 3, 1]);
521 let m = a.min_axes(1);
522 assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
523 let m = a.min_all();
524 assert_eq!(m, 1);
525
526 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
528 let a = asarray((&v, [4, 3].c()));
529 println!("{a:}");
530 let m = a.min_axes(0);
531 assert_eq!(m.to_vec(), vec![2, 3, 1]);
532 let m = a.min_axes(1);
533 assert_eq!(m.to_vec(), vec![2, 3, 1, 5]);
534 let m = a.min_all();
535 assert_eq!(m, 1);
536 }
537
538 #[test]
539 fn test_mean() {
540 #[cfg(not(feature = "col_major"))]
541 {
542 let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
544 let m = a.mean_all();
545 assert_eq!(m, 11.5);
546
547 let m = a.mean_axes((0, 2));
548 println!("{m:}");
549 assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
550
551 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
552 println!("{m:}");
553 assert_eq!(m.to_vec(), vec![18.0, 6.0]);
554
555 let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
557 let m = a.mean_all();
558 assert_eq!(m, 11.5);
559
560 let m = a.mean_axes((0, 2));
561 println!("{m:}");
562 assert_eq!(m.to_vec(), vec![7.5, 11.5, 15.5]);
563
564 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
565 println!("{m:}");
566 assert_eq!(m.to_vec(), vec![18.0, 6.0]);
567 }
568 #[cfg(feature = "col_major")]
569 {
570 let a = arange((24.0, &DeviceCpuSerial::default())).into_shape((2, 3, 4));
572 let m = a.mean_all();
573 assert_eq!(m, 11.5);
574
575 let m = a.mean_axes((0, 2));
576 println!("{m:}");
577 assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
578
579 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
580 println!("{m:}");
581 assert_eq!(m.to_vec(), vec![15.0, 14.0]);
582
583 let a: Tensor<f64> = arange(24.0).into_shape((2, 3, 4));
585 let m = a.mean_all();
586 assert_eq!(m, 11.5);
587
588 let m = a.mean_axes((0, 2));
589 println!("{m:}");
590 assert_eq!(m.to_vec(), vec![9.5, 11.5, 13.5]);
591
592 let m = a.i((slice!(None, None, -1), .., slice!(None, None, -2))).mean_axes((-1, 1));
593 println!("{m:}");
594 assert_eq!(m.to_vec(), vec![15.0, 14.0]);
595 }
596 }
597
598 #[test]
599 fn test_var() {
600 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
602 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
603
604 let m = a.var_all();
605 println!("{m:}");
606 assert!((m - 8.409722222222221).abs() < 1e-10);
607
608 let m = a.var_axes(0);
609 println!("{m:}");
610 assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
611
612 let m = a.var_axes(1);
613 println!("{m:}");
614 assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
615
616 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
618 let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
619
620 let m = a.var_all();
621 println!("{m:}");
622 assert!((m - 8.409722222222221).abs() < 1e-10);
623
624 let m = a.var_axes(0);
625 println!("{m:}");
626 assert!(allclose_f64(&m, &asarray(vec![7.1875, 8.1875, 5.6875])));
627
628 let m = a.var_axes(1);
629 println!("{m:}");
630 assert!(allclose_f64(&m, &asarray(vec![6.22222222, 6.22222222, 9.55555556, 4.66666667])));
631 }
632
633 #[test]
634 fn test_std() {
635 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
637 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default())).mapv(|x| x as f64);
638
639 let m = a.std_all();
640 println!("{m:}");
641 assert!((m - 2.899952106884219).abs() < 1e-10);
642
643 let m = a.std_axes(0);
644 println!("{m:}");
645 assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
646
647 let m = a.std_axes(1);
648 println!("{m:}");
649 assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
650
651 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
652 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
653 let v = vr
654 .iter()
655 .zip(vi.iter())
656 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
657 .collect::<Vec<_>>();
658 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
659
660 let m = a.std_all();
661 println!("{m:}");
662 assert!((m - 4.508479664907993).abs() < 1e-10);
663
664 let v = vec![8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
666 let a = asarray((&v, [4, 3].c())).mapv(|x| x as f64);
667
668 let m = a.std_all();
669 println!("{m:}");
670 assert!((m - 2.899952106884219).abs() < 1e-10);
671
672 let m = a.std_axes(0);
673 println!("{m:}");
674 assert!(allclose_f64(&m, &asarray(vec![2.68095132, 2.86138079, 2.384848])));
675
676 let m = a.std_axes(1);
677 println!("{m:}");
678 assert!(allclose_f64(&m, &asarray(vec![2.49443826, 2.49443826, 3.09120617, 2.1602469])));
679
680 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
681 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
682 let v = vr
683 .iter()
684 .zip(vi.iter())
685 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
686 .collect::<Vec<_>>();
687 let a = asarray((&v, [4, 3].c()));
688
689 let m = a.std_all();
690 println!("{m:}");
691 assert!((m - 4.508479664907993).abs() < 1e-10);
692 }
693
694 #[test]
695 fn test_l2_norm() {
696 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
698 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
699 let v = vr
700 .iter()
701 .zip(vi.iter())
702 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
703 .collect::<Vec<_>>();
704 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
705
706 let m = a.l2_norm_all();
707 println!("{m:}");
708 assert!((m - 33.21144381083123).abs() < 1e-10);
709
710 let vr = [8, 4, 2, 9, 3, 7, 2, 8, 1, 6, 10, 5];
712 let vi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
713 let v = vr
714 .iter()
715 .zip(vi.iter())
716 .map(|(r, i)| num::Complex::new(r.to_f64().unwrap(), i.to_f64().unwrap()))
717 .collect::<Vec<_>>();
718 let a = asarray((&v, [4, 3].c()));
719
720 let m = a.l2_norm_all();
721 println!("{m:}");
722 assert!((m - 33.21144381083123).abs() < 1e-10);
723 }
724
725 #[test]
726 #[cfg(feature = "rayon")]
727 fn test_large_std() {
728 #[cfg(not(feature = "col_major"))]
729 {
730 let a = linspace((0.0, 1.0, 1048576)).into_shape([16, 256, 256]);
737 let b = linspace((1.0, 2.0, 1048576)).into_shape([16, 256, 256]);
738 let c: Tensor<f64> = &a % &b;
739
740 let c_mean = c.mean_all();
741 println!("{c_mean:?}");
742 assert!((c_mean - 213.2503660477036) < 1e-6);
743
744 let c_std = c.std_all();
745 println!("{c_std:?}");
746 assert!((c_std - 148.88523481701804) < 1e-6);
747
748 let c_std_1 = c.std_axes((0, 1));
749 println!("{c_std_1}");
750 assert!(c_std_1[[0]] - 148.8763226818815 < 1e-6);
751 assert!(c_std_1[[255]] - 148.8941462322758 < 1e-6);
752
753 let c_std_2 = c.std_axes((1, 2));
754 println!("{c_std_2}");
755 assert!(c_std_2[[0]] - 4.763105902995575 < 1e-6);
756 assert!(c_std_2[[15]] - 9.093224903569157 < 1e-6);
757 }
758 #[cfg(feature = "col_major")]
759 {
760 let a = linspace((0.0, 1.0, 1048576)).into_shape([256, 256, 16]);
770 let b = linspace((1.0, 2.0, 1048576)).into_shape([256, 256, 16]);
771 let mut c: Tensor<f64> = zeros([256, 256, 16]);
772 for i in 0..16 {
773 c.i_mut((.., .., i)).assign(&a.i((.., .., i)) % &b.i((.., .., i)));
774 }
775
776 let c_mean = c.mean_all();
777 println!("{c_mean:?}");
778 assert!((c_mean - 213.25036604770355) < 1e-6);
779
780 let c_std = c.std_all();
781 println!("{c_std:?}");
782 assert!((c_std - 148.7419537312827) < 1e-6);
783
784 let c_std_1 = c.std_axes((1, 2));
785 println!("{c_std_1}");
786 assert!(c_std_1[[0]] - 148.75113653867191 < 1e-6);
787 assert!(c_std_1[[255]] - 148.7689445622776 < 1e-6);
788
789 let c_std_2 = c.std_axes((0, 1));
790 println!("{c_std_2}");
791 assert!(c_std_2[[0]] - 0.145530296246335 < 1e-6);
792 assert!(c_std_2[[15]] - 4.474611918106057 < 1e-6);
793 }
794 }
795
796 #[test]
797 fn test_unraveled_argmin() {
798 let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
800 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
801 println!("{a:}");
802 let m = a.unraveled_argmin_all();
808 println!("{m:?}");
809 assert_eq!(m, vec![1, 2]);
810
811 let m = a.unraveled_argmin_axes(-1);
812 println!("{m:?}");
813 let m_vec = m.raw();
814 assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1], vec![2]]);
815
816 let m = a.unraveled_argmin_axes(0);
817 println!("{m:?}");
818 let m_vec = m.raw();
819 assert_eq!(m_vec, &vec![vec![2], vec![2], vec![1]]);
820 }
821
822 #[test]
823 fn test_argmin() {
824 let v = vec![8, 4, 2, 9, 7, 1, 2, 1, 8, 6, 10, 5];
826 let a = asarray((&v, [4, 3].c(), &DeviceCpuSerial::default()));
827 println!("{a:}");
828 let m = a.argmin_all();
834 println!("{m:?}");
835 assert_eq!(m, 5);
836
837 let m = a.argmin_axes(-1);
838 println!("{m:?}");
839 let m_vec = m.raw();
840 assert_eq!(m_vec, &vec![2, 2, 1, 2]);
841
842 let m = a.argmin_axes(0);
843 println!("{m:?}");
844 let m_vec = m.raw();
845 assert_eq!(m_vec, &vec![2, 2, 1]);
846 }
847
848 #[test]
849 fn test_all() {
850 let a = asarray((vec![true, true, false, true, true, true], [2, 3].c()));
851 let a_all = a.all_axes(-2);
852 println!("{:?}", a_all);
853 assert_eq!(a_all.raw(), &[true, true, false]);
854 }
855
856 #[test]
857 fn test_allclose_cpu_serial() {
858 use rstsr_dtype_traits::IsCloseArgsBuilder;
859
860 let mut device = DeviceCpuSerial::default();
861 device.set_default_order(RowMajor);
862 let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
863 let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
864 let result = allclose(&a, &b, None);
865 println!("Allclose result: {result}");
866 assert!(result);
867 let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
868 let result = allclose(&a, &b, args);
869 println!("Allclose result with tight args: {result}");
870 assert!(!result);
871 }
872
873 #[test]
874 #[cfg(feature = "faer")]
875 fn test_allclose_faer() {
876 use rstsr_dtype_traits::IsCloseArgsBuilder;
877
878 let mut device = DeviceFaer::default();
879 device.set_default_order(RowMajor);
880 let a = asarray((vec![1, 2, 3, 4], [2, 2].c(), &device));
881 let b = asarray((vec![1.0f32, 3.0, 2.0, 4.00001], [2, 2].f(), &device));
882 let result = allclose(&a, &b, None);
883 println!("Allclose result: {result}");
884 assert!(result);
885 let args = IsCloseArgsBuilder::default().atol(1e-8).rtol(1e-8).build().unwrap();
886 let result = allclose(&a, &b, args);
887 println!("Allclose result with tight args: {result}");
888 assert!(!result);
889 }
890}