1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8
9#[derive(Debug, Clone)]
13pub struct MaskedArray<T> {
14 pub data: Array1<T>,
16 pub mask: Array1<bool>,
18}
19
20impl<T: Copy> MaskedArray<T> {
21 pub fn new(data: Array1<T>, mask: Array1<bool>) -> StatsResult<Self> {
23 if data.len() != mask.len() {
24 return Err(StatsError::DimensionMismatch(
25 "Data and mask arrays must have the same length".to_string(),
26 ));
27 }
28
29 Ok(Self { data, mask })
30 }
31
32 pub fn fromdata(data: Array1<T>) -> Self {
34 let mask = Array1::from_elem(data.len(), true);
35 Self { data, mask }
36 }
37
38 pub fn valid_values(&self) -> Vec<T> {
40 self.data
41 .iter()
42 .zip(self.mask.iter())
43 .filter_map(|(&value, &is_valid)| if is_valid { Some(value) } else { None })
44 .collect()
45 }
46
47 pub fn count_valid(&self) -> usize {
49 self.mask.iter().filter(|&&is_valid| is_valid).count()
50 }
51
52 pub fn has_valid_values(&self) -> bool {
54 self.count_valid() > 0
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct MaskedArray2<T> {
61 pub data: Array2<T>,
63 pub mask: Array2<bool>,
65}
66
67impl<T: Copy> MaskedArray2<T> {
68 pub fn new(data: Array2<T>, mask: Array2<bool>) -> StatsResult<Self> {
70 if data.shape() != mask.shape() {
71 return Err(StatsError::DimensionMismatch(
72 "Data and mask arrays must have the same shape".to_string(),
73 ));
74 }
75
76 Ok(Self { data, mask })
77 }
78
79 pub fn fromdata(data: Array2<T>) -> Self {
81 let mask = Array2::from_elem(data.dim(), true);
82 Self { data, mask }
83 }
84}
85
86#[allow(dead_code)]
109pub fn masked_mean<T>(maskedarray: &MaskedArray<T>, axis: Option<usize>) -> StatsResult<f64>
110where
111 T: Copy + Into<f64>,
112{
113 if !maskedarray.has_valid_values() {
114 return Err(StatsError::InvalidArgument(
115 "Array has no valid values".to_string(),
116 ));
117 }
118
119 let valid_values = maskedarray.valid_values();
120 let sum: f64 = valid_values.iter().map(|&x| x.into()).sum();
121 Ok(sum / valid_values.len() as f64)
122}
123
124#[allow(dead_code)]
134pub fn masked_var<T>(
135 maskedarray: &MaskedArray<T>,
136 ddof: usize,
137 axis: Option<usize>,
138) -> StatsResult<f64>
139where
140 T: Copy + Into<f64>,
141{
142 if !maskedarray.has_valid_values() {
143 return Err(StatsError::InvalidArgument(
144 "Array has no valid values".to_string(),
145 ));
146 }
147
148 let valid_values = maskedarray.valid_values();
149 let n = valid_values.len();
150
151 if n <= ddof {
152 return Err(StatsError::InvalidArgument(
153 "Number of valid values must be greater than ddof".to_string(),
154 ));
155 }
156
157 let mean = masked_mean(maskedarray, axis)?;
158 let sum_squared_diff: f64 = valid_values
159 .iter()
160 .map(|&x| {
161 let diff = x.into() - mean;
162 diff * diff
163 })
164 .sum();
165
166 Ok(sum_squared_diff / (n - ddof) as f64)
167}
168
169#[allow(dead_code)]
179pub fn masked_std<T>(
180 maskedarray: &MaskedArray<T>,
181 ddof: usize,
182 axis: Option<usize>,
183) -> StatsResult<f64>
184where
185 T: Copy + Into<f64>,
186{
187 let variance = masked_var(maskedarray, ddof, axis)?;
188 Ok(variance.sqrt())
189}
190
191#[allow(dead_code)]
199pub fn masked_median<T>(maskedarray: &MaskedArray<T>) -> StatsResult<f64>
200where
201 T: Copy + Into<f64> + PartialOrd,
202{
203 if !maskedarray.has_valid_values() {
204 return Err(StatsError::InvalidArgument(
205 "Array has no valid values".to_string(),
206 ));
207 }
208
209 let mut valid_values = maskedarray.valid_values();
210 valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
211
212 let n = valid_values.len();
213 let median = if n % 2 == 1 {
214 valid_values[n / 2].into()
215 } else {
216 let mid1 = valid_values[n / 2 - 1].into();
217 let mid2 = valid_values[n / 2].into();
218 (mid1 + mid2) / 2.0
219 };
220
221 Ok(median)
222}
223
224#[allow(dead_code)]
233pub fn masked_quantile<T>(
234 maskedarray: &MaskedArray<T>,
235 q: ArrayView1<f64>,
236) -> StatsResult<Array1<f64>>
237where
238 T: Copy + Into<f64> + PartialOrd,
239{
240 if !maskedarray.has_valid_values() {
241 return Err(StatsError::InvalidArgument(
242 "Array has no valid values".to_string(),
243 ));
244 }
245
246 for &quantile in q.iter() {
247 if !(0.0..=1.0).contains(&quantile) {
248 return Err(StatsError::InvalidArgument(
249 "Quantiles must be between 0 and 1".to_string(),
250 ));
251 }
252 }
253
254 let mut valid_values = maskedarray.valid_values();
255 valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
256
257 let n = valid_values.len() as f64;
258 let mut quantiles = Array1::zeros(q.len());
259
260 for (i, &quantile) in q.iter().enumerate() {
261 let index = quantile * (n - 1.0);
262 let lower = index.floor() as usize;
263 let upper = index.ceil() as usize;
264 let fraction = index - lower as f64;
265
266 if lower == upper {
267 quantiles[i] = valid_values[lower].into();
268 } else {
269 let lower_val = valid_values[lower].into();
270 let upper_val = valid_values[upper].into();
271 quantiles[i] = lower_val + fraction * (upper_val - lower_val);
272 }
273 }
274
275 Ok(quantiles)
276}
277
278#[allow(dead_code)]
288pub fn masked_corrcoef<T>(x: &MaskedArray<T>, y: &MaskedArray<T>, method: &str) -> StatsResult<f64>
289where
290 T: Copy + Into<f64> + PartialOrd,
291{
292 if x.data.len() != y.data.len() {
293 return Err(StatsError::DimensionMismatch(
294 "Arrays must have the same length".to_string(),
295 ));
296 }
297
298 let combined_mask: Array1<bool> = x
300 .mask
301 .iter()
302 .zip(y.mask.iter())
303 .map(|(&x_valid, &y_valid)| x_valid && y_valid)
304 .collect();
305
306 let valid_pairs: Vec<(T, T)> = x
307 .data
308 .iter()
309 .zip(y.data.iter())
310 .zip(combined_mask.iter())
311 .filter_map(
312 |((&x_val, &y_val), &is_valid)| {
313 if is_valid {
314 Some((x_val, y_val))
315 } else {
316 None
317 }
318 },
319 )
320 .collect();
321
322 if valid_pairs.is_empty() {
323 return Err(StatsError::InvalidArgument(
324 "No valid pairs found".to_string(),
325 ));
326 }
327
328 let n = valid_pairs.len() as f64;
329
330 match method {
331 "pearson" => {
332 let x_values: Vec<f64> = valid_pairs.iter().map(|(x, _)| (*x).into()).collect();
333 let y_values: Vec<f64> = valid_pairs.iter().map(|(_, y)| (*y).into()).collect();
334
335 let x_mean: f64 = x_values.iter().sum::<f64>() / n;
336 let y_mean: f64 = y_values.iter().sum::<f64>() / n;
337
338 let mut numerator = 0.0;
339 let mut x_var = 0.0;
340 let mut y_var = 0.0;
341
342 for (&x_val, &y_val) in x_values.iter().zip(y_values.iter()) {
343 let x_diff = x_val - x_mean;
344 let y_diff = y_val - y_mean;
345 numerator += x_diff * y_diff;
346 x_var += x_diff * x_diff;
347 y_var += y_diff * y_diff;
348 }
349
350 if x_var == 0.0 || y_var == 0.0 {
351 return Ok(0.0);
352 }
353
354 Ok(numerator / (x_var * y_var).sqrt())
355 }
356 "spearman" => {
357 let mut x_values: Vec<(f64, usize)> = valid_pairs
359 .iter()
360 .enumerate()
361 .map(|(i, (x, _))| ((*x).into(), i))
362 .collect();
363 let mut y_values: Vec<(f64, usize)> = valid_pairs
364 .iter()
365 .enumerate()
366 .map(|(i, (_, y))| ((*y).into(), i))
367 .collect();
368
369 x_values.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
370 y_values.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
371
372 let mut x_ranks = vec![0.0; valid_pairs.len()];
373 let mut y_ranks = vec![0.0; valid_pairs.len()];
374
375 for (rank, (_, original_idx)) in x_values.iter().enumerate() {
376 x_ranks[*original_idx] = rank as f64 + 1.0;
377 }
378 for (rank, (_, original_idx)) in y_values.iter().enumerate() {
379 y_ranks[*original_idx] = rank as f64 + 1.0;
380 }
381
382 let x_rank_mean = x_ranks.iter().sum::<f64>() / n;
384 let y_rank_mean = y_ranks.iter().sum::<f64>() / n;
385
386 let mut numerator = 0.0;
387 let mut x_var = 0.0;
388 let mut y_var = 0.0;
389
390 for (&x_rank, &y_rank) in x_ranks.iter().zip(y_ranks.iter()) {
391 let x_diff = x_rank - x_rank_mean;
392 let y_diff = y_rank - y_rank_mean;
393 numerator += x_diff * y_diff;
394 x_var += x_diff * x_diff;
395 y_var += y_diff * y_diff;
396 }
397
398 if x_var == 0.0 || y_var == 0.0 {
399 return Ok(0.0);
400 }
401
402 Ok(numerator / (x_var * y_var).sqrt())
403 }
404 "kendall" => {
405 let mut concordant = 0;
407 let mut discordant = 0;
408
409 for i in 0..valid_pairs.len() {
410 for j in (i + 1)..valid_pairs.len() {
411 let (x1, y1) = valid_pairs[i];
412 let (x2, y2) = valid_pairs[j];
413
414 let x1_f64 = x1.into();
415 let y1_f64 = y1.into();
416 let x2_f64 = x2.into();
417 let y2_f64 = y2.into();
418
419 let x_diff = x2_f64 - x1_f64;
420 let y_diff = y2_f64 - y1_f64;
421
422 if x_diff * y_diff > 0.0 {
423 concordant += 1;
424 } else if x_diff * y_diff < 0.0 {
425 discordant += 1;
426 }
427 }
429 }
430
431 let total_pairs = valid_pairs.len() * (valid_pairs.len() - 1) / 2;
432 Ok((concordant - discordant) as f64 / total_pairs as f64)
433 }
434 _ => Err(StatsError::InvalidArgument(
435 "Method must be one of 'pearson', 'spearman', or 'kendall'".to_string(),
436 )),
437 }
438}
439
440#[allow(dead_code)]
450pub fn masked_cov<T>(x: &MaskedArray<T>, y: &MaskedArray<T>, ddof: usize) -> StatsResult<f64>
451where
452 T: Copy + Into<f64>,
453{
454 if x.data.len() != y.data.len() {
455 return Err(StatsError::DimensionMismatch(
456 "Arrays must have the same length".to_string(),
457 ));
458 }
459
460 let combined_mask: Array1<bool> = x
462 .mask
463 .iter()
464 .zip(y.mask.iter())
465 .map(|(&x_valid, &y_valid)| x_valid && y_valid)
466 .collect();
467
468 let valid_pairs: Vec<(T, T)> = x
469 .data
470 .iter()
471 .zip(y.data.iter())
472 .zip(combined_mask.iter())
473 .filter_map(
474 |((&x_val, &y_val), &is_valid)| {
475 if is_valid {
476 Some((x_val, y_val))
477 } else {
478 None
479 }
480 },
481 )
482 .collect();
483
484 if valid_pairs.len() <= ddof {
485 return Err(StatsError::InvalidArgument(
486 "Number of valid pairs must be greater than ddof".to_string(),
487 ));
488 }
489
490 let n = valid_pairs.len() as f64;
491 let x_values: Vec<f64> = valid_pairs.iter().map(|(x, _)| (*x).into()).collect();
492 let y_values: Vec<f64> = valid_pairs.iter().map(|(_, y)| (*y).into()).collect();
493
494 let x_mean: f64 = x_values.iter().sum::<f64>() / n;
495 let y_mean: f64 = y_values.iter().sum::<f64>() / n;
496
497 let covariance: f64 = x_values
498 .iter()
499 .zip(y_values.iter())
500 .map(|(&x_val, &y_val)| (x_val - x_mean) * (y_val - y_mean))
501 .sum::<f64>()
502 / (n - ddof as f64);
503
504 Ok(covariance)
505}
506
507#[allow(dead_code)]
516pub fn masked_skew<T>(maskedarray: &MaskedArray<T>, bias: bool) -> StatsResult<f64>
517where
518 T: Copy + Into<f64>,
519{
520 if !maskedarray.has_valid_values() {
521 return Err(StatsError::InvalidArgument(
522 "Array has no valid values".to_string(),
523 ));
524 }
525
526 let valid_values = maskedarray.valid_values();
527 let n = valid_values.len() as f64;
528
529 if n < 3.0 {
530 return Err(StatsError::InvalidArgument(
531 "Skewness requires at least 3 valid values".to_string(),
532 ));
533 }
534
535 let mean = masked_mean(maskedarray, None)?;
536 let std_dev = masked_std(maskedarray, 1, None)?;
537
538 if std_dev == 0.0 {
539 return Ok(0.0);
540 }
541
542 let m3: f64 = valid_values
543 .iter()
544 .map(|&x| {
545 let z = (x.into() - mean) / std_dev;
546 z.powi(3)
547 })
548 .sum::<f64>()
549 / n;
550
551 if bias {
552 Ok(m3)
553 } else {
554 let correction = ((n * (n - 1.0)).sqrt()) / (n - 2.0);
556 Ok(correction * m3)
557 }
558}
559
560#[allow(dead_code)]
570pub fn masked_kurtosis<T>(
571 maskedarray: &MaskedArray<T>,
572 fisher: bool,
573 bias: bool,
574) -> StatsResult<f64>
575where
576 T: Copy + Into<f64>,
577{
578 if !maskedarray.has_valid_values() {
579 return Err(StatsError::InvalidArgument(
580 "Array has no valid values".to_string(),
581 ));
582 }
583
584 let valid_values = maskedarray.valid_values();
585 let n = valid_values.len() as f64;
586
587 if n < 4.0 {
588 return Err(StatsError::InvalidArgument(
589 "Kurtosis requires at least 4 valid values".to_string(),
590 ));
591 }
592
593 let mean = masked_mean(maskedarray, None)?;
594 let std_dev = masked_std(maskedarray, 1, None)?;
595
596 if std_dev == 0.0 {
597 return Err(StatsError::InvalidArgument(
598 "Standard deviation is zero".to_string(),
599 ));
600 }
601
602 let m4: f64 = valid_values
603 .iter()
604 .map(|&x| {
605 let z = (x.into() - mean) / std_dev;
606 z.powi(4)
607 })
608 .sum::<f64>()
609 / n;
610
611 let kurtosis = if bias {
612 m4
613 } else {
614 let term1 = (n - 1.0) / ((n - 2.0) * (n - 3.0));
616 let term2 = (n + 1.0) * m4 - 3.0 * (n - 1.0);
617 term1 * term2 + 3.0
618 };
619
620 if fisher {
621 Ok(kurtosis - 3.0) } else {
623 Ok(kurtosis)
624 }
625}
626
627#[allow(dead_code)]
636pub fn masked_tmean<T>(maskedarray: &MaskedArray<T>, proportiontocut: f64) -> StatsResult<f64>
637where
638 T: Copy + Into<f64> + PartialOrd,
639{
640 if !(0.0..0.5).contains(&proportiontocut) {
641 return Err(StatsError::InvalidArgument(
642 "proportiontocut must be between 0 and 0.5".to_string(),
643 ));
644 }
645
646 if !maskedarray.has_valid_values() {
647 return Err(StatsError::InvalidArgument(
648 "Array has no valid values".to_string(),
649 ));
650 }
651
652 let mut valid_values = maskedarray.valid_values();
653 valid_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
654
655 let n = valid_values.len();
656 let ncut = (n as f64 * proportiontocut).floor() as usize;
657
658 if n <= 2 * ncut {
659 return Err(StatsError::InvalidArgument(
660 "Too many values would be trimmed".to_string(),
661 ));
662 }
663
664 let trimmed_values = &valid_values[ncut..(n - ncut)];
665 let sum: f64 = trimmed_values.iter().map(|&x| x.into()).sum();
666
667 Ok(sum / trimmed_values.len() as f64)
668}