Skip to main content

27_axis_reductions/
27_axis_reductions.rs

1//! Axis-based reductions: `sum_axis`, `mean_axis`, `min_axis`, `max_axis`.
2//!
3//! Run: cargo run --example 27_axis_reductions
4//!
5//! Reducing along an axis removes that axis from the output shape.
6//! NaN propagates: if any element along a reduced axis is NaN, the
7//! output cell for that slice is NaN.
8
9use matten::Tensor;
10
11fn main() {
12    // Shape [3, 4]: 3 rows, 4 columns
13    let m = Tensor::new(
14        vec![
15            1.0, 2.0, 3.0, 4.0, // row 0
16            5.0, 6.0, 7.0, 8.0, // row 1
17            9.0, 10.0, 11.0, 12.0, // row 2
18        ],
19        &[3, 4],
20    );
21
22    // Reduce rows → column sums: shape [4]
23    let col_sums = m.sum_axis(0);
24    assert_eq!(col_sums.shape(), &[4]);
25    assert_eq!(col_sums.as_slice(), &[15.0, 18.0, 21.0, 24.0]);
26    println!("col sums  = {:?}", col_sums.as_slice());
27
28    // Reduce columns → row means: shape [3]
29    let row_means = m.mean_axis(1);
30    assert_eq!(row_means.shape(), &[3]);
31    assert_eq!(row_means.as_slice(), &[2.5, 6.5, 10.5]);
32    println!("row means = {:?}", row_means.as_slice());
33
34    // Column-wise min and max
35    let col_min = m.min_axis(0);
36    let col_max = m.max_axis(0);
37    assert_eq!(col_min.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
38    assert_eq!(col_max.as_slice(), &[9.0, 10.0, 11.0, 12.0]);
39    println!("col min   = {:?}", col_min.as_slice());
40    println!("col max   = {:?}", col_max.as_slice());
41
42    // NaN propagation: one NaN contaminates its column's min
43    let with_nan = Tensor::new(vec![1.0, f64::NAN, 3.0, 5.0, 6.0, 7.0], &[2, 3]);
44    let nan_col_min = with_nan.min_axis(0);
45    assert!(nan_col_min.as_slice()[1].is_nan()); // column 1 had a NaN
46    assert_eq!(nan_col_min.as_slice()[0], 1.0); // column 0 clean
47    println!("NaN propagation confirmed");
48
49    println!("done.");
50}