1use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, DataMut, Ix1, Ix2};
7use num_complex::Complex64;
8
9#[derive(Debug, Clone, PartialEq)]
11pub enum StatsError {
12 EmptyMatrix,
14 InsufficientSamples,
16 InvalidInput(String),
18 NumericalInstability,
20}
21
22impl fmt::Display for StatsError {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 StatsError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
26 StatsError::InsufficientSamples => {
27 write!(f, "At least two observations are required")
28 }
29 StatsError::InvalidInput(message) => write!(f, "Invalid input: {message}"),
30 StatsError::NumericalInstability => write!(f, "Numerical instability detected"),
31 }
32 }
33}
34
35impl std::error::Error for StatsError {}
36
37fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
38 T::from_usize(value).unwrap_or(T::max_value())
39}
40
41fn complex_is_finite(value: Complex64) -> bool { value.re.is_finite() && value.im.is_finite() }
42
43fn validate_vector_output_len<T, S>(
44 output: &ArrayBase<S, Ix1>,
45 expected_len: usize,
46 name: &str,
47) -> Result<(), StatsError>
48where
49 S: DataMut<Elem = T>,
50{
51 if output.len() != expected_len {
52 return Err(StatsError::InvalidInput(format!(
53 "{name} output length must match expected length {expected_len}",
54 )));
55 }
56 Ok(())
57}
58
59fn validate_matrix_output_shape<T, S>(
60 output: &ArrayBase<S, Ix2>,
61 expected_rows: usize,
62 expected_cols: usize,
63 name: &str,
64) -> Result<(), StatsError>
65where
66 S: DataMut<Elem = T>,
67{
68 if output.nrows() != expected_rows || output.ncols() != expected_cols {
69 return Err(StatsError::InvalidInput(format!(
70 "{name} output shape must be ({expected_rows}, {expected_cols})",
71 )));
72 }
73 Ok(())
74}
75
76fn column_means_into_impl<T, S>(
77 matrix: &ArrayView2<'_, T>,
78 output: &mut ArrayBase<S, Ix1>,
79) -> Result<(), StatsError>
80where
81 T: NabledReal,
82 S: DataMut<Elem = T>,
83{
84 validate_vector_output_len(output, matrix.ncols(), "column_means")?;
85
86 if matrix.nrows() == 0 {
87 output.fill(T::zero());
88 return Ok(());
89 }
90
91 let denom = usize_to_scalar::<T>(matrix.nrows());
92 for col in 0..matrix.ncols() {
93 let mut sum = T::zero();
94 for row in 0..matrix.nrows() {
95 sum += matrix[[row, col]];
96 }
97 output[col] = sum / denom;
98 }
99
100 Ok(())
101}
102
103fn column_means_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
104 matrix.mean_axis(Axis(0)).unwrap_or_else(|| Array1::zeros(matrix.ncols()))
105}
106
107#[must_use]
109pub fn column_means<T: NabledReal>(matrix: &Array2<T>) -> Array1<T> {
110 column_means_impl(&matrix.view())
111}
112
113#[must_use]
115pub fn column_means_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
116 column_means_impl(matrix)
117}
118
119pub fn column_means_into<T, S>(
124 matrix: &Array2<T>,
125 output: &mut ArrayBase<S, Ix1>,
126) -> Result<(), StatsError>
127where
128 T: NabledReal,
129 S: DataMut<Elem = T>,
130{
131 column_means_into_impl(&matrix.view(), output)
132}
133
134pub fn column_means_view_into<T, S>(
139 matrix: &ArrayView2<'_, T>,
140 output: &mut ArrayBase<S, Ix1>,
141) -> Result<(), StatsError>
142where
143 T: NabledReal,
144 S: DataMut<Elem = T>,
145{
146 column_means_into_impl(matrix, output)
147}
148
149fn center_columns_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
150 let means = column_means_impl(matrix);
151 let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
152 for row in 0..matrix.nrows() {
153 for col in 0..matrix.ncols() {
154 centered[[row, col]] = matrix[[row, col]] - means[col];
155 }
156 }
157 centered
158}
159
160#[must_use]
162pub fn center_columns<T: NabledReal>(matrix: &Array2<T>) -> Array2<T> {
163 center_columns_impl(&matrix.view())
164}
165
166#[must_use]
168pub fn center_columns_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
169 center_columns_impl(matrix)
170}
171
172fn center_columns_into_impl<T, S>(
173 matrix: &ArrayView2<'_, T>,
174 output: &mut ArrayBase<S, Ix2>,
175) -> Result<(), StatsError>
176where
177 T: NabledReal,
178 S: DataMut<Elem = T>,
179{
180 validate_matrix_output_shape(output, matrix.nrows(), matrix.ncols(), "center_columns")?;
181
182 let mut means = Array1::<T>::zeros(matrix.ncols());
183 column_means_into_impl(matrix, &mut means)?;
184
185 for row in 0..matrix.nrows() {
186 for col in 0..matrix.ncols() {
187 output[[row, col]] = matrix[[row, col]] - means[col];
188 }
189 }
190
191 Ok(())
192}
193
194pub fn center_columns_into<T, S>(
199 matrix: &Array2<T>,
200 output: &mut ArrayBase<S, Ix2>,
201) -> Result<(), StatsError>
202where
203 T: NabledReal,
204 S: DataMut<Elem = T>,
205{
206 center_columns_into_impl(&matrix.view(), output)
207}
208
209pub fn center_columns_view_into<T, S>(
214 matrix: &ArrayView2<'_, T>,
215 output: &mut ArrayBase<S, Ix2>,
216) -> Result<(), StatsError>
217where
218 T: NabledReal,
219 S: DataMut<Elem = T>,
220{
221 center_columns_into_impl(matrix, output)
222}
223
224fn covariance_matrix_impl<T: NabledReal>(
225 matrix: &ArrayView2<'_, T>,
226) -> Result<Array2<T>, StatsError> {
227 if matrix.is_empty() {
228 return Err(StatsError::EmptyMatrix);
229 }
230 if matrix.nrows() < 2 {
231 return Err(StatsError::InsufficientSamples);
232 }
233
234 let centered = center_columns_impl(matrix);
235 let covariance: Array2<T> =
236 centered.t().dot(¢ered) / usize_to_scalar::<T>(matrix.nrows() - 1);
237
238 if covariance.iter().any(|value| !value.is_finite()) {
239 return Err(StatsError::NumericalInstability);
240 }
241
242 Ok(covariance)
243}
244
245pub fn covariance_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
250 covariance_matrix_impl(&matrix.view())
251}
252
253pub fn covariance_matrix_view<T: NabledReal>(
258 matrix: &ArrayView2<'_, T>,
259) -> Result<Array2<T>, StatsError> {
260 covariance_matrix_impl(matrix)
261}
262
263fn covariance_matrix_into_impl<T, S>(
264 matrix: &ArrayView2<'_, T>,
265 output: &mut ArrayBase<S, Ix2>,
266) -> Result<(), StatsError>
267where
268 T: NabledReal,
269 S: DataMut<Elem = T>,
270{
271 if matrix.is_empty() {
272 return Err(StatsError::EmptyMatrix);
273 }
274 if matrix.nrows() < 2 {
275 return Err(StatsError::InsufficientSamples);
276 }
277 validate_matrix_output_shape(output, matrix.ncols(), matrix.ncols(), "covariance_matrix")?;
278
279 let mut means = Array1::<T>::zeros(matrix.ncols());
280 column_means_into_impl(matrix, &mut means)?;
281 let denom = usize_to_scalar::<T>(matrix.nrows() - 1);
282
283 for i in 0..matrix.ncols() {
284 for j in i..matrix.ncols() {
285 let mut sum = T::zero();
286 for row in 0..matrix.nrows() {
287 let left = matrix[[row, i]] - means[i];
288 let right = matrix[[row, j]] - means[j];
289 sum += left * right;
290 }
291 let value = sum / denom;
292 output[[i, j]] = value;
293 if i != j {
294 output[[j, i]] = value;
295 }
296 }
297 }
298
299 if output.iter().any(|value| !value.is_finite()) {
300 return Err(StatsError::NumericalInstability);
301 }
302
303 Ok(())
304}
305
306pub fn covariance_matrix_into<T, S>(
311 matrix: &Array2<T>,
312 output: &mut ArrayBase<S, Ix2>,
313) -> Result<(), StatsError>
314where
315 T: NabledReal,
316 S: DataMut<Elem = T>,
317{
318 covariance_matrix_into_impl(&matrix.view(), output)
319}
320
321pub fn covariance_matrix_view_into<T, S>(
326 matrix: &ArrayView2<'_, T>,
327 output: &mut ArrayBase<S, Ix2>,
328) -> Result<(), StatsError>
329where
330 T: NabledReal,
331 S: DataMut<Elem = T>,
332{
333 covariance_matrix_into_impl(matrix, output)
334}
335
336fn correlation_matrix_impl<T: NabledReal>(
337 matrix: &ArrayView2<'_, T>,
338) -> Result<Array2<T>, StatsError> {
339 let covariance = covariance_matrix_impl(matrix)?;
340 let n = covariance.nrows();
341 let mut correlation = Array2::<T>::zeros((n, n));
342
343 for i in 0..n {
344 let sigma_i = covariance[[i, i]].sqrt();
345 for j in 0..n {
346 let sigma_j = covariance[[j, j]].sqrt();
347 let denom = (sigma_i * sigma_j).max(T::epsilon());
348 correlation[[i, j]] = covariance[[i, j]] / denom;
349 }
350 }
351
352 Ok(correlation)
353}
354
355pub fn correlation_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
360 correlation_matrix_impl(&matrix.view())
361}
362
363pub fn correlation_matrix_view<T: NabledReal>(
368 matrix: &ArrayView2<'_, T>,
369) -> Result<Array2<T>, StatsError> {
370 correlation_matrix_impl(matrix)
371}
372
373fn correlation_matrix_into_impl<T, S>(
374 matrix: &ArrayView2<'_, T>,
375 output: &mut ArrayBase<S, Ix2>,
376) -> Result<(), StatsError>
377where
378 T: NabledReal,
379 S: DataMut<Elem = T>,
380{
381 covariance_matrix_into_impl(matrix, output)?;
382 let mut sigmas = Array1::<T>::zeros(output.nrows());
383 for i in 0..output.nrows() {
384 sigmas[i] = output[[i, i]].sqrt();
385 }
386
387 for i in 0..output.nrows() {
388 let sigma_i = sigmas[i];
389 for j in 0..output.ncols() {
390 let sigma_j = sigmas[j];
391 let denom = (sigma_i * sigma_j).max(T::epsilon());
392 output[[i, j]] /= denom;
393 }
394 }
395
396 if output.iter().any(|value| !value.is_finite()) {
397 return Err(StatsError::NumericalInstability);
398 }
399
400 Ok(())
401}
402
403pub fn correlation_matrix_into<T, S>(
408 matrix: &Array2<T>,
409 output: &mut ArrayBase<S, Ix2>,
410) -> Result<(), StatsError>
411where
412 T: NabledReal,
413 S: DataMut<Elem = T>,
414{
415 correlation_matrix_into_impl(&matrix.view(), output)
416}
417
418pub fn correlation_matrix_view_into<T, S>(
423 matrix: &ArrayView2<'_, T>,
424 output: &mut ArrayBase<S, Ix2>,
425) -> Result<(), StatsError>
426where
427 T: NabledReal,
428 S: DataMut<Elem = T>,
429{
430 correlation_matrix_into_impl(matrix, output)
431}
432
433fn column_means_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
434 if matrix.nrows() == 0 {
435 return Array1::zeros(matrix.ncols());
436 }
437
438 let mut means = Array1::<Complex64>::zeros(matrix.ncols());
439 for col in 0..matrix.ncols() {
440 let mut sum = Complex64::new(0.0, 0.0);
441 for row in 0..matrix.nrows() {
442 sum += matrix[[row, col]];
443 }
444 means[col] = sum / usize_to_scalar::<f64>(matrix.nrows());
445 }
446 means
447}
448
449#[must_use]
451pub fn column_means_complex(matrix: &Array2<Complex64>) -> Array1<Complex64> {
452 column_means_complex_impl(&matrix.view())
453}
454
455#[must_use]
457pub fn column_means_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
458 column_means_complex_impl(matrix)
459}
460
461fn column_means_complex_into_impl<S>(
462 matrix: &ArrayView2<'_, Complex64>,
463 output: &mut ArrayBase<S, Ix1>,
464) -> Result<(), StatsError>
465where
466 S: DataMut<Elem = Complex64>,
467{
468 validate_vector_output_len(output, matrix.ncols(), "column_means_complex")?;
469
470 if matrix.nrows() == 0 {
471 output.fill(Complex64::new(0.0, 0.0));
472 return Ok(());
473 }
474
475 let denom = usize_to_scalar::<f64>(matrix.nrows());
476 for col in 0..matrix.ncols() {
477 let mut sum = Complex64::new(0.0, 0.0);
478 for row in 0..matrix.nrows() {
479 sum += matrix[[row, col]];
480 }
481 output[col] = sum / denom;
482 }
483
484 Ok(())
485}
486
487pub fn column_means_complex_into<S>(
492 matrix: &Array2<Complex64>,
493 output: &mut ArrayBase<S, Ix1>,
494) -> Result<(), StatsError>
495where
496 S: DataMut<Elem = Complex64>,
497{
498 column_means_complex_into_impl(&matrix.view(), output)
499}
500
501pub fn column_means_complex_view_into<S>(
506 matrix: &ArrayView2<'_, Complex64>,
507 output: &mut ArrayBase<S, Ix1>,
508) -> Result<(), StatsError>
509where
510 S: DataMut<Elem = Complex64>,
511{
512 column_means_complex_into_impl(matrix, output)
513}
514
515fn center_columns_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
516 let means = column_means_complex_impl(matrix);
517 let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
518 for row in 0..matrix.nrows() {
519 for col in 0..matrix.ncols() {
520 centered[[row, col]] = matrix[[row, col]] - means[col];
521 }
522 }
523 centered
524}
525
526#[must_use]
528pub fn center_columns_complex(matrix: &Array2<Complex64>) -> Array2<Complex64> {
529 center_columns_complex_impl(&matrix.view())
530}
531
532#[must_use]
534pub fn center_columns_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
535 center_columns_complex_impl(matrix)
536}
537
538fn center_columns_complex_into_impl<S>(
539 matrix: &ArrayView2<'_, Complex64>,
540 output: &mut ArrayBase<S, Ix2>,
541) -> Result<(), StatsError>
542where
543 S: DataMut<Elem = Complex64>,
544{
545 validate_matrix_output_shape(output, matrix.nrows(), matrix.ncols(), "center_columns_complex")?;
546
547 let mut means = Array1::<Complex64>::zeros(matrix.ncols());
548 column_means_complex_into_impl(matrix, &mut means)?;
549
550 for row in 0..matrix.nrows() {
551 for col in 0..matrix.ncols() {
552 output[[row, col]] = matrix[[row, col]] - means[col];
553 }
554 }
555
556 Ok(())
557}
558
559pub fn center_columns_complex_into<S>(
564 matrix: &Array2<Complex64>,
565 output: &mut ArrayBase<S, Ix2>,
566) -> Result<(), StatsError>
567where
568 S: DataMut<Elem = Complex64>,
569{
570 center_columns_complex_into_impl(&matrix.view(), output)
571}
572
573pub fn center_columns_complex_view_into<S>(
579 matrix: &ArrayView2<'_, Complex64>,
580 output: &mut ArrayBase<S, Ix2>,
581) -> Result<(), StatsError>
582where
583 S: DataMut<Elem = Complex64>,
584{
585 center_columns_complex_into_impl(matrix, output)
586}
587
588fn covariance_matrix_complex_impl(
589 matrix: &ArrayView2<'_, Complex64>,
590) -> Result<Array2<Complex64>, StatsError> {
591 if matrix.is_empty() {
592 return Err(StatsError::EmptyMatrix);
593 }
594 if matrix.nrows() < 2 {
595 return Err(StatsError::InsufficientSamples);
596 }
597
598 let centered = center_columns_complex_impl(matrix);
599 let conjugate_transpose = centered.t().mapv(|value| value.conj());
600 let covariance: Array2<Complex64> =
601 conjugate_transpose.dot(¢ered) / usize_to_scalar::<f64>(matrix.nrows() - 1);
602
603 if covariance.iter().any(|value| !complex_is_finite(*value)) {
604 return Err(StatsError::NumericalInstability);
605 }
606
607 Ok(covariance)
608}
609
610pub fn covariance_matrix_complex(
615 matrix: &Array2<Complex64>,
616) -> Result<Array2<Complex64>, StatsError> {
617 covariance_matrix_complex_impl(&matrix.view())
618}
619
620pub fn covariance_matrix_complex_view(
625 matrix: &ArrayView2<'_, Complex64>,
626) -> Result<Array2<Complex64>, StatsError> {
627 covariance_matrix_complex_impl(matrix)
628}
629
630fn covariance_matrix_complex_into_impl<S>(
631 matrix: &ArrayView2<'_, Complex64>,
632 output: &mut ArrayBase<S, Ix2>,
633) -> Result<(), StatsError>
634where
635 S: DataMut<Elem = Complex64>,
636{
637 if matrix.is_empty() {
638 return Err(StatsError::EmptyMatrix);
639 }
640 if matrix.nrows() < 2 {
641 return Err(StatsError::InsufficientSamples);
642 }
643 validate_matrix_output_shape(
644 output,
645 matrix.ncols(),
646 matrix.ncols(),
647 "covariance_matrix_complex",
648 )?;
649
650 let mut means = Array1::<Complex64>::zeros(matrix.ncols());
651 column_means_complex_into_impl(matrix, &mut means)?;
652 let denom = usize_to_scalar::<f64>(matrix.nrows() - 1);
653
654 for i in 0..matrix.ncols() {
655 for j in i..matrix.ncols() {
656 let mut sum = Complex64::new(0.0, 0.0);
657 for row in 0..matrix.nrows() {
658 let left = matrix[[row, i]] - means[i];
659 let right = matrix[[row, j]] - means[j];
660 sum += left.conj() * right;
661 }
662 let value = sum / denom;
663 output[[i, j]] = value;
664 if i != j {
665 output[[j, i]] = value.conj();
666 }
667 }
668 }
669
670 if output.iter().any(|value| !complex_is_finite(*value)) {
671 return Err(StatsError::NumericalInstability);
672 }
673
674 Ok(())
675}
676
677pub fn covariance_matrix_complex_into<S>(
682 matrix: &Array2<Complex64>,
683 output: &mut ArrayBase<S, Ix2>,
684) -> Result<(), StatsError>
685where
686 S: DataMut<Elem = Complex64>,
687{
688 covariance_matrix_complex_into_impl(&matrix.view(), output)
689}
690
691pub fn covariance_matrix_complex_view_into<S>(
697 matrix: &ArrayView2<'_, Complex64>,
698 output: &mut ArrayBase<S, Ix2>,
699) -> Result<(), StatsError>
700where
701 S: DataMut<Elem = Complex64>,
702{
703 covariance_matrix_complex_into_impl(matrix, output)
704}
705
706fn correlation_matrix_complex_impl(
707 matrix: &ArrayView2<'_, Complex64>,
708) -> Result<Array2<Complex64>, StatsError> {
709 let covariance = covariance_matrix_complex_impl(matrix)?;
710 let n = covariance.nrows();
711 let mut correlation = Array2::<Complex64>::zeros((n, n));
712
713 for i in 0..n {
714 let sigma_i = covariance[[i, i]].re.max(0.0).sqrt();
715 for j in 0..n {
716 let sigma_j = covariance[[j, j]].re.max(0.0).sqrt();
717 let denom = (sigma_i * sigma_j).max(f64::EPSILON);
718 correlation[[i, j]] = covariance[[i, j]] / denom;
719 }
720 }
721
722 if correlation.iter().any(|value| !complex_is_finite(*value)) {
723 return Err(StatsError::NumericalInstability);
724 }
725
726 Ok(correlation)
727}
728
729pub fn correlation_matrix_complex(
734 matrix: &Array2<Complex64>,
735) -> Result<Array2<Complex64>, StatsError> {
736 correlation_matrix_complex_impl(&matrix.view())
737}
738
739pub fn correlation_matrix_complex_view(
744 matrix: &ArrayView2<'_, Complex64>,
745) -> Result<Array2<Complex64>, StatsError> {
746 correlation_matrix_complex_impl(matrix)
747}
748
749fn correlation_matrix_complex_into_impl<S>(
750 matrix: &ArrayView2<'_, Complex64>,
751 output: &mut ArrayBase<S, Ix2>,
752) -> Result<(), StatsError>
753where
754 S: DataMut<Elem = Complex64>,
755{
756 covariance_matrix_complex_into_impl(matrix, output)?;
757 let mut sigmas = Array1::<f64>::zeros(output.nrows());
758 for i in 0..output.nrows() {
759 sigmas[i] = output[[i, i]].re.max(0.0).sqrt();
760 }
761
762 for i in 0..output.nrows() {
763 let sigma_i = sigmas[i];
764 for j in 0..output.ncols() {
765 let sigma_j = sigmas[j];
766 let denom = (sigma_i * sigma_j).max(f64::EPSILON);
767 output[[i, j]] /= denom;
768 }
769 }
770
771 if output.iter().any(|value| !complex_is_finite(*value)) {
772 return Err(StatsError::NumericalInstability);
773 }
774
775 Ok(())
776}
777
778pub fn correlation_matrix_complex_into<S>(
783 matrix: &Array2<Complex64>,
784 output: &mut ArrayBase<S, Ix2>,
785) -> Result<(), StatsError>
786where
787 S: DataMut<Elem = Complex64>,
788{
789 correlation_matrix_complex_into_impl(&matrix.view(), output)
790}
791
792pub fn correlation_matrix_complex_view_into<S>(
798 matrix: &ArrayView2<'_, Complex64>,
799 output: &mut ArrayBase<S, Ix2>,
800) -> Result<(), StatsError>
801where
802 S: DataMut<Elem = Complex64>,
803{
804 correlation_matrix_complex_into_impl(matrix, output)
805}
806
807pub mod online {
810 #![allow(clippy::missing_errors_doc)]
811 use nabled_core::scalar::NabledReal;
812
813 #[derive(Debug, Clone, PartialEq)]
815 pub struct OnlineMean<T> {
816 count: usize,
817 mean: T,
818 }
819
820 impl<T: NabledReal> Default for OnlineMean<T> {
821 fn default() -> Self { Self { count: 0, mean: T::zero() } }
822 }
823
824 impl<T: NabledReal> OnlineMean<T> {
825 pub fn push(&mut self, value: T) {
826 self.count += 1;
827 let n = T::from_usize(self.count).unwrap_or(T::one());
828 self.mean += (value - self.mean) / n;
829 }
830
831 #[must_use]
832 pub fn mean(&self) -> T { self.mean }
833
834 pub fn reset(&mut self) {
835 self.count = 0;
836 self.mean = T::zero();
837 }
838 }
839
840 #[derive(Debug, Clone, PartialEq)]
842 pub struct OnlineVariance<T> {
843 count: usize,
844 mean: T,
845 m2: T,
846 }
847
848 impl<T: NabledReal> Default for OnlineVariance<T> {
849 fn default() -> Self { Self { count: 0, mean: T::zero(), m2: T::zero() } }
850 }
851
852 impl<T: NabledReal> OnlineVariance<T> {
853 pub fn push(&mut self, value: T) {
854 self.count += 1;
855 let n = T::from_usize(self.count).unwrap_or(T::one());
856 let delta = value - self.mean;
857 self.mean += delta / n;
858 let delta2 = value - self.mean;
859 self.m2 += delta * delta2;
860 }
861
862 #[must_use]
863 pub fn mean(&self) -> T { self.mean }
864
865 #[must_use]
866 pub fn variance(&self) -> T {
867 if self.count < 2 {
868 return T::zero();
869 }
870 self.m2 / T::from_usize(self.count - 1).unwrap_or(T::one())
871 }
872
873 pub fn reset(&mut self) {
874 self.count = 0;
875 self.mean = T::zero();
876 self.m2 = T::zero();
877 }
878 }
879}
880
881pub mod ewma {
882 #![allow(clippy::missing_errors_doc)]
883 use nabled_core::scalar::NabledReal;
884 use ndarray::{Array1, ArrayBase, ArrayView1, DataMut, Ix1};
885
886 use super::StatsError;
887
888 #[derive(Debug, Clone, PartialEq)]
890 pub struct EwmaState<T> {
891 alpha: T,
892 value: Option<T>,
893 }
894
895 impl<T: NabledReal> EwmaState<T> {
896 pub fn new(alpha: T) -> Self { Self { alpha, value: None } }
897
898 pub fn push(&mut self, sample: T) -> T {
899 if let Some(prev) = self.value {
900 let next = self.alpha * sample + (T::one() - self.alpha) * prev;
901 self.value = Some(next);
902 next
903 } else {
904 self.value = Some(sample);
905 sample
906 }
907 }
908 }
909
910 #[must_use]
911 pub fn ewma<T: NabledReal>(signal: &Array1<T>, alpha: T) -> Array1<T> {
912 ewma_view(&signal.view(), alpha)
913 }
914
915 #[must_use]
916 pub fn ewma_view<T: NabledReal>(signal: &ArrayView1<'_, T>, alpha: T) -> Array1<T> {
917 let mut out = Array1::<T>::zeros(signal.len());
918 drop(ewma_into(signal, alpha, &mut out));
919 out
920 }
921
922 pub fn ewma_into<T, S>(
923 signal: &ArrayView1<'_, T>,
924 alpha: T,
925 output: &mut ArrayBase<S, Ix1>,
926 ) -> Result<(), StatsError>
927 where
928 T: NabledReal,
929 S: DataMut<Elem = T>,
930 {
931 if output.len() != signal.len() {
932 return Err(StatsError::InvalidInput("ewma output length mismatch".to_string()));
933 }
934 let mut state = EwmaState::new(alpha);
935 for (i, &sample) in signal.iter().enumerate() {
936 output[i] = state.push(sample);
937 }
938 Ok(())
939 }
940}
941
942pub mod rolling {
943 #![allow(clippy::missing_errors_doc)]
944 use nabled_core::scalar::NabledReal;
945 use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, DataMut, Ix1, Ix2};
946
947 use super::StatsError;
948
949 #[must_use]
950 pub fn rolling_mean<T: NabledReal>(signal: &ArrayView1<'_, T>, window: usize) -> Array1<T> {
951 let mut out = Array1::<T>::zeros(signal.len());
952 drop(rolling_mean_into(signal, window, &mut out));
953 out
954 }
955
956 pub fn rolling_mean_view<T: NabledReal>(
957 signal: &ArrayView1<'_, T>,
958 window: usize,
959 ) -> Array1<T> {
960 rolling_mean(signal, window)
961 }
962
963 pub fn rolling_mean_into<T, S>(
964 signal: &ArrayView1<'_, T>,
965 window: usize,
966 output: &mut ArrayBase<S, Ix1>,
967 ) -> Result<(), StatsError>
968 where
969 T: NabledReal,
970 S: DataMut<Elem = T>,
971 {
972 if window == 0 {
973 return Err(StatsError::InvalidInput("window must be positive".to_string()));
974 }
975 if output.len() != signal.len() {
976 return Err(StatsError::InvalidInput(
977 "rolling_mean output length mismatch".to_string(),
978 ));
979 }
980 for i in 0..signal.len() {
981 let start = i.saturating_sub(window - 1);
982 let slice = signal.slice(ndarray::s![start..=i]);
983 let sum = slice.iter().fold(T::zero(), |acc, v| acc + *v);
984 let count = T::from_usize(slice.len()).unwrap_or(T::one());
985 output[i] = sum / count;
986 }
987 Ok(())
988 }
989
990 #[must_use]
991 pub fn rolling_variance<T: NabledReal>(signal: &ArrayView1<'_, T>, window: usize) -> Array1<T> {
992 let mut out = Array1::<T>::zeros(signal.len());
993 drop(rolling_variance_into(signal, window, &mut out));
994 out
995 }
996
997 pub fn rolling_variance_into<T, S>(
998 signal: &ArrayView1<'_, T>,
999 window: usize,
1000 output: &mut ArrayBase<S, Ix1>,
1001 ) -> Result<(), StatsError>
1002 where
1003 T: NabledReal,
1004 S: DataMut<Elem = T>,
1005 {
1006 if window == 0 {
1007 return Err(StatsError::InvalidInput("window must be positive".to_string()));
1008 }
1009 if output.len() != signal.len() {
1010 return Err(StatsError::InvalidInput(
1011 "rolling_variance output length mismatch".to_string(),
1012 ));
1013 }
1014 for i in 0..signal.len() {
1015 let start = i.saturating_sub(window - 1);
1016 let slice = signal.slice(ndarray::s![start..=i]);
1017 let mean = slice.iter().fold(T::zero(), |acc, v| acc + *v)
1018 / T::from_usize(slice.len()).unwrap_or(T::one());
1019 let var = slice
1020 .iter()
1021 .map(|v| {
1022 let d = *v - mean;
1023 d * d
1024 })
1025 .fold(T::zero(), |acc, v| acc + v)
1026 / T::from_usize(slice.len()).unwrap_or(T::one());
1027 output[i] = var;
1028 }
1029 Ok(())
1030 }
1031
1032 #[must_use]
1033 pub fn rolling_covariance<T: NabledReal>(
1034 matrix: &ArrayView2<'_, T>,
1035 window: usize,
1036 ) -> Array2<T> {
1037 let mut out = Array2::<T>::zeros((matrix.nrows(), matrix.ncols() * matrix.ncols()));
1038 drop(rolling_covariance_into(matrix, window, &mut out));
1039 out
1040 }
1041
1042 pub fn rolling_covariance_view<T: NabledReal>(
1043 matrix: &ArrayView2<'_, T>,
1044 window: usize,
1045 ) -> Array2<T> {
1046 rolling_covariance(matrix, window)
1047 }
1048
1049 pub fn rolling_covariance_into<T, S>(
1050 matrix: &ArrayView2<'_, T>,
1051 window: usize,
1052 output: &mut ArrayBase<S, Ix2>,
1053 ) -> Result<(), StatsError>
1054 where
1055 T: NabledReal,
1056 S: DataMut<Elem = T>,
1057 {
1058 let cols = matrix.ncols();
1059 if output.nrows() != matrix.nrows() || output.ncols() != cols * cols {
1060 return Err(StatsError::InvalidInput(
1061 "rolling_covariance output shape mismatch".to_string(),
1062 ));
1063 }
1064 if window == 0 {
1065 return Err(StatsError::InvalidInput("window must be positive".to_string()));
1066 }
1067 for row in 0..matrix.nrows() {
1068 let start = row.saturating_sub(window - 1);
1069 let block = matrix.slice(ndarray::s![start..=row, ..]);
1070 let cov = if block.nrows() >= 2 {
1071 super::covariance_matrix_view(&block)?
1072 } else {
1073 Array2::<T>::zeros((cols, cols))
1074 };
1075 for i in 0..cols {
1076 for j in 0..cols {
1077 output[[row, i * cols + j]] = cov[[i, j]];
1078 }
1079 }
1080 }
1081 Ok(())
1082 }
1083}
1084
1085pub mod lag {
1086 #![allow(clippy::missing_errors_doc)]
1087 use nabled_core::scalar::NabledReal;
1088 use ndarray::{Array2, ArrayView2};
1089
1090 use super::StatsError;
1091
1092 pub fn lag_view<'a, T>(
1094 matrix: &'a ArrayView2<'_, T>,
1095 lag: usize,
1096 ) -> Result<ArrayView2<'a, T>, StatsError> {
1097 if lag >= matrix.nrows() {
1098 return Err(StatsError::InvalidInput(format!(
1099 "lag {lag} must be less than row count {}",
1100 matrix.nrows()
1101 )));
1102 }
1103 Ok(matrix.slice(ndarray::s![..matrix.nrows() - lag, ..]))
1104 }
1105
1106 pub fn shift_columns_into<T: NabledReal>(
1108 matrix: &ArrayView2<'_, T>,
1109 lag: usize,
1110 output: &mut Array2<T>,
1111 ) -> Result<(), StatsError> {
1112 if output.dim() != matrix.dim() {
1113 return Err(StatsError::InvalidInput(
1114 "shift_columns_into output shape mismatch".to_string(),
1115 ));
1116 }
1117 output.fill(T::zero());
1118 if lag >= matrix.nrows() {
1119 return Ok(());
1120 }
1121 let rows = matrix.nrows() - lag;
1122 output.slice_mut(ndarray::s![lag.., ..]).assign(&matrix.slice(ndarray::s![..rows, ..]));
1123 Ok(())
1124 }
1125}
1126
1127#[cfg(test)]
1128mod tests {
1129 use ndarray::{Array1, Array2};
1130 use num_complex::Complex64;
1131
1132 use super::*;
1133
1134 #[test]
1135 fn covariance_and_correlation_are_well_formed() {
1136 let matrix =
1137 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1138 .unwrap();
1139 let covariance = covariance_matrix(&matrix).unwrap();
1140 let correlation = correlation_matrix(&matrix).unwrap();
1141 assert_eq!(covariance.dim(), (2, 2));
1142 assert_eq!(correlation.dim(), (2, 2));
1143 }
1144
1145 #[test]
1146 fn stats_rejects_empty_and_insufficient_inputs() {
1147 let empty = Array2::<f64>::zeros((0, 0));
1148 assert!(matches!(covariance_matrix(&empty), Err(StatsError::EmptyMatrix)));
1149
1150 let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
1151 assert!(matches!(covariance_matrix(&one_row), Err(StatsError::InsufficientSamples)));
1152 }
1153
1154 #[test]
1155 fn center_columns_zeroes_means() {
1156 let matrix =
1157 Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1158 let centered = center_columns(&matrix);
1159 let means = column_means(¢ered);
1160 assert!(means.iter().all(|value| num_traits::Float::abs(*value) < 1e-12));
1161 }
1162
1163 #[test]
1164 fn column_means_handles_empty_input() {
1165 let matrix = Array2::<f64>::zeros((0, 3));
1166 let means = column_means(&matrix);
1167 assert_eq!(means.len(), 3);
1168 assert!(means.iter().all(|value| *value == 0.0));
1169 }
1170
1171 #[test]
1172 fn covariance_reports_numerical_instability() {
1173 let matrix = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
1174 let result = covariance_matrix(&matrix);
1175 assert!(matches!(result, Err(StatsError::NumericalInstability)));
1176 }
1177
1178 #[test]
1179 fn correlation_handles_zero_variance_column() {
1180 let matrix =
1181 Array2::from_shape_vec((3, 2), vec![1.0_f64, 10.0, 1.0, 20.0, 1.0, 30.0]).unwrap();
1182 let correlation = correlation_matrix(&matrix).unwrap();
1183 assert!(correlation[[0, 0]].is_finite());
1184 assert!(correlation[[0, 1]].is_finite());
1185 assert!(correlation[[1, 0]].is_finite());
1186 assert!(correlation[[1, 1]].is_finite());
1187 }
1188
1189 #[test]
1190 fn view_variants_match_owned() {
1191 let matrix =
1192 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1193 .unwrap();
1194 let means_owned = column_means(&matrix);
1195 let means_view = column_means_view(&matrix.view());
1196 let centered_owned = center_columns(&matrix);
1197 let centered_view = center_columns_view(&matrix.view());
1198 let covariance_owned = covariance_matrix(&matrix).unwrap();
1199 let covariance_view = covariance_matrix_view(&matrix.view()).unwrap();
1200 let correlation_owned = correlation_matrix(&matrix).unwrap();
1201 let correlation_view = correlation_matrix_view(&matrix.view()).unwrap();
1202
1203 for i in 0..means_owned.len() {
1204 assert!((means_owned[i] - means_view[i]).abs() < 1e-12);
1205 }
1206 for i in 0..matrix.nrows() {
1207 for j in 0..matrix.ncols() {
1208 assert!((centered_owned[[i, j]] - centered_view[[i, j]]).abs() < 1e-12);
1209 }
1210 }
1211 for i in 0..2 {
1212 for j in 0..2 {
1213 assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).abs() < 1e-12);
1214 assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).abs() < 1e-12);
1215 }
1216 }
1217 }
1218
1219 #[test]
1220 fn complex_covariance_and_correlation_are_well_formed() {
1221 let matrix = Array2::from_shape_vec((4, 2), vec![
1222 Complex64::new(1.0, 0.0),
1223 Complex64::new(3.0, 1.0),
1224 Complex64::new(2.0, -1.0),
1225 Complex64::new(2.0, 0.5),
1226 Complex64::new(3.0, 0.2),
1227 Complex64::new(1.0, -0.3),
1228 Complex64::new(4.0, 0.7),
1229 Complex64::new(0.0, 0.0),
1230 ])
1231 .unwrap();
1232
1233 let covariance = covariance_matrix_complex(&matrix).unwrap();
1234 let correlation = correlation_matrix_complex(&matrix).unwrap();
1235 assert_eq!(covariance.dim(), (2, 2));
1236 assert_eq!(correlation.dim(), (2, 2));
1237 }
1238
1239 #[test]
1240 fn complex_view_variants_match_owned() {
1241 let matrix = Array2::from_shape_vec((3, 2), vec![
1242 Complex64::new(1.0, 1.0),
1243 Complex64::new(2.0, -1.0),
1244 Complex64::new(2.0, 2.0),
1245 Complex64::new(3.0, 0.0),
1246 Complex64::new(3.0, -2.0),
1247 Complex64::new(4.0, 1.0),
1248 ])
1249 .unwrap();
1250
1251 let means_owned = column_means_complex(&matrix);
1252 let means_view = column_means_complex_view(&matrix.view());
1253 let centered_owned = center_columns_complex(&matrix);
1254 let centered_view = center_columns_complex_view(&matrix.view());
1255 let covariance_owned = covariance_matrix_complex(&matrix).unwrap();
1256 let covariance_view = covariance_matrix_complex_view(&matrix.view()).unwrap();
1257 let correlation_owned = correlation_matrix_complex(&matrix).unwrap();
1258 let correlation_view = correlation_matrix_complex_view(&matrix.view()).unwrap();
1259
1260 for i in 0..means_owned.len() {
1261 assert!((means_owned[i] - means_view[i]).norm() < 1e-12);
1262 }
1263 for i in 0..matrix.nrows() {
1264 for j in 0..matrix.ncols() {
1265 assert!((centered_owned[[i, j]] - centered_view[[i, j]]).norm() < 1e-12);
1266 }
1267 }
1268 for i in 0..2 {
1269 for j in 0..2 {
1270 assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).norm() < 1e-12);
1271 assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).norm() < 1e-12);
1272 }
1273 }
1274 }
1275
1276 #[test]
1277 fn stats_view_into_reuses_outputs() {
1278 let matrix =
1279 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1280 .unwrap();
1281
1282 let mut means = Array1::<f64>::zeros(2);
1283 let mut centered = Array2::<f64>::zeros((4, 2));
1284 let mut covariance = Array2::<f64>::zeros((2, 2));
1285 let mut correlation = Array2::<f64>::zeros((2, 2));
1286
1287 column_means_view_into(&matrix.view(), &mut means).unwrap();
1288 center_columns_view_into(&matrix.view(), &mut centered).unwrap();
1289 covariance_matrix_view_into(&matrix.view(), &mut covariance).unwrap();
1290 correlation_matrix_view_into(&matrix.view(), &mut correlation).unwrap();
1291
1292 assert_eq!(means, column_means(&matrix));
1293 assert_eq!(centered, center_columns(&matrix));
1294 assert_eq!(covariance, covariance_matrix(&matrix).unwrap());
1295 assert_eq!(correlation, correlation_matrix(&matrix).unwrap());
1296 }
1297
1298 #[test]
1299 fn complex_stats_view_into_reuses_outputs() {
1300 let matrix = Array2::from_shape_vec((3, 2), vec![
1301 Complex64::new(1.0, 1.0),
1302 Complex64::new(2.0, -1.0),
1303 Complex64::new(2.0, 2.0),
1304 Complex64::new(3.0, 0.0),
1305 Complex64::new(3.0, -2.0),
1306 Complex64::new(4.0, 1.0),
1307 ])
1308 .unwrap();
1309
1310 let mut means = Array1::<Complex64>::zeros(2);
1311 let mut centered = Array2::<Complex64>::zeros((3, 2));
1312 let mut covariance = Array2::<Complex64>::zeros((2, 2));
1313 let mut correlation = Array2::<Complex64>::zeros((2, 2));
1314
1315 column_means_complex_view_into(&matrix.view(), &mut means).unwrap();
1316 center_columns_complex_view_into(&matrix.view(), &mut centered).unwrap();
1317 covariance_matrix_complex_view_into(&matrix.view(), &mut covariance).unwrap();
1318 correlation_matrix_complex_view_into(&matrix.view(), &mut correlation).unwrap();
1319
1320 assert_eq!(means, column_means_complex(&matrix));
1321 assert_eq!(centered, center_columns_complex(&matrix));
1322 assert_eq!(covariance, covariance_matrix_complex(&matrix).unwrap());
1323 assert_eq!(correlation, correlation_matrix_complex(&matrix).unwrap());
1324 }
1325
1326 #[test]
1327 fn stats_owned_into_paths_cover_empty_valid_and_error_cases() {
1328 assert_eq!(
1329 StatsError::InvalidInput("bad shape".to_string()).to_string(),
1330 "Invalid input: bad shape"
1331 );
1332
1333 let matrix =
1334 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1335 .unwrap();
1336 let mut means = Array1::<f64>::zeros(2);
1337 let mut centered = Array2::<f64>::zeros((4, 2));
1338 let mut covariance = Array2::<f64>::zeros((2, 2));
1339 let mut correlation = Array2::<f64>::zeros((2, 2));
1340
1341 column_means_into(&matrix, &mut means).unwrap();
1342 center_columns_into(&matrix, &mut centered).unwrap();
1343 covariance_matrix_into(&matrix, &mut covariance).unwrap();
1344 correlation_matrix_into(&matrix, &mut correlation).unwrap();
1345
1346 assert_eq!(means, column_means(&matrix));
1347 assert_eq!(centered, center_columns(&matrix));
1348 assert_eq!(covariance, covariance_matrix(&matrix).unwrap());
1349 assert_eq!(correlation, correlation_matrix(&matrix).unwrap());
1350
1351 let empty_columns = Array2::<f64>::zeros((0, 3));
1352 let mut empty_means = Array1::<f64>::from_vec(vec![1.0, 1.0, 1.0]);
1353 column_means_into(&empty_columns, &mut empty_means).unwrap();
1354 assert!(empty_means.iter().all(|value| *value == 0.0));
1355
1356 let mut bad_means = Array1::<f64>::zeros(3);
1357 assert!(matches!(
1358 column_means_into(&matrix, &mut bad_means),
1359 Err(StatsError::InvalidInput(_))
1360 ));
1361
1362 let mut bad_centered = Array2::<f64>::zeros((4, 3));
1363 assert!(matches!(
1364 center_columns_into(&matrix, &mut bad_centered),
1365 Err(StatsError::InvalidInput(_))
1366 ));
1367
1368 let empty = Array2::<f64>::zeros((0, 0));
1369 let mut empty_covariance = Array2::<f64>::zeros((0, 0));
1370 assert!(matches!(
1371 covariance_matrix_into(&empty, &mut empty_covariance),
1372 Err(StatsError::EmptyMatrix)
1373 ));
1374
1375 let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
1376 let mut one_row_covariance = Array2::<f64>::zeros((2, 2));
1377 assert!(matches!(
1378 covariance_matrix_into(&one_row, &mut one_row_covariance),
1379 Err(StatsError::InsufficientSamples)
1380 ));
1381
1382 let mut bad_covariance = Array2::<f64>::zeros((3, 3));
1383 assert!(matches!(
1384 covariance_matrix_into(&matrix, &mut bad_covariance),
1385 Err(StatsError::InvalidInput(_))
1386 ));
1387 assert!(matches!(
1388 correlation_matrix_into(&matrix, &mut bad_covariance),
1389 Err(StatsError::InvalidInput(_))
1390 ));
1391
1392 let unstable = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
1393 let mut unstable_covariance = Array2::<f64>::zeros((2, 2));
1394 assert!(matches!(
1395 covariance_matrix_into(&unstable, &mut unstable_covariance),
1396 Err(StatsError::NumericalInstability)
1397 ));
1398 }
1399
1400 #[test]
1401 fn complex_stats_owned_into_paths_cover_empty_valid_and_error_cases() {
1402 let matrix = Array2::from_shape_vec((3, 2), vec![
1403 Complex64::new(1.0, 1.0),
1404 Complex64::new(2.0, -1.0),
1405 Complex64::new(2.0, 2.0),
1406 Complex64::new(3.0, 0.0),
1407 Complex64::new(3.0, -2.0),
1408 Complex64::new(4.0, 1.0),
1409 ])
1410 .unwrap();
1411 let mut means = Array1::<Complex64>::zeros(2);
1412 let mut centered = Array2::<Complex64>::zeros((3, 2));
1413 let mut covariance = Array2::<Complex64>::zeros((2, 2));
1414 let mut correlation = Array2::<Complex64>::zeros((2, 2));
1415
1416 column_means_complex_into(&matrix, &mut means).unwrap();
1417 center_columns_complex_into(&matrix, &mut centered).unwrap();
1418 covariance_matrix_complex_into(&matrix, &mut covariance).unwrap();
1419 correlation_matrix_complex_into(&matrix, &mut correlation).unwrap();
1420
1421 assert_eq!(means, column_means_complex(&matrix));
1422 assert_eq!(centered, center_columns_complex(&matrix));
1423 assert_eq!(covariance, covariance_matrix_complex(&matrix).unwrap());
1424 assert_eq!(correlation, correlation_matrix_complex(&matrix).unwrap());
1425
1426 let empty_columns = Array2::<Complex64>::zeros((0, 3));
1427 let mut empty_means = Array1::<Complex64>::from_vec(vec![
1428 Complex64::new(1.0, 1.0),
1429 Complex64::new(1.0, 1.0),
1430 Complex64::new(1.0, 1.0),
1431 ]);
1432 column_means_complex_into(&empty_columns, &mut empty_means).unwrap();
1433 assert!(empty_means.iter().all(|value| *value == Complex64::new(0.0, 0.0)));
1434
1435 let mut bad_means = Array1::<Complex64>::zeros(3);
1436 assert!(matches!(
1437 column_means_complex_into(&matrix, &mut bad_means),
1438 Err(StatsError::InvalidInput(_))
1439 ));
1440
1441 let mut bad_centered = Array2::<Complex64>::zeros((3, 3));
1442 assert!(matches!(
1443 center_columns_complex_into(&matrix, &mut bad_centered),
1444 Err(StatsError::InvalidInput(_))
1445 ));
1446
1447 let empty = Array2::<Complex64>::zeros((0, 0));
1448 let mut empty_covariance = Array2::<Complex64>::zeros((0, 0));
1449 assert!(matches!(
1450 covariance_matrix_complex_into(&empty, &mut empty_covariance),
1451 Err(StatsError::EmptyMatrix)
1452 ));
1453
1454 let one_row = Array2::from_shape_vec((1, 2), vec![
1455 Complex64::new(1.0, 0.0),
1456 Complex64::new(2.0, 0.0),
1457 ])
1458 .unwrap();
1459 let mut one_row_covariance = Array2::<Complex64>::zeros((2, 2));
1460 assert!(matches!(
1461 covariance_matrix_complex_into(&one_row, &mut one_row_covariance),
1462 Err(StatsError::InsufficientSamples)
1463 ));
1464
1465 let mut bad_covariance = Array2::<Complex64>::zeros((3, 3));
1466 assert!(matches!(
1467 covariance_matrix_complex_into(&matrix, &mut bad_covariance),
1468 Err(StatsError::InvalidInput(_))
1469 ));
1470 assert!(matches!(
1471 correlation_matrix_complex_into(&matrix, &mut bad_covariance),
1472 Err(StatsError::InvalidInput(_))
1473 ));
1474
1475 let unstable = Array2::from_shape_vec((2, 2), vec![
1476 Complex64::new(f64::MAX, 0.0),
1477 Complex64::new(0.0, 0.0),
1478 Complex64::new(-f64::MAX, 0.0),
1479 Complex64::new(0.0, 0.0),
1480 ])
1481 .unwrap();
1482 let mut unstable_covariance = Array2::<Complex64>::zeros((2, 2));
1483 assert!(matches!(
1484 covariance_matrix_complex_into(&unstable, &mut unstable_covariance),
1485 Err(StatsError::NumericalInstability)
1486 ));
1487 }
1488
1489 #[test]
1490 fn stats_view_into_rejects_wrong_output_shapes() {
1491 let matrix =
1492 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
1493 .unwrap();
1494 let mut bad_means = Array1::<f64>::zeros(3);
1495 let mut bad_centered = Array2::<f64>::zeros((4, 3));
1496 let mut bad_covariance = Array2::<f64>::zeros((3, 3));
1497
1498 assert!(matches!(
1499 column_means_view_into(&matrix.view(), &mut bad_means),
1500 Err(StatsError::InvalidInput(_))
1501 ));
1502 assert!(matches!(
1503 center_columns_view_into(&matrix.view(), &mut bad_centered),
1504 Err(StatsError::InvalidInput(_))
1505 ));
1506 assert!(matches!(
1507 covariance_matrix_view_into(&matrix.view(), &mut bad_covariance),
1508 Err(StatsError::InvalidInput(_))
1509 ));
1510 }
1511
1512 fn approx_eq(a: f64, b: f64) -> bool { num_traits::Float::abs(a - b) < 1e-10 }
1513
1514 #[test]
1515 fn online_mean_and_variance_track_batch_statistics() {
1516 let mut mean = online::OnlineMean::<f64>::default();
1517 let mut variance = online::OnlineVariance::<f64>::default();
1518 for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
1519 mean.push(value);
1520 variance.push(value);
1521 }
1522 assert!(approx_eq(mean.mean(), 3.0));
1523 assert!(approx_eq(variance.mean(), 3.0));
1524 assert!(approx_eq(variance.variance(), 2.5));
1525 }
1526
1527 #[test]
1528 fn online_accumulators_reset_to_initial_state() {
1529 let mut mean = online::OnlineMean::<f64>::default();
1530 let mut variance = online::OnlineVariance::<f64>::default();
1531 mean.push(4.0);
1532 variance.push(4.0);
1533 mean.reset();
1534 variance.reset();
1535 assert!(approx_eq(mean.mean(), 0.0));
1536 assert!(approx_eq(variance.mean(), 0.0));
1537 assert!(approx_eq(variance.variance(), 0.0));
1538 }
1539
1540 #[test]
1541 fn online_variance_is_zero_with_fewer_than_two_samples() {
1542 let mut variance = online::OnlineVariance::<f64>::default();
1543 variance.push(2.0);
1544 assert!(approx_eq(variance.variance(), 0.0));
1545 }
1546
1547 #[test]
1548 fn ewma_state_and_vector_apis_match_expected_smoothing() {
1549 let signal = Array1::from_vec(vec![10.0_f64, 0.0]);
1550 let alpha = 0.5;
1551 let mut state = ewma::EwmaState::new(alpha);
1552 assert!(approx_eq(state.push(10.0), 10.0));
1553 assert!(approx_eq(state.push(0.0), 5.0));
1554
1555 let owned = ewma::ewma(&signal, alpha);
1556 let viewed = ewma::ewma_view(&signal.view(), alpha);
1557 assert!(approx_eq(owned[0], 10.0));
1558 assert!(approx_eq(owned[1], 5.0));
1559 assert_eq!(owned, viewed);
1560
1561 let mut output = Array1::<f64>::zeros(2);
1562 ewma::ewma_into(&signal.view(), alpha, &mut output).unwrap();
1563 assert_eq!(output, owned);
1564
1565 let mut bad_output = Array1::<f64>::zeros(3);
1566 assert!(matches!(
1567 ewma::ewma_into(&signal.view(), alpha, &mut bad_output),
1568 Err(StatsError::InvalidInput(_))
1569 ));
1570 }
1571
1572 #[test]
1573 fn rolling_mean_and_variance_match_manual_windows() {
1574 let signal = Array1::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]);
1575 let window = 3;
1576 let mean = rolling::rolling_mean(&signal.view(), window);
1577 let mean_view = rolling::rolling_mean_view(&signal.view(), window);
1578 assert_eq!(mean, mean_view);
1579 assert!(approx_eq(mean[0], 1.0));
1580 assert!(approx_eq(mean[2], 2.0));
1581 assert!(approx_eq(mean[4], 4.0));
1582
1583 let variance = rolling::rolling_variance(&signal.view(), window);
1584 assert!(approx_eq(variance[2], 2.0 / 3.0));
1585
1586 let mut mean_out = Array1::<f64>::zeros(signal.len());
1587 rolling::rolling_mean_into(&signal.view(), window, &mut mean_out).unwrap();
1588 assert_eq!(mean_out, mean);
1589
1590 let mut variance_out = Array1::<f64>::zeros(signal.len());
1591 rolling::rolling_variance_into(&signal.view(), window, &mut variance_out).unwrap();
1592 assert_eq!(variance_out, variance);
1593
1594 assert!(matches!(
1595 rolling::rolling_mean_into(&signal.view(), 0, &mut mean_out),
1596 Err(StatsError::InvalidInput(_))
1597 ));
1598 let mut bad_mean_out = Array1::<f64>::zeros(signal.len() + 1);
1599 assert!(matches!(
1600 rolling::rolling_mean_into(&signal.view(), window, &mut bad_mean_out),
1601 Err(StatsError::InvalidInput(_))
1602 ));
1603 assert!(matches!(
1604 rolling::rolling_variance_into(&signal.view(), 0, &mut variance_out),
1605 Err(StatsError::InvalidInput(_))
1606 ));
1607 }
1608
1609 #[test]
1610 fn rolling_covariance_into_handles_valid_and_invalid_inputs() {
1611 let matrix =
1612 Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1613 let window = 2;
1614 let cov = rolling::rolling_covariance(&matrix.view(), window);
1615 let cov_view = rolling::rolling_covariance_view(&matrix.view(), window);
1616 assert_eq!(cov, cov_view);
1617 assert!(approx_eq(cov[[0, 0]], 0.0));
1618
1619 let mut output = Array2::<f64>::zeros((3, 4));
1620 rolling::rolling_covariance_into(&matrix.view(), window, &mut output).unwrap();
1621 assert_eq!(output, cov);
1622
1623 let mut bad_shape = Array2::<f64>::zeros((3, 3));
1624 assert!(matches!(
1625 rolling::rolling_covariance_into(&matrix.view(), window, &mut bad_shape),
1626 Err(StatsError::InvalidInput(_))
1627 ));
1628 assert!(matches!(
1629 rolling::rolling_covariance_into(&matrix.view(), 0, &mut output),
1630 Err(StatsError::InvalidInput(_))
1631 ));
1632 }
1633
1634 #[test]
1635 fn lag_view_and_shift_columns_into_cover_alignment_paths() {
1636 let matrix =
1637 Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1638 let view = matrix.view();
1639 let lagged = lag::lag_view(&view, 1).unwrap();
1640 assert_eq!(lagged.dim(), (2, 2));
1641 assert!(approx_eq(lagged[[0, 0]], 1.0));
1642 assert!(approx_eq(lagged[[1, 1]], 4.0));
1643
1644 let mut shifted = Array2::<f64>::zeros((3, 2));
1645 lag::shift_columns_into(&view, 1, &mut shifted).unwrap();
1646 assert!(approx_eq(shifted[[0, 0]], 0.0));
1647 assert!(approx_eq(shifted[[1, 0]], 1.0));
1648 assert!(approx_eq(shifted[[2, 1]], 4.0));
1649
1650 assert!(matches!(lag::lag_view(&view, matrix.nrows()), Err(StatsError::InvalidInput(_))));
1651
1652 let mut bad_output = Array2::<f64>::zeros((2, 2));
1653 assert!(matches!(
1654 lag::shift_columns_into(&view, 1, &mut bad_output),
1655 Err(StatsError::InvalidInput(_))
1656 ));
1657
1658 let mut zeroed = Array2::<f64>::ones((3, 2));
1659 lag::shift_columns_into(&view, matrix.nrows(), &mut zeroed).unwrap();
1660 assert!(zeroed.iter().all(|value| approx_eq(*value, 0.0)));
1661 }
1662}