1use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use ndarray::{Array1, Array2, ArrayView2, Axis};
7use num_complex::Complex64;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11pub enum StatsError {
12 EmptyMatrix,
14 InsufficientSamples,
16 NumericalInstability,
18}
19
20impl fmt::Display for StatsError {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 StatsError::EmptyMatrix => write!(f, "Matrix cannot be empty"),
24 StatsError::InsufficientSamples => {
25 write!(f, "At least two observations are required")
26 }
27 StatsError::NumericalInstability => write!(f, "Numerical instability detected"),
28 }
29 }
30}
31
32impl std::error::Error for StatsError {}
33
34fn usize_to_scalar<T: NabledReal>(value: usize) -> T {
35 T::from_usize(value).unwrap_or(T::max_value())
36}
37
38fn complex_is_finite(value: Complex64) -> bool { value.re.is_finite() && value.im.is_finite() }
39
40fn column_means_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
41 matrix.mean_axis(Axis(0)).unwrap_or_else(|| Array1::zeros(matrix.ncols()))
42}
43
44#[must_use]
46pub fn column_means<T: NabledReal>(matrix: &Array2<T>) -> Array1<T> {
47 column_means_impl(&matrix.view())
48}
49
50#[must_use]
52pub fn column_means_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array1<T> {
53 column_means_impl(matrix)
54}
55
56fn center_columns_impl<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
57 let means = column_means_impl(matrix);
58 let mut centered = Array2::<T>::zeros((matrix.nrows(), matrix.ncols()));
59 for row in 0..matrix.nrows() {
60 for col in 0..matrix.ncols() {
61 centered[[row, col]] = matrix[[row, col]] - means[col];
62 }
63 }
64 centered
65}
66
67#[must_use]
69pub fn center_columns<T: NabledReal>(matrix: &Array2<T>) -> Array2<T> {
70 center_columns_impl(&matrix.view())
71}
72
73#[must_use]
75pub fn center_columns_view<T: NabledReal>(matrix: &ArrayView2<'_, T>) -> Array2<T> {
76 center_columns_impl(matrix)
77}
78
79fn covariance_matrix_impl<T: NabledReal>(
80 matrix: &ArrayView2<'_, T>,
81) -> Result<Array2<T>, StatsError> {
82 if matrix.is_empty() {
83 return Err(StatsError::EmptyMatrix);
84 }
85 if matrix.nrows() < 2 {
86 return Err(StatsError::InsufficientSamples);
87 }
88
89 let centered = center_columns_impl(matrix);
90 let covariance: Array2<T> =
91 centered.t().dot(¢ered) / usize_to_scalar::<T>(matrix.nrows() - 1);
92
93 if covariance.iter().any(|value| !value.is_finite()) {
94 return Err(StatsError::NumericalInstability);
95 }
96
97 Ok(covariance)
98}
99
100pub fn covariance_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
105 covariance_matrix_impl(&matrix.view())
106}
107
108pub fn covariance_matrix_view<T: NabledReal>(
113 matrix: &ArrayView2<'_, T>,
114) -> Result<Array2<T>, StatsError> {
115 covariance_matrix_impl(matrix)
116}
117
118fn correlation_matrix_impl<T: NabledReal>(
119 matrix: &ArrayView2<'_, T>,
120) -> Result<Array2<T>, StatsError> {
121 let covariance = covariance_matrix_impl(matrix)?;
122 let n = covariance.nrows();
123 let mut correlation = Array2::<T>::zeros((n, n));
124
125 for i in 0..n {
126 let sigma_i = covariance[[i, i]].sqrt();
127 for j in 0..n {
128 let sigma_j = covariance[[j, j]].sqrt();
129 let denom = (sigma_i * sigma_j).max(T::epsilon());
130 correlation[[i, j]] = covariance[[i, j]] / denom;
131 }
132 }
133
134 Ok(correlation)
135}
136
137pub fn correlation_matrix<T: NabledReal>(matrix: &Array2<T>) -> Result<Array2<T>, StatsError> {
142 correlation_matrix_impl(&matrix.view())
143}
144
145pub fn correlation_matrix_view<T: NabledReal>(
150 matrix: &ArrayView2<'_, T>,
151) -> Result<Array2<T>, StatsError> {
152 correlation_matrix_impl(matrix)
153}
154
155fn column_means_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
156 if matrix.nrows() == 0 {
157 return Array1::zeros(matrix.ncols());
158 }
159
160 let mut means = Array1::<Complex64>::zeros(matrix.ncols());
161 for col in 0..matrix.ncols() {
162 let mut sum = Complex64::new(0.0, 0.0);
163 for row in 0..matrix.nrows() {
164 sum += matrix[[row, col]];
165 }
166 means[col] = sum / usize_to_scalar::<f64>(matrix.nrows());
167 }
168 means
169}
170
171#[must_use]
173pub fn column_means_complex(matrix: &Array2<Complex64>) -> Array1<Complex64> {
174 column_means_complex_impl(&matrix.view())
175}
176
177#[must_use]
179pub fn column_means_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array1<Complex64> {
180 column_means_complex_impl(matrix)
181}
182
183fn center_columns_complex_impl(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
184 let means = column_means_complex_impl(matrix);
185 let mut centered = Array2::<Complex64>::zeros((matrix.nrows(), matrix.ncols()));
186 for row in 0..matrix.nrows() {
187 for col in 0..matrix.ncols() {
188 centered[[row, col]] = matrix[[row, col]] - means[col];
189 }
190 }
191 centered
192}
193
194#[must_use]
196pub fn center_columns_complex(matrix: &Array2<Complex64>) -> Array2<Complex64> {
197 center_columns_complex_impl(&matrix.view())
198}
199
200#[must_use]
202pub fn center_columns_complex_view(matrix: &ArrayView2<'_, Complex64>) -> Array2<Complex64> {
203 center_columns_complex_impl(matrix)
204}
205
206fn covariance_matrix_complex_impl(
207 matrix: &ArrayView2<'_, Complex64>,
208) -> Result<Array2<Complex64>, StatsError> {
209 if matrix.is_empty() {
210 return Err(StatsError::EmptyMatrix);
211 }
212 if matrix.nrows() < 2 {
213 return Err(StatsError::InsufficientSamples);
214 }
215
216 let centered = center_columns_complex_impl(matrix);
217 let conjugate_transpose = centered.t().mapv(|value| value.conj());
218 let covariance: Array2<Complex64> =
219 conjugate_transpose.dot(¢ered) / usize_to_scalar::<f64>(matrix.nrows() - 1);
220
221 if covariance.iter().any(|value| !complex_is_finite(*value)) {
222 return Err(StatsError::NumericalInstability);
223 }
224
225 Ok(covariance)
226}
227
228pub fn covariance_matrix_complex(
233 matrix: &Array2<Complex64>,
234) -> Result<Array2<Complex64>, StatsError> {
235 covariance_matrix_complex_impl(&matrix.view())
236}
237
238pub fn covariance_matrix_complex_view(
243 matrix: &ArrayView2<'_, Complex64>,
244) -> Result<Array2<Complex64>, StatsError> {
245 covariance_matrix_complex_impl(matrix)
246}
247
248fn correlation_matrix_complex_impl(
249 matrix: &ArrayView2<'_, Complex64>,
250) -> Result<Array2<Complex64>, StatsError> {
251 let covariance = covariance_matrix_complex_impl(matrix)?;
252 let n = covariance.nrows();
253 let mut correlation = Array2::<Complex64>::zeros((n, n));
254
255 for i in 0..n {
256 let sigma_i = covariance[[i, i]].re.max(0.0).sqrt();
257 for j in 0..n {
258 let sigma_j = covariance[[j, j]].re.max(0.0).sqrt();
259 let denom = (sigma_i * sigma_j).max(f64::EPSILON);
260 correlation[[i, j]] = covariance[[i, j]] / denom;
261 }
262 }
263
264 if correlation.iter().any(|value| !complex_is_finite(*value)) {
265 return Err(StatsError::NumericalInstability);
266 }
267
268 Ok(correlation)
269}
270
271pub fn correlation_matrix_complex(
276 matrix: &Array2<Complex64>,
277) -> Result<Array2<Complex64>, StatsError> {
278 correlation_matrix_complex_impl(&matrix.view())
279}
280
281pub fn correlation_matrix_complex_view(
286 matrix: &ArrayView2<'_, Complex64>,
287) -> Result<Array2<Complex64>, StatsError> {
288 correlation_matrix_complex_impl(matrix)
289}
290
291#[cfg(test)]
292mod tests {
293 use ndarray::Array2;
294 use num_complex::Complex64;
295
296 use super::*;
297
298 #[test]
299 fn covariance_and_correlation_are_well_formed() {
300 let matrix =
301 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
302 .unwrap();
303 let covariance = covariance_matrix(&matrix).unwrap();
304 let correlation = correlation_matrix(&matrix).unwrap();
305 assert_eq!(covariance.dim(), (2, 2));
306 assert_eq!(correlation.dim(), (2, 2));
307 }
308
309 #[test]
310 fn stats_rejects_empty_and_insufficient_inputs() {
311 let empty = Array2::<f64>::zeros((0, 0));
312 assert!(matches!(covariance_matrix(&empty), Err(StatsError::EmptyMatrix)));
313
314 let one_row = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
315 assert!(matches!(covariance_matrix(&one_row), Err(StatsError::InsufficientSamples)));
316 }
317
318 #[test]
319 fn center_columns_zeroes_means() {
320 let matrix =
321 Array2::from_shape_vec((3, 2), vec![1.0_f64, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
322 let centered = center_columns(&matrix);
323 let means = column_means(¢ered);
324 assert!(means.iter().all(|value| num_traits::Float::abs(*value) < 1e-12));
325 }
326
327 #[test]
328 fn column_means_handles_empty_input() {
329 let matrix = Array2::<f64>::zeros((0, 3));
330 let means = column_means(&matrix);
331 assert_eq!(means.len(), 3);
332 assert!(means.iter().all(|value| *value == 0.0));
333 }
334
335 #[test]
336 fn covariance_reports_numerical_instability() {
337 let matrix = Array2::from_shape_vec((2, 2), vec![f64::MAX, 0.0, -f64::MAX, 0.0]).unwrap();
338 let result = covariance_matrix(&matrix);
339 assert!(matches!(result, Err(StatsError::NumericalInstability)));
340 }
341
342 #[test]
343 fn correlation_handles_zero_variance_column() {
344 let matrix =
345 Array2::from_shape_vec((3, 2), vec![1.0_f64, 10.0, 1.0, 20.0, 1.0, 30.0]).unwrap();
346 let correlation = correlation_matrix(&matrix).unwrap();
347 assert!(correlation[[0, 0]].is_finite());
348 assert!(correlation[[0, 1]].is_finite());
349 assert!(correlation[[1, 0]].is_finite());
350 assert!(correlation[[1, 1]].is_finite());
351 }
352
353 #[test]
354 fn view_variants_match_owned() {
355 let matrix =
356 Array2::from_shape_vec((4, 2), vec![1.0_f64, 3.0, 2.0, 2.0, 3.0, 1.0, 4.0, 0.0])
357 .unwrap();
358 let means_owned = column_means(&matrix);
359 let means_view = column_means_view(&matrix.view());
360 let centered_owned = center_columns(&matrix);
361 let centered_view = center_columns_view(&matrix.view());
362 let covariance_owned = covariance_matrix(&matrix).unwrap();
363 let covariance_view = covariance_matrix_view(&matrix.view()).unwrap();
364 let correlation_owned = correlation_matrix(&matrix).unwrap();
365 let correlation_view = correlation_matrix_view(&matrix.view()).unwrap();
366
367 for i in 0..means_owned.len() {
368 assert!((means_owned[i] - means_view[i]).abs() < 1e-12);
369 }
370 for i in 0..matrix.nrows() {
371 for j in 0..matrix.ncols() {
372 assert!((centered_owned[[i, j]] - centered_view[[i, j]]).abs() < 1e-12);
373 }
374 }
375 for i in 0..2 {
376 for j in 0..2 {
377 assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).abs() < 1e-12);
378 assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).abs() < 1e-12);
379 }
380 }
381 }
382
383 #[test]
384 fn complex_covariance_and_correlation_are_well_formed() {
385 let matrix = Array2::from_shape_vec((4, 2), vec![
386 Complex64::new(1.0, 0.0),
387 Complex64::new(3.0, 1.0),
388 Complex64::new(2.0, -1.0),
389 Complex64::new(2.0, 0.5),
390 Complex64::new(3.0, 0.2),
391 Complex64::new(1.0, -0.3),
392 Complex64::new(4.0, 0.7),
393 Complex64::new(0.0, 0.0),
394 ])
395 .unwrap();
396
397 let covariance = covariance_matrix_complex(&matrix).unwrap();
398 let correlation = correlation_matrix_complex(&matrix).unwrap();
399 assert_eq!(covariance.dim(), (2, 2));
400 assert_eq!(correlation.dim(), (2, 2));
401 }
402
403 #[test]
404 fn complex_view_variants_match_owned() {
405 let matrix = Array2::from_shape_vec((3, 2), vec![
406 Complex64::new(1.0, 1.0),
407 Complex64::new(2.0, -1.0),
408 Complex64::new(2.0, 2.0),
409 Complex64::new(3.0, 0.0),
410 Complex64::new(3.0, -2.0),
411 Complex64::new(4.0, 1.0),
412 ])
413 .unwrap();
414
415 let means_owned = column_means_complex(&matrix);
416 let means_view = column_means_complex_view(&matrix.view());
417 let centered_owned = center_columns_complex(&matrix);
418 let centered_view = center_columns_complex_view(&matrix.view());
419 let covariance_owned = covariance_matrix_complex(&matrix).unwrap();
420 let covariance_view = covariance_matrix_complex_view(&matrix.view()).unwrap();
421 let correlation_owned = correlation_matrix_complex(&matrix).unwrap();
422 let correlation_view = correlation_matrix_complex_view(&matrix.view()).unwrap();
423
424 for i in 0..means_owned.len() {
425 assert!((means_owned[i] - means_view[i]).norm() < 1e-12);
426 }
427 for i in 0..matrix.nrows() {
428 for j in 0..matrix.ncols() {
429 assert!((centered_owned[[i, j]] - centered_view[[i, j]]).norm() < 1e-12);
430 }
431 }
432 for i in 0..2 {
433 for j in 0..2 {
434 assert!((covariance_owned[[i, j]] - covariance_view[[i, j]]).norm() < 1e-12);
435 assert!((correlation_owned[[i, j]] - correlation_view[[i, j]]).norm() < 1e-12);
436 }
437 }
438 }
439}