27_axis_reductions/
27_axis_reductions.rs1use matten::Tensor;
10
11fn main() {
12 let m = Tensor::new(
14 vec![
15 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
19 &[3, 4],
20 );
21
22 let col_sums = m.sum_axis(0);
24 assert_eq!(col_sums.shape(), &[4]);
25 assert_eq!(col_sums.as_slice(), &[15.0, 18.0, 21.0, 24.0]);
26 println!("col sums = {:?}", col_sums.as_slice());
27
28 let row_means = m.mean_axis(1);
30 assert_eq!(row_means.shape(), &[3]);
31 assert_eq!(row_means.as_slice(), &[2.5, 6.5, 10.5]);
32 println!("row means = {:?}", row_means.as_slice());
33
34 let col_min = m.min_axis(0);
36 let col_max = m.max_axis(0);
37 assert_eq!(col_min.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
38 assert_eq!(col_max.as_slice(), &[9.0, 10.0, 11.0, 12.0]);
39 println!("col min = {:?}", col_min.as_slice());
40 println!("col max = {:?}", col_max.as_slice());
41
42 let with_nan = Tensor::new(vec![1.0, f64::NAN, 3.0, 5.0, 6.0, 7.0], &[2, 3]);
44 let nan_col_min = with_nan.min_axis(0);
45 assert!(nan_col_min.as_slice()[1].is_nan()); assert_eq!(nan_col_min.as_slice()[0], 1.0); println!("NaN propagation confirmed");
48
49 println!("done.");
50}