Skip to main content

nabled_ml/
stats.rs

1//! Statistical utilities over ndarray matrices.
2
3use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, DataMut, Ix1, Ix2};
7use num_complex::Complex64;
8
9/// Error type for matrix statistics.
10#[derive(Debug, Clone, PartialEq)]
11pub enum StatsError {
12    /// Matrix is empty.
13    EmptyMatrix,
14    /// Matrix needs at least two rows.
15    InsufficientSamples,
16    /// Input or output shapes are incompatible.
17    InvalidInput(String),
18    /// Numerical instability detected.
19    NumericalInstability,
20}
21
22impl fmt::Display for StatsError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            StatsError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
26            StatsError::InsufficientSamples => {
27                write!(f, "At least two observations are required")
28            }
29            StatsError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
30            StatsError::NumericalInstability => write!(f, "Numerical instability detected"),
31        }
32    }
33}
34
35impl std::error::Error for StatsError {}
36
37fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
38    T::from_usize(value).unwrap_or(T::max_value())
39}
40
41fn complex_is_finite(value: Complex64) -> bool { value.re.is_finite() && value.im.is_finite() }
42
43fn validate_vector_output_len<T, S>(
44    output: &ArrayBase<S, Ix1>,
45    expected_len: usize,
46    name: &str,
47) -> Result<(), StatsError>
48where
49    S: DataMut<Elem = T>,
50{
51    if output.len() != expected_len {
52        return Err(StatsError::InvalidInput(format!(
53            "{name} output length must match expected length {expected_len}",
54        )));
55    }
56    Ok(())
57}
58
59fn validate_matrix_output_shape<T, S>(
60    output: &ArrayBase<S, Ix2>,
61    expected_rows: usize,
62    expected_cols: usize,
63    name: &str,
64) -> Result<(), StatsError>
65where
66    S: DataMut<Elem = T>,
67{
68    if output.nrows() != expected_rows || output.ncols() != expected_cols {
69        return Err(StatsError::InvalidInput(format!(
70            "{name} output shape must be ({expected_rows}, {expected_cols})",
71        )));
72    }
73    Ok(())
74}
75
76fn column_means_into_impl<T, S>(
77    matrix: &ArrayView2<'_, T>,
78    output: &mut ArrayBase<S, Ix1>,
79) -> Result<(), StatsError>
80where
81    T: NabledReal,
82    S: DataMut<Elem = T>,
83{
84    validate_vector_output_len(output, matrix.ncols(), "column_means")?;
85
86    if matrix.nrows() == 0 {
87        output.fill(T::zero());
88        return Ok(());
89    }
90
91    let denom = usize_to_scalar::<T>(matrix.nrows());
92    for col in 0..matrix.ncols() {
93        let mut sum = T::zero();
94        for row in 0..matrix.nrows() {
95            sum += matrix[[row, col]];
96        }
97        output[col] = sum / denom;
98    }
99
100    Ok(())
101}
102
103fn column_means_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
104    matrix.mean_axis(Axis(0)).unwrap_or_else(|| Array1::zeros(matrix.ncols()))
105}
106
107/// Compute column means.
108#[must_use]
109pub fn column_means<T: NabledReal>(matrix: &Array2<T>) -> Array1<T> {
110    column_means_impl(&matrix.view())
111}
112
113/// Compute column means from a matrix view.
114#[must_use]
115pub fn column_means_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
116    column_means_impl(matrix)
117}
118
119/// Compute column means into caller-provided output.
120///
121/// # Errors
122/// Returns an error if `output` does not match the column count.
123pub fn column_means_into<T, S>(
124    matrix: &Array2<T>,
125    output: &mut ArrayBase<S, Ix1>,
126) -> Result<(), StatsError>
127where
128    T: NabledReal,
129    S: DataMut<Elem = T>,
130{
131    column_means_into_impl(&matrix.view(), output)
132}
133
134/// Compute column means from a matrix view into caller-provided output.
135///
136/// # Errors
137/// Returns an error if `output` does not match the column count.
138pub fn column_means_view_into<T, S>(
139    matrix: &ArrayView2<'_, T>,
140    output: &mut ArrayBase<S, Ix1>,
141) -> Result<(), StatsError>
142where
143    T: NabledReal,
144    S: DataMut<Elem = T>,
145{
146    column_means_into_impl(matrix, output)
147}
148
149fn center_columns_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
150    let means = column_means_impl(matrix);
151    let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
152    for row in 0..matrix.nrows() {
153        for col in 0..matrix.ncols() {
154            centered[[row, col]] = matrix[[row, col]] - means[col];
155        }
156    }
157    centered
158}
159
160/// Center columns by subtracting their means.
161#[must_use]
162pub fn center_columns<T: NabledReal>(matrix: &Array2<T>) -> Array2<T> {
163    center_columns_impl(&matrix.view())
164}
165
166/// Center columns by subtracting their means from a matrix view.
167#[must_use]
168pub fn center_columns_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
169    center_columns_impl(matrix)
170}
171
172fn center_columns_into_impl<T, S>(
173    matrix: &ArrayView2<'_, T>,
174    output: &mut ArrayBase<S, Ix2>,
175) -> Result<(), StatsError>
176where
177    T: NabledReal,
178    S: DataMut<Elem = T>,
179{
180    validate_matrix_output_shape(output, matrix.nrows(), matrix.ncols(), "center_columns")?;
181
182    let mut means = Array1::<T>::zeros(matrix.ncols());
183    column_means_into_impl(matrix, &mut means)?;
184
185    for row in 0..matrix.nrows() {
186        for col in 0..matrix.ncols() {
187            output[[row, col]] = matrix[[row, col]] - means[col];
188        }
189    }
190
191    Ok(())
192}
193
194/// Center columns by subtracting their means into caller-provided output.
195///
196/// # Errors
197/// Returns an error if `output` does not match the input shape.
198pub fn center_columns_into<T, S>(
199    matrix: &Array2<T>,
200    output: &mut ArrayBase<S, Ix2>,
201) -> Result<(), StatsError>
202where
203    T: NabledReal,
204    S: DataMut<Elem = T>,
205{
206    center_columns_into_impl(&matrix.view(), output)
207}
208
209/// Center columns by subtracting their means from a matrix view into caller-provided output.
210///
211/// # Errors
212/// Returns an error if `output` does not match the input shape.
213pub fn center_columns_view_into<T, S>(
214    matrix: &ArrayView2<'_, T>,
215    output: &mut ArrayBase<S, Ix2>,
216) -> Result<(), StatsError>
217where
218    T: NabledReal,
219    S: DataMut<Elem = T>,
220{
221    center_columns_into_impl(matrix, output)
222}
223
224fn covariance_matrix_impl<T: NabledReal>(
225    matrix: &ArrayView2<'_, T>,
226) -> Result<Array2<T>, StatsError> {
227    if matrix.is_empty() {
228        return Err(StatsError::EmptyMatrix);
229    }
230    if matrix.nrows() < 2 {
231        return Err(StatsError::InsufficientSamples);
232    }
233
234    let centered = center_columns_impl(matrix);
235    let covariance: Array2<T> =
236        centered.t().dot(&centered) / usize_to_scalar::<T>(matrix.nrows() - 1);
237
238    if covariance.iter().any(|value| !value.is_finite()) {
239        return Err(StatsError::NumericalInstability);
240    }
241
242    Ok(covariance)
243}
244
245/// Compute sample covariance matrix.
246///
247/// # Errors
248/// Returns an error for empty input or fewer than two samples.
249pub fn covariance_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
250    covariance_matrix_impl(&matrix.view())
251}
252
253/// Compute sample covariance matrix from a matrix view.
254///
255/// # Errors
256/// Returns an error for empty input or fewer than two samples.
257pub fn covariance_matrix_view<T: NabledReal>(
258    matrix: &ArrayView2<'_, T>,
259) -> Result<Array2<T>, StatsError> {
260    covariance_matrix_impl(matrix)
261}
262
263fn covariance_matrix_into_impl<T, S>(
264    matrix: &ArrayView2<'_, T>,
265    output: &mut ArrayBase<S, Ix2>,
266) -> Result<(), StatsError>
267where
268    T: NabledReal,
269    S: DataMut<Elem = T>,
270{
271    if matrix.is_empty() {
272        return Err(StatsError::EmptyMatrix);
273    }
274    if matrix.nrows() < 2 {
275        return Err(StatsError::InsufficientSamples);
276    }
277    validate_matrix_output_shape(output, matrix.ncols(), matrix.ncols(), "covariance_matrix")?;
278
279    let mut means = Array1::<T>::zeros(matrix.ncols());
280    column_means_into_impl(matrix, &mut means)?;
281    let denom = usize_to_scalar::<T>(matrix.nrows() - 1);
282
283    for i in 0..matrix.ncols() {
284        for j in i..matrix.ncols() {
285            let mut sum = T::zero();
286            for row in 0..matrix.nrows() {
287                let left = matrix[[row, i]] - means[i];
288                let right = matrix[[row, j]] - means[j];
289                sum += left * right;
290            }
291            let value = sum / denom;
292            output[[i, j]] = value;
293            if i != j {
294                output[[j, i]] = value;
295            }
296        }
297    }
298
299    if output.iter().any(|value| !value.is_finite()) {
300        return Err(StatsError::NumericalInstability);
301    }
302
303    Ok(())
304}
305
306/// Compute sample covariance matrix into caller-provided output.
307///
308/// # Errors
309/// Returns an error for empty input, fewer than two samples, or incompatible output shape.
310pub fn covariance_matrix_into<T, S>(
311    matrix: &Array2<T>,
312    output: &mut ArrayBase<S, Ix2>,
313) -> Result<(), StatsError>
314where
315    T: NabledReal,
316    S: DataMut<Elem = T>,
317{
318    covariance_matrix_into_impl(&matrix.view(), output)
319}
320
321/// Compute sample covariance matrix from a matrix view into caller-provided output.
322///
323/// # Errors
324/// Returns an error for empty input, fewer than two samples, or incompatible output shape.
325pub fn covariance_matrix_view_into<T, S>(
326    matrix: &ArrayView2<'_, T>,
327    output: &mut ArrayBase<S, Ix2>,
328) -> Result<(), StatsError>
329where
330    T: NabledReal,
331    S: DataMut<Elem = T>,
332{
333    covariance_matrix_into_impl(matrix, output)
334}
335
336fn correlation_matrix_impl<T: NabledReal>(
337    matrix: &ArrayView2<'_, T>,
338) -> Result<Array2<T>, StatsError> {
339    let covariance = covariance_matrix_impl(matrix)?;
340    let n = covariance.nrows();
341    let mut correlation = Array2::<T>::zeros((n, n));
342
343    for i in 0..n {
344        let sigma_i = covariance[[i, i]].sqrt();
345        for j in 0..n {
346            let sigma_j = covariance[[j, j]].sqrt();
347            let denom = (sigma_i * sigma_j).max(T::epsilon());
348            correlation[[i, j]] = covariance[[i, j]] / denom;
349        }
350    }
351
352    Ok(correlation)
353}
354
355/// Compute correlation matrix.
356///
357/// # Errors
358/// Returns an error if covariance computation fails.
359pub fn correlation_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
360    correlation_matrix_impl(&matrix.view())
361}
362
363/// Compute correlation matrix from a matrix view.
364///
365/// # Errors
366/// Returns an error if covariance computation fails.
367pub fn correlation_matrix_view<T: NabledReal>(
368    matrix: &ArrayView2<'_, T>,
369) -> Result<Array2<T>, StatsError> {
370    correlation_matrix_impl(matrix)
371}
372
373fn correlation_matrix_into_impl<T, S>(
374    matrix: &ArrayView2<'_, T>,
375    output: &mut ArrayBase<S, Ix2>,
376) -> Result<(), StatsError>
377where
378    T: NabledReal,
379    S: DataMut<Elem = T>,
380{
381    covariance_matrix_into_impl(matrix, output)?;
382    let mut sigmas = Array1::<T>::zeros(output.nrows());
383    for i in 0..output.nrows() {
384        sigmas[i] = output[[i, i]].sqrt();
385    }
386
387    for i in 0..output.nrows() {
388        let sigma_i = sigmas[i];
389        for j in 0..output.ncols() {
390            let sigma_j = sigmas[j];
391            let denom = (sigma_i * sigma_j).max(T::epsilon());
392            output[[i, j]] /= denom;
393        }
394    }
395
396    if output.iter().any(|value| !value.is_finite()) {
397        return Err(StatsError::NumericalInstability);
398    }
399
400    Ok(())
401}
402
403/// Compute correlation matrix into caller-provided output.
404///
405/// # Errors
406/// Returns an error if covariance computation fails or `output` shape is incompatible.
407pub fn correlation_matrix_into<T, S>(
408    matrix: &Array2<T>,
409    output: &mut ArrayBase<S, Ix2>,
410) -> Result<(), StatsError>
411where
412    T: NabledReal,
413    S: DataMut<Elem = T>,
414{
415    correlation_matrix_into_impl(&matrix.view(), output)
416}
417
418/// Compute correlation matrix from a matrix view into caller-provided output.
419///
420/// # Errors
421/// Returns an error if covariance computation fails or `output` shape is incompatible.
422pub fn correlation_matrix_view_into<T, S>(
423    matrix: &ArrayView2<'_, T>,
424    output: &mut ArrayBase<S, Ix2>,
425) -> Result<(), StatsError>
426where
427    T: NabledReal,
428    S: DataMut<Elem = T>,
429{
430    correlation_matrix_into_impl(matrix, output)
431}
432
433fn column_means_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
434    if matrix.nrows() == 0 {
435        return Array1::zeros(matrix.ncols());
436    }
437
438    let mut means = Array1::<Complex64>::zeros(matrix.ncols());
439    for col in 0..matrix.ncols() {
440        let mut sum = Complex64::new(0.0, 0.0);
441        for row in 0..matrix.nrows() {
442            sum += matrix[[row, col]];
443        }
444        means[col] = sum / usize_to_scalar::<f64>(matrix.nrows());
445    }
446    means
447}
448
449/// Compute complex column means.
450#[must_use]
451pub fn column_means_complex(matrix: &Array2<Complex64>) -> Array1<Complex64> {
452    column_means_complex_impl(&matrix.view())
453}
454
455/// Compute complex column means from a matrix view.
456#[must_use]
457pub fn column_means_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
458    column_means_complex_impl(matrix)
459}
460
461fn column_means_complex_into_impl<S>(
462    matrix: &ArrayView2<'_, Complex64>,
463    output: &mut ArrayBase<S, Ix1>,
464) -> Result<(), StatsError>
465where
466    S: DataMut<Elem = Complex64>,
467{
468    validate_vector_output_len(output, matrix.ncols(), "column_means_complex")?;
469
470    if matrix.nrows() == 0 {
471        output.fill(Complex64::new(0.0, 0.0));
472        return Ok(());
473    }
474
475    let denom = usize_to_scalar::<f64>(matrix.nrows());
476    for col in 0..matrix.ncols() {
477        let mut sum = Complex64::new(0.0, 0.0);
478        for row in 0..matrix.nrows() {
479            sum += matrix[[row, col]];
480        }
481        output[col] = sum / denom;
482    }
483
484    Ok(())
485}
486
487/// Compute complex column means into caller-provided output.
488///
489/// # Errors
490/// Returns an error if `output` does not match the column count.
491pub fn column_means_complex_into<S>(
492    matrix: &Array2<Complex64>,
493    output: &mut ArrayBase<S, Ix1>,
494) -> Result<(), StatsError>
495where
496    S: DataMut<Elem = Complex64>,
497{
498    column_means_complex_into_impl(&matrix.view(), output)
499}
500
501/// Compute complex column means from a matrix view into caller-provided output.
502///
503/// # Errors
504/// Returns an error if `output` does not match the column count.
505pub fn column_means_complex_view_into<S>(
506    matrix: &ArrayView2<'_, Complex64>,
507    output: &mut ArrayBase<S, Ix1>,
508) -> Result<(), StatsError>
509where
510    S: DataMut<Elem = Complex64>,
511{
512    column_means_complex_into_impl(matrix, output)
513}
514
515fn center_columns_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
516    let means = column_means_complex_impl(matrix);
517    let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
518    for row in 0..matrix.nrows() {
519        for col in 0..matrix.ncols() {
520            centered[[row, col]] = matrix[[row, col]] - means[col];
521        }
522    }
523    centered
524}
525
526/// Center complex columns by subtracting their means.
527#[must_use]
528pub fn center_columns_complex(matrix: &Array2<Complex64>) -> Array2<Complex64> {
529    center_columns_complex_impl(&matrix.view())
530}
531
532/// Center complex columns by subtracting their means from a matrix view.
533#[must_use]
534pub fn center_columns_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
535    center_columns_complex_impl(matrix)
536}
537
538fn center_columns_complex_into_impl<S>(
539    matrix: &ArrayView2<'_, Complex64>,
540    output: &mut ArrayBase<S, Ix2>,
541) -> Result<(), StatsError>
542where
543    S: DataMut<Elem = Complex64>,
544{
545    validate_matrix_output_shape(output, matrix.nrows(), matrix.ncols(), "center_columns_complex")?;
546
547    let mut means = Array1::<Complex64>::zeros(matrix.ncols());
548    column_means_complex_into_impl(matrix, &mut means)?;
549
550    for row in 0..matrix.nrows() {
551        for col in 0..matrix.ncols() {
552            output[[row, col]] = matrix[[row, col]] - means[col];
553        }
554    }
555
556    Ok(())
557}
558
559/// Center complex columns by subtracting their means into caller-provided output.
560///
561/// # Errors
562/// Returns an error if `output` does not match the input shape.
563pub fn center_columns_complex_into<S>(
564    matrix: &Array2<Complex64>,
565    output: &mut ArrayBase<S, Ix2>,
566) -> Result<(), StatsError>
567where
568    S: DataMut<Elem = Complex64>,
569{
570    center_columns_complex_into_impl(&matrix.view(), output)
571}
572
573/// Center complex columns by subtracting their means from a matrix view into caller-provided
574/// output.
575///
576/// # Errors
577/// Returns an error if `output` does not match the input shape.
578pub fn center_columns_complex_view_into<S>(
579    matrix: &ArrayView2<'_, Complex64>,
580    output: &mut ArrayBase<S, Ix2>,
581) -> Result<(), StatsError>
582where
583    S: DataMut<Elem = Complex64>,
584{
585    center_columns_complex_into_impl(matrix, output)
586}
587
588fn covariance_matrix_complex_impl(
589    matrix: &ArrayView2<'_, Complex64>,
590) -> Result<Array2<Complex64>, StatsError> {
591    if matrix.is_empty() {
592        return Err(StatsError::EmptyMatrix);
593    }
594    if matrix.nrows() < 2 {
595        return Err(StatsError::InsufficientSamples);
596    }
597
598    let centered = center_columns_complex_impl(matrix);
599    let conjugate_transpose = centered.t().mapv(|value| value.conj());
600    let covariance: Array2<Complex64> =
601        conjugate_transpose.dot(&centered) / usize_to_scalar::<f64>(matrix.nrows() - 1);
602
603    if covariance.iter().any(|value| !complex_is_finite(*value)) {
604        return Err(StatsError::NumericalInstability);
605    }
606
607    Ok(covariance)
608}
609
610/// Compute sample covariance matrix for complex observations.
611///
612/// # Errors
613/// Returns an error for empty input or fewer than two samples.
614pub fn covariance_matrix_complex(
615    matrix: &Array2<Complex64>,
616) -> Result<Array2<Complex64>, StatsError> {
617    covariance_matrix_complex_impl(&matrix.view())
618}
619
620/// Compute sample covariance matrix for complex observations from a matrix view.
621///
622/// # Errors
623/// Returns an error for empty input or fewer than two samples.
624pub fn covariance_matrix_complex_view(
625    matrix: &ArrayView2<'_, Complex64>,
626) -> Result<Array2<Complex64>, StatsError> {
627    covariance_matrix_complex_impl(matrix)
628}
629
630fn covariance_matrix_complex_into_impl<S>(
631    matrix: &ArrayView2<'_, Complex64>,
632    output: &mut ArrayBase<S, Ix2>,
633) -> Result<(), StatsError>
634where
635    S: DataMut<Elem = Complex64>,
636{
637    if matrix.is_empty() {
638        return Err(StatsError::EmptyMatrix);
639    }
640    if matrix.nrows() < 2 {
641        return Err(StatsError::InsufficientSamples);
642    }
643    validate_matrix_output_shape(
644        output,
645        matrix.ncols(),
646        matrix.ncols(),
647        "covariance_matrix_complex",
648    )?;
649
650    let mut means = Array1::<Complex64>::zeros(matrix.ncols());
651    column_means_complex_into_impl(matrix, &mut means)?;
652    let denom = usize_to_scalar::<f64>(matrix.nrows() - 1);
653
654    for i in 0..matrix.ncols() {
655        for j in i..matrix.ncols() {
656            let mut sum = Complex64::new(0.0, 0.0);
657            for row in 0..matrix.nrows() {
658                let left = matrix[[row, i]] - means[i];
659                let right = matrix[[row, j]] - means[j];
660                sum += left.conj() * right;
661            }
662            let value = sum / denom;
663            output[[i, j]] = value;
664            if i != j {
665                output[[j, i]] = value.conj();
666            }
667        }
668    }
669
670    if output.iter().any(|value| !complex_is_finite(*value)) {
671        return Err(StatsError::NumericalInstability);
672    }
673
674    Ok(())
675}
676
677/// Compute sample covariance matrix for complex observations into caller-provided output.
678///
679/// # Errors
680/// Returns an error for empty input, fewer than two samples, or incompatible output shape.
681pub fn covariance_matrix_complex_into<S>(
682    matrix: &Array2<Complex64>,
683    output: &mut ArrayBase<S, Ix2>,
684) -> Result<(), StatsError>
685where
686    S: DataMut<Elem = Complex64>,
687{
688    covariance_matrix_complex_into_impl(&matrix.view(), output)
689}
690
691/// Compute sample covariance matrix for complex observations from a matrix view into
692/// caller-provided output.
693///
694/// # Errors
695/// Returns an error for empty input, fewer than two samples, or incompatible output shape.
696pub fn covariance_matrix_complex_view_into<S>(
697    matrix: &ArrayView2<'_, Complex64>,
698    output: &mut ArrayBase<S, Ix2>,
699) -> Result<(), StatsError>
700where
701    S: DataMut<Elem = Complex64>,
702{
703    covariance_matrix_complex_into_impl(matrix, output)
704}
705
706fn correlation_matrix_complex_impl(
707    matrix: &ArrayView2<'_, Complex64>,
708) -> Result<Array2<Complex64>, StatsError> {
709    let covariance = covariance_matrix_complex_impl(matrix)?;
710    let n = covariance.nrows();
711    let mut correlation = Array2::<Complex64>::zeros((n, n));
712
713    for i in 0..n {
714        let sigma_i = covariance[[i, i]].re.max(0.0).sqrt();
715        for j in 0..n {
716            let sigma_j = covariance[[j, j]].re.max(0.0).sqrt();
717            let denom = (sigma_i * sigma_j).max(f64::EPSILON);
718            correlation[[i, j]] = covariance[[i, j]] / denom;
719        }
720    }
721
722    if correlation.iter().any(|value| !complex_is_finite(*value)) {
723        return Err(StatsError::NumericalInstability);
724    }
725
726    Ok(correlation)
727}
728
729/// Compute correlation matrix for complex observations.
730///
731/// # Errors
732/// Returns an error if covariance computation fails.
733pub fn correlation_matrix_complex(
734    matrix: &Array2<Complex64>,
735) -> Result<Array2<Complex64>, StatsError> {
736    correlation_matrix_complex_impl(&matrix.view())
737}
738
739/// Compute correlation matrix for complex observations from a matrix view.
740///
741/// # Errors
742/// Returns an error if covariance computation fails.
743pub fn correlation_matrix_complex_view(
744    matrix: &ArrayView2<'_, Complex64>,
745) -> Result<Array2<Complex64>, StatsError> {
746    correlation_matrix_complex_impl(matrix)
747}
748
749fn correlation_matrix_complex_into_impl<S>(
750    matrix: &ArrayView2<'_, Complex64>,
751    output: &mut ArrayBase<S, Ix2>,
752) -> Result<(), StatsError>
753where
754    S: DataMut<Elem = Complex64>,
755{
756    covariance_matrix_complex_into_impl(matrix, output)?;
757    let mut sigmas = Array1::<f64>::zeros(output.nrows());
758    for i in 0..output.nrows() {
759        sigmas[i] = output[[i, i]].re.max(0.0).sqrt();
760    }
761
762    for i in 0..output.nrows() {
763        let sigma_i = sigmas[i];
764        for j in 0..output.ncols() {
765            let sigma_j = sigmas[j];
766            let denom = (sigma_i * sigma_j).max(f64::EPSILON);
767            output[[i, j]] /= denom;
768        }
769    }
770
771    if output.iter().any(|value| !complex_is_finite(*value)) {
772        return Err(StatsError::NumericalInstability);
773    }
774
775    Ok(())
776}
777
778/// Compute correlation matrix for complex observations into caller-provided output.
779///
780/// # Errors
781/// Returns an error if covariance computation fails or `output` shape is incompatible.
782pub fn correlation_matrix_complex_into<S>(
783    matrix: &Array2<Complex64>,
784    output: &mut ArrayBase<S, Ix2>,
785) -> Result<(), StatsError>
786where
787    S: DataMut<Elem = Complex64>,
788{
789    correlation_matrix_complex_into_impl(&matrix.view(), output)
790}
791
792/// Compute correlation matrix for complex observations from a matrix view into caller-provided
793/// output.
794///
795/// # Errors
796/// Returns an error if covariance computation fails or `output` shape is incompatible.
797pub fn correlation_matrix_complex_view_into<S>(
798    matrix: &ArrayView2<'_, Complex64>,
799    output: &mut ArrayBase<S, Ix2>,
800) -> Result<(), StatsError>
801where
802    S: DataMut<Elem = Complex64>,
803{
804    correlation_matrix_complex_into_impl(matrix, output)
805}
806
807// --- Physical AI streaming / rolling extensions ---
808
809pub mod online {
810    #![allow(clippy::missing_errors_doc)]
811    use nabled_core::scalar::NabledReal;
812
813    /// Online mean accumulator (Welford-style count/mean).
814    #[derive(Debug, Clone, PartialEq)]
815    pub struct OnlineMean<T> {
816        count: usize,
817        mean:  T,
818    }
819
820    impl<T: NabledReal> Default for OnlineMean<T> {
821        fn default() -> Self { Self { count: 0, mean: T::zero() } }
822    }
823
824    impl<T: NabledReal> OnlineMean<T> {
825        pub fn push(&mut self, value: T) {
826            self.count += 1;
827            let n = T::from_usize(self.count).unwrap_or(T::one());
828            self.mean += (value - self.mean) / n;
829        }
830
831        #[must_use]
832        pub fn mean(&self) -> T { self.mean }
833
834        pub fn reset(&mut self) {
835            self.count = 0;
836            self.mean = T::zero();
837        }
838    }
839
840    /// Online variance accumulator (Welford).
841    #[derive(Debug, Clone, PartialEq)]
842    pub struct OnlineVariance<T> {
843        count: usize,
844        mean:  T,
845        m2:    T,
846    }
847
848    impl<T: NabledReal> Default for OnlineVariance<T> {
849        fn default() -> Self { Self { count: 0, mean: T::zero(), m2: T::zero() } }
850    }
851
852    impl<T: NabledReal> OnlineVariance<T> {
853        pub fn push(&mut self, value: T) {
854            self.count += 1;
855            let n = T::from_usize(self.count).unwrap_or(T::one());
856            let delta = value - self.mean;
857            self.mean += delta / n;
858            let delta2 = value - self.mean;
859            self.m2 += delta * delta2;
860        }
861
862        #[must_use]
863        pub fn mean(&self) -> T { self.mean }
864
865        #[must_use]
866        pub fn variance(&self) -> T {
867            if self.count < 2 {
868                return T::zero();
869            }
870            self.m2 / T::from_usize(self.count - 1).unwrap_or(T::one())
871        }
872
873        pub fn reset(&mut self) {
874            self.count = 0;
875            self.mean = T::zero();
876            self.m2 = T::zero();
877        }
878    }
879}
880
881pub mod ewma {
882    #![allow(clippy::missing_errors_doc)]
883    use nabled_core::scalar::NabledReal;
884    use ndarray::{Array1, ArrayBase, ArrayView1, DataMut, Ix1};
885
886    use super::StatsError;
887
888    /// Incremental EWMA state.
889    #[derive(Debug, Clone, PartialEq)]
890    pub struct EwmaState<T> {
891        alpha: T,
892        value: Option<T>,
893    }
894
895    impl<T: NabledReal> EwmaState<T> {
896        pub fn new(alpha: T) -> Self { Self { alpha, value: None } }
897
898        pub fn push(&mut self, sample: T) -> T {
899            if let Some(prev) = self.value {
900                let next = self.alpha * sample + (T::one() - self.alpha) * prev;
901                self.value = Some(next);
902                next
903            } else {
904                self.value = Some(sample);
905                sample
906            }
907        }
908    }
909
910    #[must_use]
911    pub fn ewma<T: NabledReal>(signal: &Array1<T>, alpha: T) -> Array1<T> {
912        ewma_view(&signal.view(), alpha)
913    }
914
915    #[must_use]
916    pub fn ewma_view<T: NabledReal>(signal: &ArrayView1<'_, T>, alpha: T) -> Array1<T> {
917        let mut out = Array1::<T>::zeros(signal.len());
918        drop(ewma_into(signal, alpha, &mut out));
919        out
920    }
921
922    pub fn ewma_into<T, S>(
923        signal: &ArrayView1<'_, T>,
924        alpha: T,
925        output: &mut ArrayBase<S, Ix1>,
926    ) -> Result<(), StatsError>
927    where
928        T: NabledReal,
929        S: DataMut<Elem = T>,
930    {
931        if output.len() != signal.len() {
932            return Err(StatsError::InvalidInput("ewma output length mismatch".to_string()));
933        }
934        let mut state = EwmaState::new(alpha);
935        for (i, &sample) in signal.iter().enumerate() {
936            output[i] = state.push(sample);
937        }
938        Ok(())
939    }
940}
941
942pub mod rolling {
943    #![allow(clippy::missing_errors_doc)]
944    use nabled_core::scalar::NabledReal;
945    use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, DataMut, Ix1, Ix2};
946
947    use super::StatsError;
948
949    #[must_use]
950    pub fn rolling_mean<T: NabledReal>(signal: &ArrayView1<'_, T>, window: usize) -> Array1<T> {
951        let mut out = Array1::<T>::zeros(signal.len());
952        drop(rolling_mean_into(signal, window, &mut out));
953        out
954    }
955
956    pub fn rolling_mean_view<T: NabledReal>(
957        signal: &ArrayView1<'_, T>,
958        window: usize,
959    ) -> Array1<T> {
960        rolling_mean(signal, window)
961    }
962
963    pub fn rolling_mean_into<T, S>(
964        signal: &ArrayView1<'_, T>,
965        window: usize,
966        output: &mut ArrayBase<S, Ix1>,
967    ) -> Result<(), StatsError>
968    where
969        T: NabledReal,
970        S: DataMut<Elem = T>,
971    {
972        if window == 0 {
973            return Err(StatsError::InvalidInput("window must be positive".to_string()));
974        }
975        if output.len() != signal.len() {
976            return Err(StatsError::InvalidInput(
977                "rolling_mean output length mismatch".to_string(),
978            ));
979        }
980        for i in 0..signal.len() {
981            let start = i.saturating_sub(window - 1);
982            let slice = signal.slice(ndarray::s![start..=i]);
983            let sum = slice.iter().fold(T::zero(), |acc, v| acc + *v);
984            let count = T::from_usize(slice.len()).unwrap_or(T::one());
985            output[i] = sum / count;
986        }
987        Ok(())
988    }
989
990    #[must_use]
991    pub fn rolling_variance<T: NabledReal>(signal: &ArrayView1<'_, T>, window: usize) -> Array1<T> {
992        let mut out = Array1::<T>::zeros(signal.len());
993        drop(rolling_variance_into(signal, window, &mut out));
994        out
995    }
996
997    pub fn rolling_variance_into<T, S>(
998        signal: &ArrayView1<'_, T>,
999        window: usize,
1000        output: &mut ArrayBase<S, Ix1>,
1001    ) -> Result<(), StatsError>
1002    where
1003        T: NabledReal,
1004        S: DataMut<Elem = T>,
1005    {
1006        if window == 0 {
1007            return Err(StatsError::InvalidInput("window must be positive".to_string()));
1008        }
1009        if output.len() != signal.len() {
1010            return Err(StatsError::InvalidInput(
1011                "rolling_variance output length mismatch".to_string(),
1012            ));
1013        }
1014        for i in 0..signal.len() {
1015            let start = i.saturating_sub(window - 1);
1016            let slice = signal.slice(ndarray::s![start..=i]);
1017            let mean = slice.iter().fold(T::zero(), |acc, v| acc + *v)
1018                / T::from_usize(slice.len()).unwrap_or(T::one());
1019            let var = slice
1020                .iter()
1021                .map(|v| {
1022                    let d = *v - mean;
1023                    d * d
1024                })
1025                .fold(T::zero(), |acc, v| acc + v)
1026                / T::from_usize(slice.len()).unwrap_or(T::one());
1027            output[i] = var;
1028        }
1029        Ok(())
1030    }
1031
1032    #[must_use]
1033    pub fn rolling_covariance<T: NabledReal>(
1034        matrix: &ArrayView2<'_, T>,
1035        window: usize,
1036    ) -> Array2<T> {
1037        let mut out = Array2::<T>::zeros((matrix.nrows(), matrix.ncols() * matrix.ncols()));
1038        drop(rolling_covariance_into(matrix, window, &mut out));
1039        out
1040    }
1041
1042    pub fn rolling_covariance_view<T: NabledReal>(
1043        matrix: &ArrayView2<'_, T>,
1044        window: usize,
1045    ) -> Array2<T> {
1046        rolling_covariance(matrix, window)
1047    }
1048
1049    pub fn rolling_covariance_into<T, S>(
1050        matrix: &ArrayView2<'_, T>,
1051        window: usize,
1052        output: &mut ArrayBase<S, Ix2>,
1053    ) -> Result<(), StatsError>
1054    where
1055        T: NabledReal,
1056        S: DataMut<Elem = T>,
1057    {
1058        let cols = matrix.ncols();
1059        if output.nrows() != matrix.nrows() || output.ncols() != cols * cols {
1060            return Err(StatsError::InvalidInput(
1061                "rolling_covariance output shape mismatch".to_string(),
1062            ));
1063        }
1064        if window == 0 {
1065            return Err(StatsError::InvalidInput("window must be positive".to_string()));
1066        }
1067        for row in 0..matrix.nrows() {
1068            let start = row.saturating_sub(window - 1);
1069            let block = matrix.slice(ndarray::s![start..=row, ..]);
1070            let cov = if block.nrows() >= 2 {
1071                super::covariance_matrix_view(&block)?
1072            } else {
1073                Array2::<T>::zeros((cols, cols))
1074            };
1075            for i in 0..cols {
1076                for j in 0..cols {
1077                    output[[row, i * cols + j]] = cov[[i, j]];
1078                }
1079            }
1080        }
1081        Ok(())
1082    }
1083}
1084
1085pub mod lag {
1086    #![allow(clippy::missing_errors_doc)]
1087    use nabled_core::scalar::NabledReal;
1088    use ndarray::{Array2, ArrayView2};
1089
1090    use super::StatsError;
1091
1092    /// Zero-copy lag view: rows `[0..n-lag)` of `matrix` aligned with rows `[lag..n)`.
1093    pub fn lag_view<'a, T>(
1094        matrix: &'a ArrayView2<'_, T>,
1095        lag: usize,
1096    ) -> Result<ArrayView2<'a, T>, StatsError> {
1097        if lag >= matrix.nrows() {
1098            return Err(StatsError::InvalidInput(format!(
1099                "lag {lag} must be less than row count {}",
1100                matrix.nrows()
1101            )));
1102        }
1103        Ok(matrix.slice(ndarray::s![..matrix.nrows() - lag, ..]))
1104    }
1105
1106    /// Shift columns down by `lag` rows, filling top rows with zero.
1107    pub fn shift_columns_into<T: NabledReal>(
1108        matrix: &ArrayView2<'_, T>,
1109        lag: usize,
1110        output: &mut Array2<T>,
1111    ) -> Result<(), StatsError> {
1112        if output.dim() != matrix.dim() {
1113            return Err(StatsError::InvalidInput(
1114                "shift_columns_into output shape mismatch".to_string(),
1115            ));
1116        }
1117        output.fill(T::zero());
1118        if lag >= matrix.nrows() {
1119            return Ok(());
1120        }
1121        let rows = matrix.nrows() - lag;
1122        output.slice_mut(ndarray::s![lag.., ..]).assign(&matrix.slice(ndarray::s![..rows, ..]));
1123        Ok(())
1124    }
1125}
1126
1127#[cfg(test)]
1128mod tests {
1129    use ndarray::{Array1, Array2};
1130    use num_complex::Complex64;
1131
1132    use super::*;
1133
1134    #[test]
1135    fn covariance_and_correlation_are_well_formed() {
1136        let matrix =
1137            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1138                .unwrap();
1139        let covariance = covariance_matrix(&matrix).unwrap();
1140        let correlation = correlation_matrix(&matrix).unwrap();
1141        assert_eq!(covariance.dim(), (2, 2));
1142        assert_eq!(correlation.dim(), (2, 2));
1143    }
1144
1145    #[test]
1146    fn stats_rejects_empty_and_insufficient_inputs() {
1147        let empty = Array2::<f64>::zeros((0, 0));
1148        assert!(matches!(covariance_matrix(&empty), Err(StatsError::EmptyMatrix)));
1149
1150        let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
1151        assert!(matches!(covariance_matrix(&one_row), Err(StatsError::InsufficientSamples)));
1152    }
1153
1154    #[test]
1155    fn center_columns_zeroes_means() {
1156        let matrix =
1157            Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1158        let centered = center_columns(&matrix);
1159        let means = column_means(&centered);
1160        assert!(means.iter().all(|value| num_traits::Float::abs(*value) < 1e-12));
1161    }
1162
1163    #[test]
1164    fn column_means_handles_empty_input() {
1165        let matrix = Array2::<f64>::zeros((0, 3));
1166        let means = column_means(&matrix);
1167        assert_eq!(means.len(), 3);
1168        assert!(means.iter().all(|value| *value == 0.0));
1169    }
1170
1171    #[test]
1172    fn covariance_reports_numerical_instability() {
1173        let matrix = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
1174        let result = covariance_matrix(&matrix);
1175        assert!(matches!(result, Err(StatsError::NumericalInstability)));
1176    }
1177
1178    #[test]
1179    fn correlation_handles_zero_variance_column() {
1180        let matrix =
1181            Array2::from_shape_vec((3, 2), vec![1.0_f64, 10.0, 1.0, 20.0, 1.0, 30.0]).unwrap();
1182        let correlation = correlation_matrix(&matrix).unwrap();
1183        assert!(correlation[[0, 0]].is_finite());
1184        assert!(correlation[[0, 1]].is_finite());
1185        assert!(correlation[[1, 0]].is_finite());
1186        assert!(correlation[[1, 1]].is_finite());
1187    }
1188
1189    #[test]
1190    fn view_variants_match_owned() {
1191        let matrix =
1192            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1193                .unwrap();
1194        let means_owned = column_means(&matrix);
1195        let means_view = column_means_view(&matrix.view());
1196        let centered_owned = center_columns(&matrix);
1197        let centered_view = center_columns_view(&matrix.view());
1198        let covariance_owned = covariance_matrix(&matrix).unwrap();
1199        let covariance_view = covariance_matrix_view(&matrix.view()).unwrap();
1200        let correlation_owned = correlation_matrix(&matrix).unwrap();
1201        let correlation_view = correlation_matrix_view(&matrix.view()).unwrap();
1202
1203        for i in 0..means_owned.len() {
1204            assert!((means_owned[i] - means_view[i]).abs() < 1e-12);
1205        }
1206        for i in 0..matrix.nrows() {
1207            for j in 0..matrix.ncols() {
1208                assert!((centered_owned[[i, j]] - centered_view[[i, j]]).abs() < 1e-12);
1209            }
1210        }
1211        for i in 0..2 {
1212            for j in 0..2 {
1213                assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).abs() < 1e-12);
1214                assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).abs() < 1e-12);
1215            }
1216        }
1217    }
1218
1219    #[test]
1220    fn complex_covariance_and_correlation_are_well_formed() {
1221        let matrix = Array2::from_shape_vec((4, 2), vec![
1222            Complex64::new(1.0, 0.0),
1223            Complex64::new(3.0, 1.0),
1224            Complex64::new(2.0, -1.0),
1225            Complex64::new(2.0, 0.5),
1226            Complex64::new(3.0, 0.2),
1227            Complex64::new(1.0, -0.3),
1228            Complex64::new(4.0, 0.7),
1229            Complex64::new(0.0, 0.0),
1230        ])
1231        .unwrap();
1232
1233        let covariance = covariance_matrix_complex(&matrix).unwrap();
1234        let correlation = correlation_matrix_complex(&matrix).unwrap();
1235        assert_eq!(covariance.dim(), (2, 2));
1236        assert_eq!(correlation.dim(), (2, 2));
1237    }
1238
1239    #[test]
1240    fn complex_view_variants_match_owned() {
1241        let matrix = Array2::from_shape_vec((3, 2), vec![
1242            Complex64::new(1.0, 1.0),
1243            Complex64::new(2.0, -1.0),
1244            Complex64::new(2.0, 2.0),
1245            Complex64::new(3.0, 0.0),
1246            Complex64::new(3.0, -2.0),
1247            Complex64::new(4.0, 1.0),
1248        ])
1249        .unwrap();
1250
1251        let means_owned = column_means_complex(&matrix);
1252        let means_view = column_means_complex_view(&matrix.view());
1253        let centered_owned = center_columns_complex(&matrix);
1254        let centered_view = center_columns_complex_view(&matrix.view());
1255        let covariance_owned = covariance_matrix_complex(&matrix).unwrap();
1256        let covariance_view = covariance_matrix_complex_view(&matrix.view()).unwrap();
1257        let correlation_owned = correlation_matrix_complex(&matrix).unwrap();
1258        let correlation_view = correlation_matrix_complex_view(&matrix.view()).unwrap();
1259
1260        for i in 0..means_owned.len() {
1261            assert!((means_owned[i] - means_view[i]).norm() < 1e-12);
1262        }
1263        for i in 0..matrix.nrows() {
1264            for j in 0..matrix.ncols() {
1265                assert!((centered_owned[[i, j]] - centered_view[[i, j]]).norm() < 1e-12);
1266            }
1267        }
1268        for i in 0..2 {
1269            for j in 0..2 {
1270                assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).norm() < 1e-12);
1271                assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).norm() < 1e-12);
1272            }
1273        }
1274    }
1275
1276    #[test]
1277    fn stats_view_into_reuses_outputs() {
1278        let matrix =
1279            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1280                .unwrap();
1281
1282        let mut means = Array1::<f64>::zeros(2);
1283        let mut centered = Array2::<f64>::zeros((4, 2));
1284        let mut covariance = Array2::<f64>::zeros((2, 2));
1285        let mut correlation = Array2::<f64>::zeros((2, 2));
1286
1287        column_means_view_into(&matrix.view(), &mut means).unwrap();
1288        center_columns_view_into(&matrix.view(), &mut centered).unwrap();
1289        covariance_matrix_view_into(&matrix.view(), &mut covariance).unwrap();
1290        correlation_matrix_view_into(&matrix.view(), &mut correlation).unwrap();
1291
1292        assert_eq!(means, column_means(&matrix));
1293        assert_eq!(centered, center_columns(&matrix));
1294        assert_eq!(covariance, covariance_matrix(&matrix).unwrap());
1295        assert_eq!(correlation, correlation_matrix(&matrix).unwrap());
1296    }
1297
1298    #[test]
1299    fn complex_stats_view_into_reuses_outputs() {
1300        let matrix = Array2::from_shape_vec((3, 2), vec![
1301            Complex64::new(1.0, 1.0),
1302            Complex64::new(2.0, -1.0),
1303            Complex64::new(2.0, 2.0),
1304            Complex64::new(3.0, 0.0),
1305            Complex64::new(3.0, -2.0),
1306            Complex64::new(4.0, 1.0),
1307        ])
1308        .unwrap();
1309
1310        let mut means = Array1::<Complex64>::zeros(2);
1311        let mut centered = Array2::<Complex64>::zeros((3, 2));
1312        let mut covariance = Array2::<Complex64>::zeros((2, 2));
1313        let mut correlation = Array2::<Complex64>::zeros((2, 2));
1314
1315        column_means_complex_view_into(&matrix.view(), &mut means).unwrap();
1316        center_columns_complex_view_into(&matrix.view(), &mut centered).unwrap();
1317        covariance_matrix_complex_view_into(&matrix.view(), &mut covariance).unwrap();
1318        correlation_matrix_complex_view_into(&matrix.view(), &mut correlation).unwrap();
1319
1320        assert_eq!(means, column_means_complex(&matrix));
1321        assert_eq!(centered, center_columns_complex(&matrix));
1322        assert_eq!(covariance, covariance_matrix_complex(&matrix).unwrap());
1323        assert_eq!(correlation, correlation_matrix_complex(&matrix).unwrap());
1324    }
1325
1326    #[test]
1327    fn stats_owned_into_paths_cover_empty_valid_and_error_cases() {
1328        assert_eq!(
1329            StatsError::InvalidInput("bad shape".to_string()).to_string(),
1330            "Invalid input: bad shape"
1331        );
1332
1333        let matrix =
1334            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1335                .unwrap();
1336        let mut means = Array1::<f64>::zeros(2);
1337        let mut centered = Array2::<f64>::zeros((4, 2));
1338        let mut covariance = Array2::<f64>::zeros((2, 2));
1339        let mut correlation = Array2::<f64>::zeros((2, 2));
1340
1341        column_means_into(&matrix, &mut means).unwrap();
1342        center_columns_into(&matrix, &mut centered).unwrap();
1343        covariance_matrix_into(&matrix, &mut covariance).unwrap();
1344        correlation_matrix_into(&matrix, &mut correlation).unwrap();
1345
1346        assert_eq!(means, column_means(&matrix));
1347        assert_eq!(centered, center_columns(&matrix));
1348        assert_eq!(covariance, covariance_matrix(&matrix).unwrap());
1349        assert_eq!(correlation, correlation_matrix(&matrix).unwrap());
1350
1351        let empty_columns = Array2::<f64>::zeros((0, 3));
1352        let mut empty_means = Array1::<f64>::from_vec(vec![1.0, 1.0, 1.0]);
1353        column_means_into(&empty_columns, &mut empty_means).unwrap();
1354        assert!(empty_means.iter().all(|value| *value == 0.0));
1355
1356        let mut bad_means = Array1::<f64>::zeros(3);
1357        assert!(matches!(
1358            column_means_into(&matrix, &mut bad_means),
1359            Err(StatsError::InvalidInput(_))
1360        ));
1361
1362        let mut bad_centered = Array2::<f64>::zeros((4, 3));
1363        assert!(matches!(
1364            center_columns_into(&matrix, &mut bad_centered),
1365            Err(StatsError::InvalidInput(_))
1366        ));
1367
1368        let empty = Array2::<f64>::zeros((0, 0));
1369        let mut empty_covariance = Array2::<f64>::zeros((0, 0));
1370        assert!(matches!(
1371            covariance_matrix_into(&empty, &mut empty_covariance),
1372            Err(StatsError::EmptyMatrix)
1373        ));
1374
1375        let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
1376        let mut one_row_covariance = Array2::<f64>::zeros((2, 2));
1377        assert!(matches!(
1378            covariance_matrix_into(&one_row, &mut one_row_covariance),
1379            Err(StatsError::InsufficientSamples)
1380        ));
1381
1382        let mut bad_covariance = Array2::<f64>::zeros((3, 3));
1383        assert!(matches!(
1384            covariance_matrix_into(&matrix, &mut bad_covariance),
1385            Err(StatsError::InvalidInput(_))
1386        ));
1387        assert!(matches!(
1388            correlation_matrix_into(&matrix, &mut bad_covariance),
1389            Err(StatsError::InvalidInput(_))
1390        ));
1391
1392        let unstable = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
1393        let mut unstable_covariance = Array2::<f64>::zeros((2, 2));
1394        assert!(matches!(
1395            covariance_matrix_into(&unstable, &mut unstable_covariance),
1396            Err(StatsError::NumericalInstability)
1397        ));
1398    }
1399
1400    #[test]
1401    fn complex_stats_owned_into_paths_cover_empty_valid_and_error_cases() {
1402        let matrix = Array2::from_shape_vec((3, 2), vec![
1403            Complex64::new(1.0, 1.0),
1404            Complex64::new(2.0, -1.0),
1405            Complex64::new(2.0, 2.0),
1406            Complex64::new(3.0, 0.0),
1407            Complex64::new(3.0, -2.0),
1408            Complex64::new(4.0, 1.0),
1409        ])
1410        .unwrap();
1411        let mut means = Array1::<Complex64>::zeros(2);
1412        let mut centered = Array2::<Complex64>::zeros((3, 2));
1413        let mut covariance = Array2::<Complex64>::zeros((2, 2));
1414        let mut correlation = Array2::<Complex64>::zeros((2, 2));
1415
1416        column_means_complex_into(&matrix, &mut means).unwrap();
1417        center_columns_complex_into(&matrix, &mut centered).unwrap();
1418        covariance_matrix_complex_into(&matrix, &mut covariance).unwrap();
1419        correlation_matrix_complex_into(&matrix, &mut correlation).unwrap();
1420
1421        assert_eq!(means, column_means_complex(&matrix));
1422        assert_eq!(centered, center_columns_complex(&matrix));
1423        assert_eq!(covariance, covariance_matrix_complex(&matrix).unwrap());
1424        assert_eq!(correlation, correlation_matrix_complex(&matrix).unwrap());
1425
1426        let empty_columns = Array2::<Complex64>::zeros((0, 3));
1427        let mut empty_means = Array1::<Complex64>::from_vec(vec![
1428            Complex64::new(1.0, 1.0),
1429            Complex64::new(1.0, 1.0),
1430            Complex64::new(1.0, 1.0),
1431        ]);
1432        column_means_complex_into(&empty_columns, &mut empty_means).unwrap();
1433        assert!(empty_means.iter().all(|value| *value == Complex64::new(0.0, 0.0)));
1434
1435        let mut bad_means = Array1::<Complex64>::zeros(3);
1436        assert!(matches!(
1437            column_means_complex_into(&matrix, &mut bad_means),
1438            Err(StatsError::InvalidInput(_))
1439        ));
1440
1441        let mut bad_centered = Array2::<Complex64>::zeros((3, 3));
1442        assert!(matches!(
1443            center_columns_complex_into(&matrix, &mut bad_centered),
1444            Err(StatsError::InvalidInput(_))
1445        ));
1446
1447        let empty = Array2::<Complex64>::zeros((0, 0));
1448        let mut empty_covariance = Array2::<Complex64>::zeros((0, 0));
1449        assert!(matches!(
1450            covariance_matrix_complex_into(&empty, &mut empty_covariance),
1451            Err(StatsError::EmptyMatrix)
1452        ));
1453
1454        let one_row = Array2::from_shape_vec((1, 2), vec![
1455            Complex64::new(1.0, 0.0),
1456            Complex64::new(2.0, 0.0),
1457        ])
1458        .unwrap();
1459        let mut one_row_covariance = Array2::<Complex64>::zeros((2, 2));
1460        assert!(matches!(
1461            covariance_matrix_complex_into(&one_row, &mut one_row_covariance),
1462            Err(StatsError::InsufficientSamples)
1463        ));
1464
1465        let mut bad_covariance = Array2::<Complex64>::zeros((3, 3));
1466        assert!(matches!(
1467            covariance_matrix_complex_into(&matrix, &mut bad_covariance),
1468            Err(StatsError::InvalidInput(_))
1469        ));
1470        assert!(matches!(
1471            correlation_matrix_complex_into(&matrix, &mut bad_covariance),
1472            Err(StatsError::InvalidInput(_))
1473        ));
1474
1475        let unstable = Array2::from_shape_vec((2, 2), vec![
1476            Complex64::new(f64::MAX, 0.0),
1477            Complex64::new(0.0, 0.0),
1478            Complex64::new(-f64::MAX, 0.0),
1479            Complex64::new(0.0, 0.0),
1480        ])
1481        .unwrap();
1482        let mut unstable_covariance = Array2::<Complex64>::zeros((2, 2));
1483        assert!(matches!(
1484            covariance_matrix_complex_into(&unstable, &mut unstable_covariance),
1485            Err(StatsError::NumericalInstability)
1486        ));
1487    }
1488
1489    #[test]
1490    fn stats_view_into_rejects_wrong_output_shapes() {
1491        let matrix =
1492            Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1493                .unwrap();
1494        let mut bad_means = Array1::<f64>::zeros(3);
1495        let mut bad_centered = Array2::<f64>::zeros((4, 3));
1496        let mut bad_covariance = Array2::<f64>::zeros((3, 3));
1497
1498        assert!(matches!(
1499            column_means_view_into(&matrix.view(), &mut bad_means),
1500            Err(StatsError::InvalidInput(_))
1501        ));
1502        assert!(matches!(
1503            center_columns_view_into(&matrix.view(), &mut bad_centered),
1504            Err(StatsError::InvalidInput(_))
1505        ));
1506        assert!(matches!(
1507            covariance_matrix_view_into(&matrix.view(), &mut bad_covariance),
1508            Err(StatsError::InvalidInput(_))
1509        ));
1510    }
1511
1512    fn approx_eq(a: f64, b: f64) -> bool { num_traits::Float::abs(a - b) < 1e-10 }
1513
1514    #[test]
1515    fn online_mean_and_variance_track_batch_statistics() {
1516        let mut mean = online::OnlineMean::<f64>::default();
1517        let mut variance = online::OnlineVariance::<f64>::default();
1518        for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
1519            mean.push(value);
1520            variance.push(value);
1521        }
1522        assert!(approx_eq(mean.mean(), 3.0));
1523        assert!(approx_eq(variance.mean(), 3.0));
1524        assert!(approx_eq(variance.variance(), 2.5));
1525    }
1526
1527    #[test]
1528    fn online_accumulators_reset_to_initial_state() {
1529        let mut mean = online::OnlineMean::<f64>::default();
1530        let mut variance = online::OnlineVariance::<f64>::default();
1531        mean.push(4.0);
1532        variance.push(4.0);
1533        mean.reset();
1534        variance.reset();
1535        assert!(approx_eq(mean.mean(), 0.0));
1536        assert!(approx_eq(variance.mean(), 0.0));
1537        assert!(approx_eq(variance.variance(), 0.0));
1538    }
1539
1540    #[test]
1541    fn online_variance_is_zero_with_fewer_than_two_samples() {
1542        let mut variance = online::OnlineVariance::<f64>::default();
1543        variance.push(2.0);
1544        assert!(approx_eq(variance.variance(), 0.0));
1545    }
1546
1547    #[test]
1548    fn ewma_state_and_vector_apis_match_expected_smoothing() {
1549        let signal = Array1::from_vec(vec![10.0_f64, 0.0]);
1550        let alpha = 0.5;
1551        let mut state = ewma::EwmaState::new(alpha);
1552        assert!(approx_eq(state.push(10.0), 10.0));
1553        assert!(approx_eq(state.push(0.0), 5.0));
1554
1555        let owned = ewma::ewma(&signal, alpha);
1556        let viewed = ewma::ewma_view(&signal.view(), alpha);
1557        assert!(approx_eq(owned[0], 10.0));
1558        assert!(approx_eq(owned[1], 5.0));
1559        assert_eq!(owned, viewed);
1560
1561        let mut output = Array1::<f64>::zeros(2);
1562        ewma::ewma_into(&signal.view(), alpha, &mut output).unwrap();
1563        assert_eq!(output, owned);
1564
1565        let mut bad_output = Array1::<f64>::zeros(3);
1566        assert!(matches!(
1567            ewma::ewma_into(&signal.view(), alpha, &mut bad_output),
1568            Err(StatsError::InvalidInput(_))
1569        ));
1570    }
1571
1572    #[test]
1573    fn rolling_mean_and_variance_match_manual_windows() {
1574        let signal = Array1::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]);
1575        let window = 3;
1576        let mean = rolling::rolling_mean(&signal.view(), window);
1577        let mean_view = rolling::rolling_mean_view(&signal.view(), window);
1578        assert_eq!(mean, mean_view);
1579        assert!(approx_eq(mean[0], 1.0));
1580        assert!(approx_eq(mean[2], 2.0));
1581        assert!(approx_eq(mean[4], 4.0));
1582
1583        let variance = rolling::rolling_variance(&signal.view(), window);
1584        assert!(approx_eq(variance[2], 2.0 / 3.0));
1585
1586        let mut mean_out = Array1::<f64>::zeros(signal.len());
1587        rolling::rolling_mean_into(&signal.view(), window, &mut mean_out).unwrap();
1588        assert_eq!(mean_out, mean);
1589
1590        let mut variance_out = Array1::<f64>::zeros(signal.len());
1591        rolling::rolling_variance_into(&signal.view(), window, &mut variance_out).unwrap();
1592        assert_eq!(variance_out, variance);
1593
1594        assert!(matches!(
1595            rolling::rolling_mean_into(&signal.view(), 0, &mut mean_out),
1596            Err(StatsError::InvalidInput(_))
1597        ));
1598        let mut bad_mean_out = Array1::<f64>::zeros(signal.len() + 1);
1599        assert!(matches!(
1600            rolling::rolling_mean_into(&signal.view(), window, &mut bad_mean_out),
1601            Err(StatsError::InvalidInput(_))
1602        ));
1603        assert!(matches!(
1604            rolling::rolling_variance_into(&signal.view(), 0, &mut variance_out),
1605            Err(StatsError::InvalidInput(_))
1606        ));
1607    }
1608
1609    #[test]
1610    fn rolling_covariance_into_handles_valid_and_invalid_inputs() {
1611        let matrix =
1612            Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1613        let window = 2;
1614        let cov = rolling::rolling_covariance(&matrix.view(), window);
1615        let cov_view = rolling::rolling_covariance_view(&matrix.view(), window);
1616        assert_eq!(cov, cov_view);
1617        assert!(approx_eq(cov[[0, 0]], 0.0));
1618
1619        let mut output = Array2::<f64>::zeros((3, 4));
1620        rolling::rolling_covariance_into(&matrix.view(), window, &mut output).unwrap();
1621        assert_eq!(output, cov);
1622
1623        let mut bad_shape = Array2::<f64>::zeros((3, 3));
1624        assert!(matches!(
1625            rolling::rolling_covariance_into(&matrix.view(), window, &mut bad_shape),
1626            Err(StatsError::InvalidInput(_))
1627        ));
1628        assert!(matches!(
1629            rolling::rolling_covariance_into(&matrix.view(), 0, &mut output),
1630            Err(StatsError::InvalidInput(_))
1631        ));
1632    }
1633
1634    #[test]
1635    fn lag_view_and_shift_columns_into_cover_alignment_paths() {
1636        let matrix =
1637            Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1638        let view = matrix.view();
1639        let lagged = lag::lag_view(&view, 1).unwrap();
1640        assert_eq!(lagged.dim(), (2, 2));
1641        assert!(approx_eq(lagged[[0, 0]], 1.0));
1642        assert!(approx_eq(lagged[[1, 1]], 4.0));
1643
1644        let mut shifted = Array2::<f64>::zeros((3, 2));
1645        lag::shift_columns_into(&view, 1, &mut shifted).unwrap();
1646        assert!(approx_eq(shifted[[0, 0]], 0.0));
1647        assert!(approx_eq(shifted[[1, 0]], 1.0));
1648        assert!(approx_eq(shifted[[2, 1]], 4.0));
1649
1650        assert!(matches!(lag::lag_view(&view, matrix.nrows()), Err(StatsError::InvalidInput(_))));
1651
1652        let mut bad_output = Array2::<f64>::zeros((2, 2));
1653        assert!(matches!(
1654            lag::shift_columns_into(&view, 1, &mut bad_output),
1655            Err(StatsError::InvalidInput(_))
1656        ));
1657
1658        let mut zeroed = Array2::<f64>::ones((3, 2));
1659        lag::shift_columns_into(&view, matrix.nrows(), &mut zeroed).unwrap();
1660        assert!(zeroed.iter().all(|value| approx_eq(*value, 0.0)));
1661    }
1662}