Skip to main content

scirs2_stats/mstats/
mod.rs

1//! Masked array statistics
2//!
3//! This module provides statistical functions that work with masked arrays,
4//! following SciPy's `stats.mstats` module.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9/// Masked array structure
10///
11/// Represents an array with associated mask indicating which values are valid/invalid.
12#[derive(Debug, Clone)]
13pub struct MaskedArray<T> {
14    /// The data array
15    pub data: Array1<T>,
16    /// The mask array (true = valid, false = masked/invalid)
17    pub mask: Array1<bool>,
18}
19
20impl<T: Copy> MaskedArray<T> {
21    /// Create a new masked array
22    pub fn new(data: Array1<T>, mask: Array1<bool>) -> StatsResult<Self> {
23        if data.len() != mask.len() {
24            return Err(StatsError::DimensionMismatch(
25                "Data and mask arrays must have the same length".to_string(),
26            ));
27        }
28
29        Ok(Self { data, mask })
30    }
31
32    /// Create a masked array with all values unmasked (valid)
33    pub fn fromdata(data: Array1<T>) -> Self {
34        let mask = Array1::from_elem(data.len(), true);
35        Self { data, mask }
36    }
37
38    /// Get the valid (unmasked) values
39    pub fn valid_values(&self) -> Vec<T> {
40        self.data
41            .iter()
42            .zip(self.mask.iter())
43            .filter_map(|(&value, &is_valid)| if is_valid { Some(value) } else { None })
44            .collect()
45    }
46
47    /// Count the number of valid values
48    pub fn count_valid(&self) -> usize {
49        self.mask.iter().filter(|&&is_valid| is_valid).count()
50    }
51
52    /// Check if the array has any valid values
53    pub fn has_valid_values(&self) -> bool {
54        self.count_valid() > 0
55    }
56}
57
58/// Masked 2D array structure
59#[derive(Debug, Clone)]
60pub struct MaskedArray2<T> {
61    /// The data array
62    pub data: Array2<T>,
63    /// The mask array (true = valid, false = masked/invalid)
64    pub mask: Array2<bool>,
65}
66
67impl<T: Copy> MaskedArray2<T> {
68    /// Create a new masked 2D array
69    pub fn new(data: Array2<T>, mask: Array2<bool>) -> StatsResult<Self> {
70        if data.shape() != mask.shape() {
71            return Err(StatsError::DimensionMismatch(
72                "Data and mask arrays must have the same shape".to_string(),
73            ));
74        }
75
76        Ok(Self { data, mask })
77    }
78
79    /// Create a masked array with all values unmasked (valid)
80    pub fn fromdata(data: Array2<T>) -> Self {
81        let mask = Array2::from_elem(data.dim(), true);
82        Self { data, mask }
83    }
84}
85
86/// Compute the mean of a masked array
87///
88/// # Arguments
89/// * `maskedarray` - The masked array
90/// * `axis` - Axis along which to compute the mean (None for overall mean)
91///
92/// # Returns
93/// * Mean of valid values
94///
95/// # Examples
96///
97/// ```
98/// use scirs2_core::ndarray::array;
99/// use scirs2_stats::mstats::{MaskedArray, masked_mean};
100///
101/// let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
102/// let mask = array![true, true, false, true, true]; // 3.0 is masked
103/// let masked_arr = MaskedArray::new(data, mask).expect("Operation failed");
104///
105/// let mean = masked_mean(&masked_arr, None).expect("Operation failed");
106/// assert!((mean - 3.0).abs() < 1e-10); // Mean of [1, 2, 4, 5] = 3.0
107/// ```
108#[allow(dead_code)]
109pub fn masked_mean<T>(maskedarray: &MaskedArray<T>, axis: Option<usize>) -> StatsResult<f64>
110where
111    T: Copy + Into<f64>,
112{
113    if !maskedarray.has_valid_values() {
114        return Err(StatsError::InvalidArgument(
115            "Array has no valid values".to_string(),
116        ));
117    }
118
119    let valid_values = maskedarray.valid_values();
120    let sum: f64 = valid_values.iter().map(|&x| x.into()).sum();
121    Ok(sum / valid_values.len() as f64)
122}
123
124/// Compute the variance of a masked array
125///
126/// # Arguments
127/// * `maskedarray` - The masked array
128/// * `ddof` - Delta degrees of freedom (0 for population variance, 1 for sample variance)
129/// * `axis` - Axis along which to compute the variance (None for overall variance)
130///
131/// # Returns
132/// * Variance of valid values
133#[allow(dead_code)]
134pub fn masked_var<T>(
135    maskedarray: &MaskedArray<T>,
136    ddof: usize,
137    axis: Option<usize>,
138) -> StatsResult<f64>
139where
140    T: Copy + Into<f64>,
141{
142    if !maskedarray.has_valid_values() {
143        return Err(StatsError::InvalidArgument(
144            "Array has no valid values".to_string(),
145        ));
146    }
147
148    let valid_values = maskedarray.valid_values();
149    let n = valid_values.len();
150
151    if n <= ddof {
152        return Err(StatsError::InvalidArgument(
153            "Number of valid values must be greater than ddof".to_string(),
154        ));
155    }
156
157    let mean = masked_mean(maskedarray, axis)?;
158    let sum_squared_diff: f64 = valid_values
159        .iter()
160        .map(|&x| {
161            let diff = x.into() - mean;
162            diff * diff
163        })
164        .sum();
165
166    Ok(sum_squared_diff / (n - ddof) as f64)
167}
168
169/// Compute the standard deviation of a masked array
170///
171/// # Arguments
172/// * `maskedarray` - The masked array
173/// * `ddof` - Delta degrees of freedom (0 for population std, 1 for sample std)
174/// * `axis` - Axis along which to compute the std (None for overall std)
175///
176/// # Returns
177/// * Standard deviation of valid values
178#[allow(dead_code)]
179pub fn masked_std<T>(
180    maskedarray: &MaskedArray<T>,
181    ddof: usize,
182    axis: Option<usize>,
183) -> StatsResult<f64>
184where
185    T: Copy + Into<f64>,
186{
187    let variance = masked_var(maskedarray, ddof, axis)?;
188    Ok(variance.sqrt())
189}
190
191/// Compute the median of a masked array
192///
193/// # Arguments
194/// * `maskedarray` - The masked array
195///
196/// # Returns
197/// * Median of valid values
198#[allow(dead_code)]
199pub fn masked_median<T>(maskedarray: &MaskedArray<T>) -> StatsResult<f64>
200where
201    T: Copy + Into<f64> + PartialOrd,
202{
203    if !maskedarray.has_valid_values() {
204        return Err(StatsError::InvalidArgument(
205            "Array has no valid values".to_string(),
206        ));
207    }
208
209    let mut valid_values = maskedarray.valid_values();
210    valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
211
212    let n = valid_values.len();
213    let median = if n % 2 == 1 {
214        valid_values[n / 2].into()
215    } else {
216        let mid1 = valid_values[n / 2 - 1].into();
217        let mid2 = valid_values[n / 2].into();
218        (mid1 + mid2) / 2.0
219    };
220
221    Ok(median)
222}
223
224/// Compute quantiles of a masked array
225///
226/// # Arguments
227/// * `maskedarray` - The masked array
228/// * `q` - Quantile or sequence of quantiles to compute (0.0 to 1.0)
229///
230/// # Returns
231/// * Array of quantiles
232#[allow(dead_code)]
233pub fn masked_quantile<T>(
234    maskedarray: &MaskedArray<T>,
235    q: ArrayView1<f64>,
236) -> StatsResult<Array1<f64>>
237where
238    T: Copy + Into<f64> + PartialOrd,
239{
240    if !maskedarray.has_valid_values() {
241        return Err(StatsError::InvalidArgument(
242            "Array has no valid values".to_string(),
243        ));
244    }
245
246    for &quantile in q.iter() {
247        if !(0.0..=1.0).contains(&quantile) {
248            return Err(StatsError::InvalidArgument(
249                "Quantiles must be between 0 and 1".to_string(),
250            ));
251        }
252    }
253
254    let mut valid_values = maskedarray.valid_values();
255    valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
256
257    let n = valid_values.len() as f64;
258    let mut quantiles = Array1::zeros(q.len());
259
260    for (i, &quantile) in q.iter().enumerate() {
261        let index = quantile * (n - 1.0);
262        let lower = index.floor() as usize;
263        let upper = index.ceil() as usize;
264        let fraction = index - lower as f64;
265
266        if lower == upper {
267            quantiles[i] = valid_values[lower].into();
268        } else {
269            let lower_val = valid_values[lower].into();
270            let upper_val = valid_values[upper].into();
271            quantiles[i] = lower_val + fraction * (upper_val - lower_val);
272        }
273    }
274
275    Ok(quantiles)
276}
277
278/// Compute the correlation coefficient between two masked arrays
279///
280/// # Arguments
281/// * `x` - First masked array
282/// * `y` - Second masked array
283/// * `method` - Correlation method ("pearson", "spearman", or "kendall")
284///
285/// # Returns
286/// * Correlation coefficient
287#[allow(dead_code)]
288pub fn masked_corrcoef<T>(x: &MaskedArray<T>, y: &MaskedArray<T>, method: &str) -> StatsResult<f64>
289where
290    T: Copy + Into<f64> + PartialOrd,
291{
292    if x.data.len() != y.data.len() {
293        return Err(StatsError::DimensionMismatch(
294            "Arrays must have the same length".to_string(),
295        ));
296    }
297
298    // Combine masks (both values must be valid)
299    let combined_mask: Array1<bool> = x
300        .mask
301        .iter()
302        .zip(y.mask.iter())
303        .map(|(&x_valid, &y_valid)| x_valid && y_valid)
304        .collect();
305
306    let valid_pairs: Vec<(T, T)> = x
307        .data
308        .iter()
309        .zip(y.data.iter())
310        .zip(combined_mask.iter())
311        .filter_map(
312            |((&x_val, &y_val), &is_valid)| {
313                if is_valid {
314                    Some((x_val, y_val))
315                } else {
316                    None
317                }
318            },
319        )
320        .collect();
321
322    if valid_pairs.is_empty() {
323        return Err(StatsError::InvalidArgument(
324            "No valid pairs found".to_string(),
325        ));
326    }
327
328    let n = valid_pairs.len() as f64;
329
330    match method {
331        "pearson" => {
332            let x_values: Vec<f64> = valid_pairs.iter().map(|(x, _)| (*x).into()).collect();
333            let y_values: Vec<f64> = valid_pairs.iter().map(|(_, y)| (*y).into()).collect();
334
335            let x_mean: f64 = x_values.iter().sum::<f64>() / n;
336            let y_mean: f64 = y_values.iter().sum::<f64>() / n;
337
338            let mut numerator = 0.0;
339            let mut x_var = 0.0;
340            let mut y_var = 0.0;
341
342            for (&x_val, &y_val) in x_values.iter().zip(y_values.iter()) {
343                let x_diff = x_val - x_mean;
344                let y_diff = y_val - y_mean;
345                numerator += x_diff * y_diff;
346                x_var += x_diff * x_diff;
347                y_var += y_diff * y_diff;
348            }
349
350            if x_var == 0.0 || y_var == 0.0 {
351                return Ok(0.0);
352            }
353
354            Ok(numerator / (x_var * y_var).sqrt())
355        }
356        "spearman" => {
357            // Convert to ranks
358            let mut x_values: Vec<(f64, usize)> = valid_pairs
359                .iter()
360                .enumerate()
361                .map(|(i, (x, _))| ((*x).into(), i))
362                .collect();
363            let mut y_values: Vec<(f64, usize)> = valid_pairs
364                .iter()
365                .enumerate()
366                .map(|(i, (_, y))| ((*y).into(), i))
367                .collect();
368
369            x_values.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
370            y_values.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
371
372            let mut x_ranks = vec![0.0; valid_pairs.len()];
373            let mut y_ranks = vec![0.0; valid_pairs.len()];
374
375            for (rank, (_, original_idx)) in x_values.iter().enumerate() {
376                x_ranks[*original_idx] = rank as f64 + 1.0;
377            }
378            for (rank, (_, original_idx)) in y_values.iter().enumerate() {
379                y_ranks[*original_idx] = rank as f64 + 1.0;
380            }
381
382            // Calculate Pearson correlation on ranks
383            let x_rank_mean = x_ranks.iter().sum::<f64>() / n;
384            let y_rank_mean = y_ranks.iter().sum::<f64>() / n;
385
386            let mut numerator = 0.0;
387            let mut x_var = 0.0;
388            let mut y_var = 0.0;
389
390            for (&x_rank, &y_rank) in x_ranks.iter().zip(y_ranks.iter()) {
391                let x_diff = x_rank - x_rank_mean;
392                let y_diff = y_rank - y_rank_mean;
393                numerator += x_diff * y_diff;
394                x_var += x_diff * x_diff;
395                y_var += y_diff * y_diff;
396            }
397
398            if x_var == 0.0 || y_var == 0.0 {
399                return Ok(0.0);
400            }
401
402            Ok(numerator / (x_var * y_var).sqrt())
403        }
404        "kendall" => {
405            // Kendall's tau
406            let mut concordant = 0;
407            let mut discordant = 0;
408
409            for i in 0..valid_pairs.len() {
410                for j in (i + 1)..valid_pairs.len() {
411                    let (x1, y1) = valid_pairs[i];
412                    let (x2, y2) = valid_pairs[j];
413
414                    let x1_f64 = x1.into();
415                    let y1_f64 = y1.into();
416                    let x2_f64 = x2.into();
417                    let y2_f64 = y2.into();
418
419                    let x_diff = x2_f64 - x1_f64;
420                    let y_diff = y2_f64 - y1_f64;
421
422                    if x_diff * y_diff > 0.0 {
423                        concordant += 1;
424                    } else if x_diff * y_diff < 0.0 {
425                        discordant += 1;
426                    }
427                    // Ties contribute 0
428                }
429            }
430
431            let total_pairs = valid_pairs.len() * (valid_pairs.len() - 1) / 2;
432            Ok((concordant - discordant) as f64 / total_pairs as f64)
433        }
434        _ => Err(StatsError::InvalidArgument(
435            "Method must be one of 'pearson', 'spearman', or 'kendall'".to_string(),
436        )),
437    }
438}
439
440/// Compute the covariance between two masked arrays
441///
442/// # Arguments
443/// * `x` - First masked array
444/// * `y` - Second masked array
445/// * `ddof` - Delta degrees of freedom
446///
447/// # Returns
448/// * Covariance
449#[allow(dead_code)]
450pub fn masked_cov<T>(x: &MaskedArray<T>, y: &MaskedArray<T>, ddof: usize) -> StatsResult<f64>
451where
452    T: Copy + Into<f64>,
453{
454    if x.data.len() != y.data.len() {
455        return Err(StatsError::DimensionMismatch(
456            "Arrays must have the same length".to_string(),
457        ));
458    }
459
460    // Combine masks (both values must be valid)
461    let combined_mask: Array1<bool> = x
462        .mask
463        .iter()
464        .zip(y.mask.iter())
465        .map(|(&x_valid, &y_valid)| x_valid && y_valid)
466        .collect();
467
468    let valid_pairs: Vec<(T, T)> = x
469        .data
470        .iter()
471        .zip(y.data.iter())
472        .zip(combined_mask.iter())
473        .filter_map(
474            |((&x_val, &y_val), &is_valid)| {
475                if is_valid {
476                    Some((x_val, y_val))
477                } else {
478                    None
479                }
480            },
481        )
482        .collect();
483
484    if valid_pairs.len() <= ddof {
485        return Err(StatsError::InvalidArgument(
486            "Number of valid pairs must be greater than ddof".to_string(),
487        ));
488    }
489
490    let n = valid_pairs.len() as f64;
491    let x_values: Vec<f64> = valid_pairs.iter().map(|(x, _)| (*x).into()).collect();
492    let y_values: Vec<f64> = valid_pairs.iter().map(|(_, y)| (*y).into()).collect();
493
494    let x_mean: f64 = x_values.iter().sum::<f64>() / n;
495    let y_mean: f64 = y_values.iter().sum::<f64>() / n;
496
497    let covariance: f64 = x_values
498        .iter()
499        .zip(y_values.iter())
500        .map(|(&x_val, &y_val)| (x_val - x_mean) * (y_val - y_mean))
501        .sum::<f64>()
502        / (n - ddof as f64);
503
504    Ok(covariance)
505}
506
507/// Compute masked skewness
508///
509/// # Arguments
510/// * `maskedarray` - The masked array
511/// * `bias` - If false, use bias-corrected formula
512///
513/// # Returns
514/// * Skewness of valid values
515#[allow(dead_code)]
516pub fn masked_skew<T>(maskedarray: &MaskedArray<T>, bias: bool) -> StatsResult<f64>
517where
518    T: Copy + Into<f64>,
519{
520    if !maskedarray.has_valid_values() {
521        return Err(StatsError::InvalidArgument(
522            "Array has no valid values".to_string(),
523        ));
524    }
525
526    let valid_values = maskedarray.valid_values();
527    let n = valid_values.len() as f64;
528
529    if n < 3.0 {
530        return Err(StatsError::InvalidArgument(
531            "Skewness requires at least 3 valid values".to_string(),
532        ));
533    }
534
535    let mean = masked_mean(maskedarray, None)?;
536    let std_dev = masked_std(maskedarray, 1, None)?;
537
538    if std_dev == 0.0 {
539        return Ok(0.0);
540    }
541
542    let m3: f64 = valid_values
543        .iter()
544        .map(|&x| {
545            let z = (x.into() - mean) / std_dev;
546            z.powi(3)
547        })
548        .sum::<f64>()
549        / n;
550
551    if bias {
552        Ok(m3)
553    } else {
554        // Bias-corrected skewness
555        let correction = ((n * (n - 1.0)).sqrt()) / (n - 2.0);
556        Ok(correction * m3)
557    }
558}
559
560/// Compute masked kurtosis
561///
562/// # Arguments
563/// * `maskedarray` - The masked array
564/// * `fisher` - If true, return Fisher's kurtosis (excess kurtosis)
565/// * `bias` - If false, use bias-corrected formula
566///
567/// # Returns
568/// * Kurtosis of valid values
569#[allow(dead_code)]
570pub fn masked_kurtosis<T>(
571    maskedarray: &MaskedArray<T>,
572    fisher: bool,
573    bias: bool,
574) -> StatsResult<f64>
575where
576    T: Copy + Into<f64>,
577{
578    if !maskedarray.has_valid_values() {
579        return Err(StatsError::InvalidArgument(
580            "Array has no valid values".to_string(),
581        ));
582    }
583
584    let valid_values = maskedarray.valid_values();
585    let n = valid_values.len() as f64;
586
587    if n < 4.0 {
588        return Err(StatsError::InvalidArgument(
589            "Kurtosis requires at least 4 valid values".to_string(),
590        ));
591    }
592
593    let mean = masked_mean(maskedarray, None)?;
594    let std_dev = masked_std(maskedarray, 1, None)?;
595
596    if std_dev == 0.0 {
597        return Err(StatsError::InvalidArgument(
598            "Standard deviation is zero".to_string(),
599        ));
600    }
601
602    let m4: f64 = valid_values
603        .iter()
604        .map(|&x| {
605            let z = (x.into() - mean) / std_dev;
606            z.powi(4)
607        })
608        .sum::<f64>()
609        / n;
610
611    let kurtosis = if bias {
612        m4
613    } else {
614        // Bias-corrected kurtosis
615        let term1 = (n - 1.0) / ((n - 2.0) * (n - 3.0));
616        let term2 = (n + 1.0) * m4 - 3.0 * (n - 1.0);
617        term1 * term2 + 3.0
618    };
619
620    if fisher {
621        Ok(kurtosis - 3.0) // Excess kurtosis
622    } else {
623        Ok(kurtosis)
624    }
625}
626
627/// Compute trimmed mean of a masked array
628///
629/// # Arguments
630/// * `maskedarray` - The masked array
631/// * `proportiontocut` - Fraction of values to trim from each end (0.0 to 0.5)
632///
633/// # Returns
634/// * Trimmed mean of valid values
635#[allow(dead_code)]
636pub fn masked_tmean<T>(maskedarray: &MaskedArray<T>, proportiontocut: f64) -> StatsResult<f64>
637where
638    T: Copy + Into<f64> + PartialOrd,
639{
640    if !(0.0..0.5).contains(&proportiontocut) {
641        return Err(StatsError::InvalidArgument(
642            "proportiontocut must be between 0 and 0.5".to_string(),
643        ));
644    }
645
646    if !maskedarray.has_valid_values() {
647        return Err(StatsError::InvalidArgument(
648            "Array has no valid values".to_string(),
649        ));
650    }
651
652    let mut valid_values = maskedarray.valid_values();
653    valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
654
655    let n = valid_values.len();
656    let ncut = (n as f64 * proportiontocut).floor() as usize;
657
658    if n <= 2 * ncut {
659        return Err(StatsError::InvalidArgument(
660            "Too many values would be trimmed".to_string(),
661        ));
662    }
663
664    let trimmed_values = &valid_values[ncut..(n - ncut)];
665    let sum: f64 = trimmed_values.iter().map(|&x| x.into()).sum();
666
667    Ok(sum / trimmed_values.len() as f64)
668}