1use scirs2_core::ndarray::{Array1, Array2, Axis, NdFloat};
4use scirs2_core::numeric::FromPrimitive;
5
6pub fn mean_axis<'a, D>(array: &'a Array2<D>, axis: Axis) -> Array1<D>
8where
9 D: NdFloat + FromPrimitive + 'a,
10{
11 array.mean_axis(axis).unwrap()
12}
13
14pub fn var_axis<'a, D>(array: &'a Array2<D>, axis: Axis, ddof: usize) -> Array1<D>
16where
17 D: NdFloat + FromPrimitive + 'a,
18{
19 let mean = array.mean_axis(axis).unwrap();
20 let n = array.len_of(axis);
21
22 if axis == Axis(0) {
23 let mut var = Array1::zeros(array.ncols());
25 for j in 0..array.ncols() {
26 let col = array.column(j);
27 let m = mean[j];
28 let sum_sq: D = col.mapv(|x| (x - m).powi(2)).sum();
29 var[j] = sum_sq / D::from(n - ddof).unwrap();
30 }
31 var
32 } else {
33 let mut var = Array1::zeros(array.nrows());
35 for i in 0..array.nrows() {
36 let row = array.row(i);
37 let m = mean[i];
38 let sum_sq: D = row.mapv(|x| (x - m).powi(2)).sum();
39 var[i] = sum_sq / D::from(n - ddof).unwrap();
40 }
41 var
42 }
43}
44
45pub fn std_axis<'a, D>(array: &'a Array2<D>, axis: Axis, ddof: usize) -> Array1<D>
47where
48 D: NdFloat + FromPrimitive + 'a,
49{
50 var_axis(array, axis, ddof).mapv(|v| v.sqrt())
51}
52
53pub fn covariance<D>(x: &Array2<D>, ddof: usize) -> Array2<D>
55where
56 D: NdFloat + FromPrimitive,
57{
58 let n_samples = x.nrows();
59
60 let mean = x.mean_axis(Axis(0)).unwrap();
62 let centered = x - &mean;
63
64 let cov = centered.t().dot(¢ered) / D::from(n_samples - ddof).unwrap();
66 cov
67}
68
69#[allow(non_snake_case)]
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use approx::assert_abs_diff_eq;
74 use scirs2_core::ndarray::array;
75
76 #[test]
77 fn test_mean_axis() {
78 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
79
80 let mean_rows = mean_axis(&x, Axis(0));
81 assert_abs_diff_eq!(mean_rows[0], 4.0, epsilon = 1e-10);
82 assert_abs_diff_eq!(mean_rows[1], 5.0, epsilon = 1e-10);
83 assert_abs_diff_eq!(mean_rows[2], 6.0, epsilon = 1e-10);
84
85 let mean_cols = mean_axis(&x, Axis(1));
86 assert_abs_diff_eq!(mean_cols[0], 2.0, epsilon = 1e-10);
87 assert_abs_diff_eq!(mean_cols[1], 5.0, epsilon = 1e-10);
88 assert_abs_diff_eq!(mean_cols[2], 8.0, epsilon = 1e-10);
89 }
90
91 #[test]
92 fn test_var_axis() {
93 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
94
95 let var_rows = var_axis(&x, Axis(0), 0);
96 assert_abs_diff_eq!(var_rows[0], 6.0, epsilon = 1e-10);
97 assert_abs_diff_eq!(var_rows[1], 6.0, epsilon = 1e-10);
98 assert_abs_diff_eq!(var_rows[2], 6.0, epsilon = 1e-10);
99 }
100}