Skip to main content

28_column_statistics/
28_column_statistics.rs

1//! Per-column statistics using axis reductions: mean, min, max, std dev.
2//!
3//! Run: cargo run --example 28_column_statistics
4//!
5//! A common PoC pattern: treating each column of a [rows, cols] tensor as
6//! a feature, computing basic stats across all rows.
7
8use matten::Tensor;
9
10/// Manually compute column standard deviation from column means.
11/// std_dev_i = sqrt( mean( (col_i - mean_i)^2 ) )
12fn column_std(data: &Tensor, col_means: &Tensor) -> Tensor {
13    // Broadcast means over all rows, compute squared deviations
14    let deviations = data - col_means; // [rows, cols] - [cols] → [rows, cols]
15    let sq = &deviations * &deviations;
16    sq.mean_axis(0) // → [cols] mean of squared deviations
17        .as_slice()
18        .iter()
19        .map(|v| v.sqrt())
20        .collect::<Vec<f64>>()
21        .into()
22}
23
24fn main() {
25    // 4 rows, 3 columns (features)
26    let data = Tensor::new(
27        vec![
28            1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0,
29        ],
30        &[4, 3],
31    );
32
33    let means = data.mean_axis(0);
34    let mins = data.min_axis(0);
35    let maxs = data.max_axis(0);
36    let stds = column_std(&data, &means);
37
38    println!("columns : 0      1      2");
39    println!("mean    : {:?}", means.as_slice());
40    println!("min     : {:?}", mins.as_slice());
41    println!("max     : {:?}", maxs.as_slice());
42    println!("std dev : {:?}", stds.as_slice());
43
44    assert_eq!(means.as_slice(), &[2.5, 25.0, 250.0]);
45    assert_eq!(mins.as_slice(), &[1.0, 10.0, 100.0]);
46    assert_eq!(maxs.as_slice(), &[4.0, 40.0, 400.0]);
47
48    // Std dev: sqrt of mean squared deviation from mean
49    let expected_std0 = {
50        // Column 0 values: 1,2,3,4. Mean=2.5. Deviations: -1.5,-0.5,0.5,1.5
51        let sq_devs = [
52            1.5f64.powi(2),
53            0.5f64.powi(2),
54            0.5f64.powi(2),
55            1.5f64.powi(2),
56        ];
57        (sq_devs.iter().sum::<f64>() / 4.0).sqrt()
58    };
59    assert!((stds.as_slice()[0] - expected_std0).abs() < 1e-10);
60
61    println!("done.");
62}