1use ferray_core::Array;
12use ferray_core::dimension::{Dimension, IxDyn};
13use ferray_core::dtype::Element;
14use ferray_core::error::{FerrayError, FerrayResult};
15use num_traits::Float;
16
17use crate::MaskedArray;
18
19fn compute_strides(shape: &[usize]) -> Vec<usize> {
25 let n = shape.len();
26 let mut s = vec![1usize; n];
27 for i in (0..n.saturating_sub(1)).rev() {
28 s[i] = s[i + 1] * shape[i + 1];
29 }
30 s
31}
32
33fn increment_multi(multi: &mut [usize], shape: &[usize]) -> bool {
35 for d in (0..multi.len()).rev() {
36 multi[d] += 1;
37 if multi[d] < shape[d] {
38 return true;
39 }
40 multi[d] = 0;
41 }
42 false
43}
44
45fn reduce_axis<T, D, F>(
52 ma: &MaskedArray<T, D>,
53 axis: usize,
54 fill_value: T,
55 kernel: F,
56) -> FerrayResult<MaskedArray<T, IxDyn>>
57where
58 T: Element + Copy,
59 D: Dimension,
60 F: Fn(&[(T, bool)]) -> Option<T>,
61{
62 let ndim = ma.ndim();
63 if axis >= ndim {
64 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
65 }
66 let shape = ma.shape();
67 let axis_len = shape[axis];
68
69 let out_shape: Vec<usize> = shape
71 .iter()
72 .enumerate()
73 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
74 .collect();
75 let out_size: usize = if out_shape.is_empty() {
76 1
77 } else {
78 out_shape.iter().product()
79 };
80
81 let src_data: Vec<T> = ma.data().iter().copied().collect();
84 let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
85 let strides = compute_strides(shape);
86
87 let mut out_data = Vec::with_capacity(out_size);
88 let mut out_mask = Vec::with_capacity(out_size);
89 let mut out_multi = vec![0usize; out_shape.len()];
90 let mut in_multi = vec![0usize; ndim];
91 let mut lane: Vec<(T, bool)> = Vec::with_capacity(axis_len);
92
93 for _ in 0..out_size {
94 let mut out_dim = 0;
97 for (d, idx) in in_multi.iter_mut().enumerate() {
98 if d == axis {
99 *idx = 0;
100 } else {
101 *idx = out_multi[out_dim];
102 out_dim += 1;
103 }
104 }
105
106 lane.clear();
107 for k in 0..axis_len {
108 in_multi[axis] = k;
109 let flat = in_multi
110 .iter()
111 .zip(strides.iter())
112 .map(|(i, s)| i * s)
113 .sum::<usize>();
114 lane.push((src_data[flat], src_mask[flat]));
115 }
116
117 if let Some(value) = kernel(&lane) {
118 out_data.push(value);
119 out_mask.push(false);
120 } else {
121 out_data.push(fill_value);
122 out_mask.push(true);
123 }
124
125 if !out_shape.is_empty() {
126 increment_multi(&mut out_multi, &out_shape);
127 }
128 }
129
130 let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)?;
131 let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&out_shape), out_mask)?;
132 let mut result = MaskedArray::new(data_arr, mask_arr)?;
133 result.set_fill_value(fill_value);
134 Ok(result)
135}
136
137fn count_axis<T, D>(ma: &MaskedArray<T, D>, axis: usize) -> FerrayResult<Array<u64, IxDyn>>
140where
141 T: Element + Copy,
142 D: Dimension,
143{
144 let ndim = ma.ndim();
145 if axis >= ndim {
146 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
147 }
148 let shape = ma.shape();
149 let axis_len = shape[axis];
150 let out_shape: Vec<usize> = shape
151 .iter()
152 .enumerate()
153 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
154 .collect();
155 let out_size: usize = if out_shape.is_empty() {
156 1
157 } else {
158 out_shape.iter().product()
159 };
160
161 let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
162 let strides = compute_strides(shape);
163 let mut out: Vec<u64> = Vec::with_capacity(out_size);
164 let mut out_multi = vec![0usize; out_shape.len()];
165 let mut in_multi = vec![0usize; ndim];
166
167 for _ in 0..out_size {
168 let mut out_dim = 0;
169 for (d, idx) in in_multi.iter_mut().enumerate() {
170 if d == axis {
171 *idx = 0;
172 } else {
173 *idx = out_multi[out_dim];
174 out_dim += 1;
175 }
176 }
177
178 let mut count: u64 = 0;
179 for k in 0..axis_len {
180 in_multi[axis] = k;
181 let flat = in_multi
182 .iter()
183 .zip(strides.iter())
184 .map(|(i, s)| i * s)
185 .sum::<usize>();
186 if !src_mask[flat] {
187 count += 1;
188 }
189 }
190 out.push(count);
191
192 if !out_shape.is_empty() {
193 increment_multi(&mut out_multi, &out_shape);
194 }
195 }
196
197 Array::<u64, IxDyn>::from_vec(IxDyn::new(&out_shape), out)
198}
199
200impl<T, D> MaskedArray<T, D>
201where
202 T: Element + Copy,
203 D: Dimension,
204{
205 pub fn count(&self) -> FerrayResult<usize> {
211 let n = self
212 .data()
213 .iter()
214 .zip(self.mask().iter())
215 .filter(|(_, m)| !**m)
216 .count();
217 Ok(n)
218 }
219
220 pub fn count_axis(&self, axis: usize) -> FerrayResult<Array<u64, IxDyn>> {
228 count_axis(self, axis)
229 }
230}
231
232impl<T, D> MaskedArray<T, D>
233where
234 T: Element + Float,
235 D: Dimension,
236{
237 pub fn sum(&self) -> FerrayResult<T> {
244 let zero = num_traits::zero::<T>();
245 let s = self
246 .data()
247 .iter()
248 .zip(self.mask().iter())
249 .filter(|(_, m)| !**m)
250 .fold(zero, |acc, (v, _)| acc + *v);
251 Ok(s)
252 }
253
254 pub fn mean(&self) -> FerrayResult<T> {
261 let zero = num_traits::zero::<T>();
262 let (sum, count) = self
263 .data()
264 .iter()
265 .zip(self.mask().iter())
266 .filter(|(_, m)| !**m)
267 .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
268 if count == 0 {
269 return Ok(T::nan());
270 }
271 let n = T::from(count).ok_or_else(|| {
276 FerrayError::invalid_value(format!(
277 "cannot convert unmasked count {count} to element type"
278 ))
279 })?;
280 Ok(sum / n)
281 }
282
283 pub fn min(&self) -> FerrayResult<T> {
290 self.data()
291 .iter()
292 .zip(self.mask().iter())
293 .filter(|(_, m)| !**m)
294 .map(|(v, _)| *v)
295 .fold(None, |acc: Option<T>, v| {
296 Some(match acc {
297 Some(a) => {
298 if a <= v {
300 a
301 } else if a > v {
302 v
303 } else {
304 a
305 }
306 }
307 None => v,
308 })
309 })
310 .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
311 }
312
313 pub fn max(&self) -> FerrayResult<T> {
320 self.data()
321 .iter()
322 .zip(self.mask().iter())
323 .filter(|(_, m)| !**m)
324 .map(|(v, _)| *v)
325 .fold(None, |acc: Option<T>, v| {
326 Some(match acc {
327 Some(a) => {
328 if a >= v {
329 a
330 } else if a < v {
331 v
332 } else {
333 a
334 }
335 }
336 None => v,
337 })
338 })
339 .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
340 }
341
342 pub fn var(&self) -> FerrayResult<T> {
349 self.var_ddof(0)
350 }
351
352 pub fn var_ddof(&self, ddof: usize) -> FerrayResult<T> {
364 let zero = num_traits::zero::<T>();
371 let mut mean = zero;
372 let mut m2 = zero;
373 let mut count = 0usize;
374
375 for (v, m) in self.data().iter().zip(self.mask().iter()) {
376 if *m {
377 continue;
378 }
379 count += 1;
380 let n_t = T::from(count).ok_or_else(|| {
381 FerrayError::invalid_value(format!("cannot convert count {count} to element type"))
382 })?;
383 let delta = *v - mean;
384 mean = mean + delta / n_t;
385 let delta2 = *v - mean;
386 m2 = m2 + delta * delta2;
387 }
388
389 if count == 0 {
390 return Ok(T::nan());
391 }
392 if count <= ddof {
393 return Ok(T::nan());
394 }
395 let n = T::from(count - ddof).ok_or_else(|| {
396 FerrayError::invalid_value(format!(
397 "cannot convert (count - ddof) = {} to element type",
398 count - ddof
399 ))
400 })?;
401 Ok(m2 / n)
402 }
403
404 pub fn std(&self) -> FerrayResult<T> {
411 Ok(self.var()?.sqrt())
412 }
413
414 pub fn std_ddof(&self, ddof: usize) -> FerrayResult<T> {
425 Ok(self.var_ddof(ddof)?.sqrt())
426 }
427
428 pub fn sum_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
439 let zero = num_traits::zero::<T>();
440 let fill = self.fill_value();
441 reduce_axis(self, axis, fill, |lane| {
442 let mut acc = zero;
443 let mut any = false;
444 for &(v, m) in lane {
445 if !m {
446 acc = acc + v;
447 any = true;
448 }
449 }
450 if any { Some(acc) } else { None }
451 })
452 }
453
454 pub fn mean_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
457 let zero = num_traits::zero::<T>();
458 let fill = self.fill_value();
459 reduce_axis(self, axis, fill, |lane| {
460 let mut acc = zero;
461 let mut count = 0usize;
462 for &(v, m) in lane {
463 if !m {
464 acc = acc + v;
465 count += 1;
466 }
467 }
468 if count == 0 {
469 None
470 } else {
471 T::from(count).map(|n| acc / n)
476 }
477 })
478 }
479
480 pub fn min_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
482 let fill = self.fill_value();
483 reduce_axis(self, axis, fill, |lane| {
484 let mut acc: Option<T> = None;
485 for &(v, m) in lane {
486 if !m {
487 acc = Some(match acc {
488 Some(a) => {
489 if a <= v {
491 a
492 } else if a > v {
493 v
494 } else {
495 a
496 }
497 }
498 None => v,
499 });
500 }
501 }
502 acc
503 })
504 }
505
506 pub fn max_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
508 let fill = self.fill_value();
509 reduce_axis(self, axis, fill, |lane| {
510 let mut acc: Option<T> = None;
511 for &(v, m) in lane {
512 if !m {
513 acc = Some(match acc {
514 Some(a) => {
515 if a >= v {
516 a
517 } else if a < v {
518 v
519 } else {
520 a
521 }
522 }
523 None => v,
524 });
525 }
526 }
527 acc
528 })
529 }
530
531 pub fn var_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
533 self.var_axis_ddof(axis, 0)
534 }
535
536 pub fn var_axis_ddof(&self, axis: usize, ddof: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
542 let zero = num_traits::zero::<T>();
544 let fill = self.fill_value();
545 reduce_axis(self, axis, fill, |lane| {
546 let mut mean = zero;
547 let mut m2 = zero;
548 let mut count = 0usize;
549 for &(v, m) in lane {
550 if m {
551 continue;
552 }
553 count += 1;
554 let n_t = T::from(count)?;
555 let delta = v - mean;
556 mean = mean + delta / n_t;
557 let delta2 = v - mean;
558 m2 = m2 + delta * delta2;
559 }
560 if count <= ddof {
561 return None;
562 }
563 let n_var = T::from(count - ddof)?;
564 Some(m2 / n_var)
565 })
566 }
567
568 pub fn std_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
570 self.std_axis_ddof(axis, 0)
571 }
572
573 pub fn std_axis_ddof(&self, axis: usize, ddof: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
575 let result = self.var_axis_ddof(axis, ddof)?;
576 let fill = self.fill_value();
577 let mask = result.mask().clone();
578 let new_data: Vec<T> = result
579 .data()
580 .iter()
581 .zip(result.mask().iter())
582 .map(|(v, m)| if *m { fill } else { v.sqrt() })
583 .collect();
584 let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(result.shape()), new_data)?;
585 let mut out = MaskedArray::new(data_arr, mask)?;
586 out.set_fill_value(fill);
587 Ok(out)
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use ferray_core::dimension::{Ix1, Ix2};
595
596 fn ma2d(rows: usize, cols: usize, data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix2> {
597 let d = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap();
598 let m = Array::<bool, Ix2>::from_vec(Ix2::new([rows, cols]), mask).unwrap();
599 MaskedArray::new(d, m).unwrap()
600 }
601
602 #[test]
605 fn sum_axis_drops_axis() {
606 let ma = ma2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![false; 6]);
608 let s0 = ma.sum_axis(0).unwrap();
609 assert_eq!(s0.shape(), &[3]);
610 let d0: Vec<f64> = s0.data().iter().copied().collect();
611 assert_eq!(d0, vec![5.0, 7.0, 9.0]);
612
613 let s1 = ma.sum_axis(1).unwrap();
614 assert_eq!(s1.shape(), &[2]);
615 let d1: Vec<f64> = s1.data().iter().copied().collect();
616 assert_eq!(d1, vec![6.0, 15.0]);
617 }
618
619 #[test]
620 fn sum_axis_skips_masked() {
621 let ma = ma2d(
623 2,
624 3,
625 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
626 vec![false, true, false, false, false, false],
627 );
628 let s0 = ma.sum_axis(0).unwrap();
630 let d0: Vec<f64> = s0.data().iter().copied().collect();
631 assert_eq!(d0, vec![5.0, 5.0, 9.0]);
632 let m0: Vec<bool> = s0.mask().iter().copied().collect();
633 assert_eq!(m0, vec![false, false, false]);
634 }
635
636 #[test]
637 fn sum_axis_all_masked_lane_is_masked() {
638 let ma = ma2d(
640 2,
641 3,
642 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
643 vec![false, true, false, false, true, false],
644 );
645 let s0 = ma.sum_axis(0).unwrap();
646 let m0: Vec<bool> = s0.mask().iter().copied().collect();
647 assert_eq!(m0, vec![false, true, false]);
648 }
649
650 #[test]
651 fn mean_axis_skips_masked() {
652 let ma = ma2d(
653 2,
654 3,
655 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
656 vec![false, true, false, false, false, false],
657 );
658 let m1 = ma.mean_axis(1).unwrap();
660 let d: Vec<f64> = m1.data().iter().copied().collect();
661 assert_eq!(d, vec![2.0, 5.0]);
662 }
663
664 #[test]
665 fn min_max_axis() {
666 let ma = ma2d(2, 3, vec![3.0, 1.0, 5.0, 2.0, 4.0, 0.0], vec![false; 6]);
667 let mn = ma.min_axis(0).unwrap();
668 let mx = ma.max_axis(0).unwrap();
669 let mn_d: Vec<f64> = mn.data().iter().copied().collect();
670 let mx_d: Vec<f64> = mx.data().iter().copied().collect();
671 assert_eq!(mn_d, vec![2.0, 1.0, 0.0]);
672 assert_eq!(mx_d, vec![3.0, 4.0, 5.0]);
673 }
674
675 #[test]
676 fn count_axis_basic() {
677 let ma = ma2d(
679 2,
680 3,
681 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
682 vec![false, true, false, false, false, true],
683 );
684 let c0 = ma.count_axis(0).unwrap();
685 let v: Vec<u64> = c0.iter().copied().collect();
686 assert_eq!(v, vec![2u64, 1, 1]);
687 }
688
689 #[test]
690 fn axis_out_of_bounds_errors() {
691 let ma = ma2d(2, 3, vec![0.0; 6], vec![false; 6]);
692 assert!(ma.sum_axis(2).is_err());
693 }
694
695 #[test]
696 fn var_std_axis() {
697 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0];
699 let ma = ma2d(2, 5, data, vec![false; 10]);
700 let v = ma.var_axis(1).unwrap();
701 let s = ma.std_axis(1).unwrap();
702 let v_d: Vec<f64> = v.data().iter().copied().collect();
703 let s_d: Vec<f64> = s.data().iter().copied().collect();
704 for &x in &v_d {
705 assert!((x - 2.0).abs() < 1e-12);
706 }
707 for &x in &s_d {
708 assert!((x - 2.0_f64.sqrt()).abs() < 1e-12);
709 }
710 }
711
712 #[test]
715 fn fill_value_default_is_zero() {
716 let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
717 let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
718 let ma = MaskedArray::new(d, m).unwrap();
719 assert_eq!(ma.fill_value(), 0.0);
720 }
721
722 #[test]
723 fn with_fill_value_sets_field() {
724 let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
725 let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
726 let ma = MaskedArray::new(d, m).unwrap().with_fill_value(99.0);
727 assert_eq!(ma.fill_value(), 99.0);
728 }
729
730 #[test]
731 fn filled_default_uses_stored_fill_value() {
732 let d = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
733 let m =
734 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
735 let ma = MaskedArray::new(d, m).unwrap().with_fill_value(-1.0);
736 let filled = ma.filled_default().unwrap();
737 let v: Vec<f64> = filled.iter().copied().collect();
738 assert_eq!(v, vec![1.0, -1.0, 3.0, -1.0]);
739 }
740
741 #[test]
742 fn arithmetic_uses_fill_value() {
743 use crate::masked_add;
746 let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
747 let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
748 let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
749 let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
750 let a = MaskedArray::new(d_a, m_a).unwrap().with_fill_value(-999.0);
751 let b = MaskedArray::new(d_b, m_b).unwrap();
752 let r = masked_add(&a, &b).unwrap();
753 let r_d: Vec<f64> = r.data().iter().copied().collect();
754 assert_eq!(r_d, vec![11.0, -999.0, 33.0]);
755 assert_eq!(r.fill_value(), -999.0);
756 }
757
758 #[test]
761 fn masked_add_broadcasts_within_same_rank() {
762 use crate::masked_add;
763 let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
765 let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false; 3]).unwrap();
766 let d_b =
767 Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
768 let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
769 let a = MaskedArray::new(d_a, m_a).unwrap();
770 let b = MaskedArray::new(d_b, m_b).unwrap();
771 let r = masked_add(&a, &b).unwrap();
772 assert_eq!(r.shape(), &[3, 4]);
773 let r_d: Vec<f64> = r.data().iter().copied().collect();
774 assert_eq!(
775 r_d,
776 vec![
777 11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0, ]
781 );
782 let r_m: Vec<bool> = r.mask().iter().copied().collect();
783 assert_eq!(r_m, vec![false; 12]);
784 }
785
786 #[test]
787 fn masked_sub_broadcasts_with_mask_union() {
788 use crate::masked_sub;
789 let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
792 let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false, true, false]).unwrap();
793 let d_b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
794 let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
795 let a = MaskedArray::new(d_a, m_a).unwrap();
796 let b = MaskedArray::new(d_b, m_b).unwrap();
797 let r = masked_sub(&a, &b).unwrap();
798 let r_m: Vec<bool> = r.mask().iter().copied().collect();
799 assert_eq!(
801 r_m,
802 vec![
803 false, false, false, false, true, true, true, true, false, false, false, false, ]
807 );
808 }
809
810 fn all_masked_ma1d(n: usize) -> MaskedArray<f64, Ix1> {
821 let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), vec![1.0; n]).unwrap();
822 let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), vec![true; n]).unwrap();
823 MaskedArray::new(d, m).unwrap()
824 }
825
826 #[test]
827 fn sum_all_masked_returns_zero() {
828 let ma = all_masked_ma1d(4);
829 assert_eq!(ma.sum().unwrap(), 0.0);
830 }
831
832 #[test]
833 fn mean_all_masked_returns_nan() {
834 let ma = all_masked_ma1d(4);
835 assert!(ma.mean().unwrap().is_nan());
836 }
837
838 #[test]
839 fn var_all_masked_returns_nan() {
840 let ma = all_masked_ma1d(4);
841 assert!(ma.var().unwrap().is_nan());
842 }
843
844 #[test]
845 fn std_all_masked_returns_nan() {
846 let ma = all_masked_ma1d(4);
847 assert!(ma.std().unwrap().is_nan());
848 }
849
850 #[test]
851 fn min_max_all_masked_error() {
852 let ma = all_masked_ma1d(4);
855 assert!(ma.min().is_err());
856 assert!(ma.max().is_err());
857 }
858
859 #[test]
860 fn sum_var_std_all_masked_2d_matches_1d() {
861 let d = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![9.0; 6]).unwrap();
863 let m = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
864 let ma = MaskedArray::new(d, m).unwrap();
865 assert_eq!(ma.sum().unwrap(), 0.0);
866 assert!(ma.var().unwrap().is_nan());
867 assert!(ma.std().unwrap().is_nan());
868 }
869
870 #[test]
871 fn masked_add_broadcast_incompatible_errors() {
872 use crate::masked_add;
873 let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
874 let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
875 let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
876 let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false; 4]).unwrap();
877 let a = MaskedArray::new(d_a, m_a).unwrap();
878 let b = MaskedArray::new(d_b, m_b).unwrap();
879 assert!(masked_add(&a, &b).is_err());
880 }
881
882 #[test]
885 fn var_ddof_zero_matches_default_var() {
886 let data =
887 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
888 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false; 5]).unwrap();
889 let ma = MaskedArray::new(data, mask).unwrap();
890 let v0 = ma.var().unwrap();
891 let v_explicit = ma.var_ddof(0).unwrap();
892 assert!((v0 - v_explicit).abs() < 1e-14);
893 }
894
895 #[test]
896 fn var_ddof_one_is_bessel_corrected() {
897 let data =
901 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
902 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false; 5]).unwrap();
903 let ma = MaskedArray::new(data, mask).unwrap();
904 let v0 = ma.var_ddof(0).unwrap();
905 let v1 = ma.var_ddof(1).unwrap();
906 assert!((v0 - 2.0).abs() < 1e-14, "ddof=0: expected 2.0, got {v0}");
907 assert!((v1 - 2.5).abs() < 1e-14, "ddof=1: expected 2.5, got {v1}");
908 }
909
910 #[test]
911 fn var_ddof_skips_masked_elements() {
912 let data =
915 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 99.0, 4.0, 5.0]).unwrap();
916 let mask =
917 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, false, true, false, false])
918 .unwrap();
919 let ma = MaskedArray::new(data, mask).unwrap();
920 let v0 = ma.var_ddof(0).unwrap();
921 let v1 = ma.var_ddof(1).unwrap();
922 assert!((v0 - 2.5).abs() < 1e-14);
923 assert!((v1 - 10.0 / 3.0).abs() < 1e-14);
924 }
925
926 #[test]
927 fn var_ddof_returns_nan_when_count_le_ddof() {
928 let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
931 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
932 let ma = MaskedArray::new(data, mask).unwrap();
933 let v = ma.var_ddof(1).unwrap();
934 assert!(v.is_nan(), "expected NaN, got {v}");
935 }
936
937 #[test]
938 fn std_ddof_is_sqrt_of_var_ddof() {
939 let data =
940 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
941 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false; 5]).unwrap();
942 let ma = MaskedArray::new(data, mask).unwrap();
943 let s1 = ma.std_ddof(1).unwrap();
944 let v1 = ma.var_ddof(1).unwrap();
945 assert!((s1 - v1.sqrt()).abs() < 1e-14);
946 }
947
948 #[test]
949 fn var_welford_stable_on_high_offset_data() {
950 let offset = 1e9_f64;
955 let data: Vec<f64> = (1..=5).map(|i| offset + i as f64).collect();
956 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data).unwrap();
957 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false; 5]).unwrap();
958 let ma = MaskedArray::new(arr, mask).unwrap();
959 let v = ma.var().unwrap();
960 assert!(
963 (v - 2.0).abs() < 1e-9,
964 "var with offset 1e9: got {v}, expected 2.0"
965 );
966 }
967
968 #[test]
969 fn var_axis_ddof_one_per_row() {
970 use ferray_core::dimension::Ix2;
971 let data =
975 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0])
976 .unwrap();
977 let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
978 let ma = MaskedArray::new(data, mask).unwrap();
979 let v = ma.var_axis_ddof(1, 1).unwrap();
980 let vs: Vec<f64> = v.data().iter().copied().collect();
981 assert_eq!(vs.len(), 2);
982 assert!((vs[0] - 1.0).abs() < 1e-12);
983 assert!((vs[1] - 100.0).abs() < 1e-12);
984 }
985}