Skip to main content

nabled_ml/
stats.rs

1//! Statistical utilities over ndarray matrices.
2
3use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use ndarray::{Array1, Array2, ArrayView2, Axis};
7use num_complex::Complex64;
8
9/// Error type for matrix statistics.
10#[derive(Debug, Clone, Copy, PartialEq)]
11pub enum StatsError {
12    /// Matrix is empty.
13    EmptyMatrix,
14    /// Matrix needs at least two rows.
15    InsufficientSamples,
16    /// Numerical instability detected.
17    NumericalInstability,
18}
19
20impl fmt::Display for StatsError {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        match self {
23            StatsError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
24            StatsError::InsufficientSamples => {
25                write!(f, "At least two observations are required")
26            }
27            StatsError::NumericalInstability => write!(f, "Numerical instability detected"),
28        }
29    }
30}
31
32impl std::error::Error for StatsError {}
33
34fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
35    T::from_usize(value).unwrap_or(T::max_value())
36}
37
38fn complex_is_finite(value: Complex64) -> bool { value.re.is_finite() && value.im.is_finite() }
39
40fn column_means_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
41    matrix.mean_axis(Axis(0)).unwrap_or_else(|| Array1::zeros(matrix.ncols()))
42}
43
44/// Compute column means.
45#[must_use]
46pub fn column_means<T: NabledReal>(matrix: &Array2<T>) -> Array1<T> {
47    column_means_impl(&matrix.view())
48}
49
50/// Compute column means from a matrix view.
51#[must_use]
52pub fn column_means_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
53    column_means_impl(matrix)
54}
55
56fn center_columns_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
57    let means = column_means_impl(matrix);
58    let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
59    for row in 0..matrix.nrows() {
60        for col in 0..matrix.ncols() {
61            centered[[row, col]] = matrix[[row, col]] - means[col];
62        }
63    }
64    centered
65}
66
67/// Center columns by subtracting their means.
68#[must_use]
69pub fn center_columns<T: NabledReal>(matrix: &Array2<T>) -> Array2<T> {
70    center_columns_impl(&matrix.view())
71}
72
73/// Center columns by subtracting their means from a matrix view.
74#[must_use]
75pub fn center_columns_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
76    center_columns_impl(matrix)
77}
78
79fn covariance_matrix_impl<T: NabledReal>(
80    matrix: &ArrayView2<'_, T>,
81) -> Result<Array2<T>, StatsError> {
82    if matrix.is_empty() {
83        return Err(StatsError::EmptyMatrix);
84    }
85    if matrix.nrows() < 2 {
86        return Err(StatsError::InsufficientSamples);
87    }
88
89    let centered = center_columns_impl(matrix);
90    let covariance: Array2<T> =
91        centered.t().dot(&centered) / usize_to_scalar::<T>(matrix.nrows() - 1);
92
93    if covariance.iter().any(|value| !value.is_finite()) {
94        return Err(StatsError::NumericalInstability);
95    }
96
97    Ok(covariance)
98}
99
100/// Compute sample covariance matrix.
101///
102/// # Errors
103/// Returns an error for empty input or fewer than two samples.
104pub fn covariance_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
105    covariance_matrix_impl(&matrix.view())
106}
107
108/// Compute sample covariance matrix from a matrix view.
109///
110/// # Errors
111/// Returns an error for empty input or fewer than two samples.
112pub fn covariance_matrix_view<T: NabledReal>(
113    matrix: &ArrayView2<'_, T>,
114) -> Result<Array2<T>, StatsError> {
115    covariance_matrix_impl(matrix)
116}
117
118fn correlation_matrix_impl<T: NabledReal>(
119    matrix: &ArrayView2<'_, T>,
120) -> Result<Array2<T>, StatsError> {
121    let covariance = covariance_matrix_impl(matrix)?;
122    let n = covariance.nrows();
123    let mut correlation = Array2::<T>::zeros((n, n));
124
125    for i in 0..n {
126        let sigma_i = covariance[[i, i]].sqrt();
127        for j in 0..n {
128            let sigma_j = covariance[[j, j]].sqrt();
129            let denom = (sigma_i * sigma_j).max(T::epsilon());
130            correlation[[i, j]] = covariance[[i, j]] / denom;
131        }
132    }
133
134    Ok(correlation)
135}
136
137/// Compute correlation matrix.
138///
139/// # Errors
140/// Returns an error if covariance computation fails.
141pub fn correlation_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
142    correlation_matrix_impl(&matrix.view())
143}
144
145/// Compute correlation matrix from a matrix view.
146///
147/// # Errors
148/// Returns an error if covariance computation fails.
149pub fn correlation_matrix_view<T: NabledReal>(
150    matrix: &ArrayView2<'_, T>,
151) -> Result<Array2<T>, StatsError> {
152    correlation_matrix_impl(matrix)
153}
154
155fn column_means_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
156    if matrix.nrows() == 0 {
157        return Array1::zeros(matrix.ncols());
158    }
159
160    let mut means = Array1::<Complex64>::zeros(matrix.ncols());
161    for col in 0..matrix.ncols() {
162        let mut sum = Complex64::new(0.0, 0.0);
163        for row in 0..matrix.nrows() {
164            sum += matrix[[row, col]];
165        }
166        means[col] = sum / usize_to_scalar::<f64>(matrix.nrows());
167    }
168    means
169}
170
171/// Compute complex column means.
172#[must_use]
173pub fn column_means_complex(matrix: &Array2<Complex64>) -> Array1<Complex64> {
174    column_means_complex_impl(&matrix.view())
175}
176
177/// Compute complex column means from a matrix view.
178#[must_use]
179pub fn column_means_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
180    column_means_complex_impl(matrix)
181}
182
183fn center_columns_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
184    let means = column_means_complex_impl(matrix);
185    let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
186    for row in 0..matrix.nrows() {
187        for col in 0..matrix.ncols() {
188            centered[[row, col]] = matrix[[row, col]] - means[col];
189        }
190    }
191    centered
192}
193
194/// Center complex columns by subtracting their means.
195#[must_use]
196pub fn center_columns_complex(matrix: &Array2<Complex64>) -> Array2<Complex64> {
197    center_columns_complex_impl(&matrix.view())
198}
199
200/// Center complex columns by subtracting their means from a matrix view.
201#[must_use]
202pub fn center_columns_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
203    center_columns_complex_impl(matrix)
204}
205
206fn covariance_matrix_complex_impl(
207    matrix: &ArrayView2<'_, Complex64>,
208) -> Result<Array2<Complex64>, StatsError> {
209    if matrix.is_empty() {
210        return Err(StatsError::EmptyMatrix);
211    }
212    if matrix.nrows() < 2 {
213        return Err(StatsError::InsufficientSamples);
214    }
215
216    let centered = center_columns_complex_impl(matrix);
217    let conjugate_transpose = centered.t().mapv(|value| value.conj());
218    let covariance: Array2<Complex64> =
219        conjugate_transpose.dot(&centered) / usize_to_scalar::<f64>(matrix.nrows() - 1);
220
221    if covariance.iter().any(|value| !complex_is_finite(*value)) {
222        return Err(StatsError::NumericalInstability);
223    }
224
225    Ok(covariance)
226}
227
228/// Compute sample covariance matrix for complex observations.
229///
230/// # Errors
231/// Returns an error for empty input or fewer than two samples.
232pub fn covariance_matrix_complex(
233    matrix: &Array2<Complex64>,
234) -> Result<Array2<Complex64>, StatsError> {
235    covariance_matrix_complex_impl(&matrix.view())
236}
237
238/// Compute sample covariance matrix for complex observations from a matrix view.
239///
240/// # Errors
241/// Returns an error for empty input or fewer than two samples.
242pub fn covariance_matrix_complex_view(
243    matrix: &ArrayView2<'_, Complex64>,
244) -> Result<Array2<Complex64>, StatsError> {
245    covariance_matrix_complex_impl(matrix)
246}
247
248fn correlation_matrix_complex_impl(
249    matrix: &ArrayView2<'_, Complex64>,
250) -> Result<Array2<Complex64>, StatsError> {
251    let covariance = covariance_matrix_complex_impl(matrix)?;
252    let n = covariance.nrows();
253    let mut correlation = Array2::<Complex64>::zeros((n, n));
254
255    for i in 0..n {
256        let sigma_i = covariance[[i, i]].re.max(0.0).sqrt();
257        for j in 0..n {
258            let sigma_j = covariance[[j, j]].re.max(0.0).sqrt();
259            let denom = (sigma_i * sigma_j).max(f64::EPSILON);
260            correlation[[i, j]] = covariance[[i, j]] / denom;
261        }
262    }
263
264    if correlation.iter().any(|value| !complex_is_finite(*value)) {
265        return Err(StatsError::NumericalInstability);
266    }
267
268    Ok(correlation)
269}
270
271/// Compute correlation matrix for complex observations.
272///
273/// # Errors
274/// Returns an error if covariance computation fails.
275pub fn correlation_matrix_complex(
276    matrix: &Array2<Complex64>,
277) -> Result<Array2<Complex64>, StatsError> {
278    correlation_matrix_complex_impl(&matrix.view())
279}
280
281/// Compute correlation matrix for complex observations from a matrix view.
282///
283/// # Errors
284/// Returns an error if covariance computation fails.
285pub fn correlation_matrix_complex_view(
286    matrix: &ArrayView2<'_, Complex64>,
287) -> Result<Array2<Complex64>, StatsError> {
288    correlation_matrix_complex_impl(matrix)
289}
290
291#[cfg(test)]
292mod tests {
293    use ndarray::Array2;
294    use num_complex::Complex64;
295
296    use super::*;
297
298    #[test]
299    fn covariance_and_correlation_are_well_formed() {
300        let matrix =
301            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
302                .unwrap();
303        let covariance = covariance_matrix(&matrix).unwrap();
304        let correlation = correlation_matrix(&matrix).unwrap();
305        assert_eq!(covariance.dim(), (2, 2));
306        assert_eq!(correlation.dim(), (2, 2));
307    }
308
309    #[test]
310    fn stats_rejects_empty_and_insufficient_inputs() {
311        let empty = Array2::<f64>::zeros((0, 0));
312        assert!(matches!(covariance_matrix(&empty), Err(StatsError::EmptyMatrix)));
313
314        let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
315        assert!(matches!(covariance_matrix(&one_row), Err(StatsError::InsufficientSamples)));
316    }
317
318    #[test]
319    fn center_columns_zeroes_means() {
320        let matrix =
321            Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
322        let centered = center_columns(&matrix);
323        let means = column_means(&centered);
324        assert!(means.iter().all(|value| num_traits::Float::abs(*value) < 1e-12));
325    }
326
327    #[test]
328    fn column_means_handles_empty_input() {
329        let matrix = Array2::<f64>::zeros((0, 3));
330        let means = column_means(&matrix);
331        assert_eq!(means.len(), 3);
332        assert!(means.iter().all(|value| *value == 0.0));
333    }
334
335    #[test]
336    fn covariance_reports_numerical_instability() {
337        let matrix = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
338        let result = covariance_matrix(&matrix);
339        assert!(matches!(result, Err(StatsError::NumericalInstability)));
340    }
341
342    #[test]
343    fn correlation_handles_zero_variance_column() {
344        let matrix =
345            Array2::from_shape_vec((3, 2), vec![1.0_f64, 10.0, 1.0, 20.0, 1.0, 30.0]).unwrap();
346        let correlation = correlation_matrix(&matrix).unwrap();
347        assert!(correlation[[0, 0]].is_finite());
348        assert!(correlation[[0, 1]].is_finite());
349        assert!(correlation[[1, 0]].is_finite());
350        assert!(correlation[[1, 1]].is_finite());
351    }
352
353    #[test]
354    fn view_variants_match_owned() {
355        let matrix =
356            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
357                .unwrap();
358        let means_owned = column_means(&matrix);
359        let means_view = column_means_view(&matrix.view());
360        let centered_owned = center_columns(&matrix);
361        let centered_view = center_columns_view(&matrix.view());
362        let covariance_owned = covariance_matrix(&matrix).unwrap();
363        let covariance_view = covariance_matrix_view(&matrix.view()).unwrap();
364        let correlation_owned = correlation_matrix(&matrix).unwrap();
365        let correlation_view = correlation_matrix_view(&matrix.view()).unwrap();
366
367        for i in 0..means_owned.len() {
368            assert!((means_owned[i] - means_view[i]).abs() < 1e-12);
369        }
370        for i in 0..matrix.nrows() {
371            for j in 0..matrix.ncols() {
372                assert!((centered_owned[[i, j]] - centered_view[[i, j]]).abs() < 1e-12);
373            }
374        }
375        for i in 0..2 {
376            for j in 0..2 {
377                assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).abs() < 1e-12);
378                assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).abs() < 1e-12);
379            }
380        }
381    }
382
383    #[test]
384    fn complex_covariance_and_correlation_are_well_formed() {
385        let matrix = Array2::from_shape_vec((4, 2), vec![
386            Complex64::new(1.0, 0.0),
387            Complex64::new(3.0, 1.0),
388            Complex64::new(2.0, -1.0),
389            Complex64::new(2.0, 0.5),
390            Complex64::new(3.0, 0.2),
391            Complex64::new(1.0, -0.3),
392            Complex64::new(4.0, 0.7),
393            Complex64::new(0.0, 0.0),
394        ])
395        .unwrap();
396
397        let covariance = covariance_matrix_complex(&matrix).unwrap();
398        let correlation = correlation_matrix_complex(&matrix).unwrap();
399        assert_eq!(covariance.dim(), (2, 2));
400        assert_eq!(correlation.dim(), (2, 2));
401    }
402
403    #[test]
404    fn complex_view_variants_match_owned() {
405        let matrix = Array2::from_shape_vec((3, 2), vec![
406            Complex64::new(1.0, 1.0),
407            Complex64::new(2.0, -1.0),
408            Complex64::new(2.0, 2.0),
409            Complex64::new(3.0, 0.0),
410            Complex64::new(3.0, -2.0),
411            Complex64::new(4.0, 1.0),
412        ])
413        .unwrap();
414
415        let means_owned = column_means_complex(&matrix);
416        let means_view = column_means_complex_view(&matrix.view());
417        let centered_owned = center_columns_complex(&matrix);
418        let centered_view = center_columns_complex_view(&matrix.view());
419        let covariance_owned = covariance_matrix_complex(&matrix).unwrap();
420        let covariance_view = covariance_matrix_complex_view(&matrix.view()).unwrap();
421        let correlation_owned = correlation_matrix_complex(&matrix).unwrap();
422        let correlation_view = correlation_matrix_complex_view(&matrix.view()).unwrap();
423
424        for i in 0..means_owned.len() {
425            assert!((means_owned[i] - means_view[i]).norm() < 1e-12);
426        }
427        for i in 0..matrix.nrows() {
428            for j in 0..matrix.ncols() {
429                assert!((centered_owned[[i, j]] - centered_view[[i, j]]).norm() < 1e-12);
430            }
431        }
432        for i in 0..2 {
433            for j in 0..2 {
434                assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).norm() < 1e-12);
435                assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).norm() < 1e-12);
436            }
437        }
438    }
439}