1use ferray_core::error::{FerrayError, FerrayResult};
5use ferray_core::{Array, Dimension, Element, IxDyn};
6use num_traits::Float;
7
8use super::{borrow_data, make_result, output_shape, reduce_axis_general, validate_axis};
9
10fn lane_nansum<T: Float>(lane: &[T]) -> T {
16 let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
17 crate::parallel::pairwise_sum(&non_nan, T::zero())
18}
19
20fn lane_nanprod<T: Float>(lane: &[T]) -> T {
22 lane.iter()
23 .copied()
24 .filter(|x| !x.is_nan())
25 .fold(T::one(), |a, b| a * b)
26}
27
28fn lane_nanmean<T: Float>(lane: &[T]) -> T {
30 let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
31 if non_nan.is_empty() {
32 T::nan()
33 } else {
34 crate::parallel::pairwise_sum(&non_nan, T::zero()) / T::from(non_nan.len()).unwrap()
35 }
36}
37
38fn lane_nanvar<T: Float>(lane: &[T], ddof: usize) -> T {
40 let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
41 let count = non_nan.len();
42 if count <= ddof {
43 return T::nan();
44 }
45 let mean = crate::parallel::pairwise_sum(&non_nan, T::zero()) / T::from(count).unwrap();
46 let sq_diffs: Vec<T> = non_nan
47 .iter()
48 .map(|&x| {
49 let d = x - mean;
50 d * d
51 })
52 .collect();
53 crate::parallel::pairwise_sum(&sq_diffs, T::zero()) / T::from(count - ddof).unwrap()
54}
55
56fn lane_nanmin<T: Float>(lane: &[T]) -> T {
58 lane.iter()
59 .copied()
60 .filter(|x| !x.is_nan())
61 .reduce(|a, b| if a <= b { a } else { b })
62 .unwrap_or_else(T::nan)
63}
64
65fn lane_nanmax<T: Float>(lane: &[T]) -> T {
67 lane.iter()
68 .copied()
69 .filter(|x| !x.is_nan())
70 .reduce(|a, b| if a >= b { a } else { b })
71 .unwrap_or_else(T::nan)
72}
73
74pub fn nansum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
82where
83 T: Element + Float,
84 D: Dimension,
85{
86 let data = borrow_data(a);
87 match axis {
88 None => {
89 let total = lane_nansum(&data);
90 make_result(&[], vec![total])
91 }
92 Some(ax) => {
93 validate_axis(ax, a.ndim())?;
94 let shape = a.shape();
95 let out_s = output_shape(shape, ax);
96 let result = reduce_axis_general(&data, shape, ax, lane_nansum);
97 make_result(&out_s, result)
98 }
99 }
100}
101
102pub fn nanprod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
106where
107 T: Element + Float,
108 D: Dimension,
109{
110 let data = borrow_data(a);
111 match axis {
112 None => {
113 let total = lane_nanprod(&data);
114 make_result(&[], vec![total])
115 }
116 Some(ax) => {
117 validate_axis(ax, a.ndim())?;
118 let shape = a.shape();
119 let out_s = output_shape(shape, ax);
120 let result = reduce_axis_general(&data, shape, ax, lane_nanprod);
121 make_result(&out_s, result)
122 }
123 }
124}
125
126pub fn nanmean<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
130where
131 T: Element + Float,
132 D: Dimension,
133{
134 let data = borrow_data(a);
135 match axis {
136 None => {
137 let m = lane_nanmean(&data);
138 make_result(&[], vec![m])
139 }
140 Some(ax) => {
141 validate_axis(ax, a.ndim())?;
142 let shape = a.shape();
143 let out_s = output_shape(shape, ax);
144 let result = reduce_axis_general(&data, shape, ax, lane_nanmean);
145 make_result(&out_s, result)
146 }
147 }
148}
149
150pub fn nanvar<T, D>(
154 a: &Array<T, D>,
155 axis: Option<usize>,
156 ddof: usize,
157) -> FerrayResult<Array<T, IxDyn>>
158where
159 T: Element + Float,
160 D: Dimension,
161{
162 let data = borrow_data(a);
163 match axis {
164 None => {
165 let v = lane_nanvar(&data, ddof);
166 make_result(&[], vec![v])
167 }
168 Some(ax) => {
169 validate_axis(ax, a.ndim())?;
170 let shape = a.shape();
171 let out_s = output_shape(shape, ax);
172 let result = reduce_axis_general(&data, shape, ax, |lane| lane_nanvar(lane, ddof));
173 make_result(&out_s, result)
174 }
175 }
176}
177
178pub fn nanstd<T, D>(
182 a: &Array<T, D>,
183 axis: Option<usize>,
184 ddof: usize,
185) -> FerrayResult<Array<T, IxDyn>>
186where
187 T: Element + Float,
188 D: Dimension,
189{
190 let v = nanvar(a, axis, ddof)?;
191 let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
192 make_result(v.shape(), data)
193}
194
195pub fn nanmin<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
199where
200 T: Element + Float,
201 D: Dimension,
202{
203 if a.is_empty() {
204 return Err(FerrayError::invalid_value(
205 "cannot compute nanmin of empty array",
206 ));
207 }
208 let data = borrow_data(a);
209 match axis {
210 None => {
211 let m = lane_nanmin(&data);
212 make_result(&[], vec![m])
213 }
214 Some(ax) => {
215 validate_axis(ax, a.ndim())?;
216 let shape = a.shape();
217 let out_s = output_shape(shape, ax);
218 let result = reduce_axis_general(&data, shape, ax, lane_nanmin);
219 make_result(&out_s, result)
220 }
221 }
222}
223
224pub fn nanmax<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
228where
229 T: Element + Float,
230 D: Dimension,
231{
232 if a.is_empty() {
233 return Err(FerrayError::invalid_value(
234 "cannot compute nanmax of empty array",
235 ));
236 }
237 let data = borrow_data(a);
238 match axis {
239 None => {
240 let m = lane_nanmax(&data);
241 make_result(&[], vec![m])
242 }
243 Some(ax) => {
244 validate_axis(ax, a.ndim())?;
245 let shape = a.shape();
246 let out_s = output_shape(shape, ax);
247 let result = reduce_axis_general(&data, shape, ax, lane_nanmax);
248 make_result(&out_s, result)
249 }
250 }
251}
252
253pub fn nancumsum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
257where
258 T: Element + Float,
259 D: Dimension,
260{
261 ferray_ufunc::nancumsum(a, axis)
262}
263
264pub fn nancumprod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
268where
269 T: Element + Float,
270 D: Dimension,
271{
272 ferray_ufunc::nancumprod(a, axis)
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use ferray_core::Ix1;
279
280 #[test]
281 fn test_nanmean_basic() {
282 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
283 let m = nanmean(&a, None).unwrap();
284 assert!((m.iter().next().unwrap() - 2.0).abs() < 1e-12);
285 }
286
287 #[test]
288 fn test_nanmean_all_nan() {
289 let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
290 let m = nanmean(&a, None).unwrap();
291 assert!(m.iter().next().unwrap().is_nan());
292 }
293
294 #[test]
295 fn test_nansum_basic() {
296 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
297 let s = nansum(&a, None).unwrap();
298 assert!((s.iter().next().unwrap() - 4.0).abs() < 1e-12);
299 }
300
301 #[test]
302 fn test_nansum_all_nan() {
303 let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
304 let s = nansum(&a, None).unwrap();
305 assert!((s.iter().next().unwrap() - 0.0).abs() < 1e-12);
306 }
307
308 #[test]
309 fn test_nanmin_nanmax() {
310 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, f64::NAN, 1.0, 4.0]).unwrap();
311 let mn = nanmin(&a, None).unwrap();
312 let mx = nanmax(&a, None).unwrap();
313 assert!((mn.iter().next().unwrap() - 1.0).abs() < 1e-12);
314 assert!((mx.iter().next().unwrap() - 4.0).abs() < 1e-12);
315 }
316
317 #[test]
318 fn test_nanvar() {
319 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
320 let v = nanvar(&a, None, 0).unwrap();
321 let expected = 8.0 / 3.0;
323 assert!((v.iter().next().unwrap() - expected).abs() < 1e-12);
324 }
325
326 #[test]
327 fn test_nanstd() {
328 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
329 let s = nanstd(&a, None, 0).unwrap();
330 let expected = (8.0_f64 / 3.0).sqrt();
331 assert!((s.iter().next().unwrap() - expected).abs() < 1e-12);
332 }
333
334 #[test]
335 fn test_nanprod() {
336 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, f64::NAN, 3.0]).unwrap();
337 let p = nanprod(&a, None).unwrap();
338 assert!((p.iter().next().unwrap() - 6.0).abs() < 1e-12);
339 }
340
341 #[test]
342 fn test_nancumsum() {
343 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
344 let cs = nancumsum(&a, None).unwrap();
345 let data: Vec<f64> = cs.iter().copied().collect();
346 assert!((data[0] - 1.0).abs() < 1e-12);
347 assert!((data[1] - 1.0).abs() < 1e-12);
348 assert!((data[2] - 4.0).abs() < 1e-12);
349 }
350
351 #[test]
352 fn test_nancumprod() {
353 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, f64::NAN, 3.0]).unwrap();
354 let cp = nancumprod(&a, None).unwrap();
355 let data: Vec<f64> = cp.iter().copied().collect();
356 assert!((data[0] - 2.0).abs() < 1e-12);
357 assert!((data[1] - 2.0).abs() < 1e-12);
358 assert!((data[2] - 6.0).abs() < 1e-12);
359 }
360}