1use num_traits::Float;
18
19use crate::array::owned::Array;
20use crate::array::view::ArrayView;
21use crate::dimension::{Axis, Dimension, IxDyn};
22use crate::dtype::Element;
23use crate::error::FerrayResult;
24
25#[inline]
32fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
33 let acc_is_nan = acc.partial_cmp(&acc).is_none();
34 if acc_is_nan {
35 return acc;
36 }
37 let x_is_nan = x.partial_cmp(&x).is_none();
38 if x_is_nan {
39 return x;
40 }
41 match (take_min, x.partial_cmp(&acc)) {
42 (true, Some(std::cmp::Ordering::Less)) => x,
43 (false, Some(std::cmp::Ordering::Greater)) => x,
44 _ => acc,
45 }
46}
47
48impl<T, D> Array<T, D>
53where
54 T: Element + Copy,
55 D: Dimension,
56{
57 pub fn sum(&self) -> T
69 where
70 T: std::ops::Add<Output = T>,
71 {
72 let mut acc = T::zero();
73 for &x in self.iter() {
74 acc = acc + x;
75 }
76 acc
77 }
78
79 pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
84 where
85 T: std::ops::Add<Output = T>,
86 D::NdarrayDim: ndarray::RemoveAxis,
87 {
88 self.fold_axis(axis, T::zero(), |acc, &x| *acc + x)
89 }
90
91 pub fn prod(&self) -> T
95 where
96 T: std::ops::Mul<Output = T>,
97 {
98 let mut acc = T::one();
99 for &x in self.iter() {
100 acc = acc * x;
101 }
102 acc
103 }
104
105 pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
107 where
108 T: std::ops::Mul<Output = T>,
109 D::NdarrayDim: ndarray::RemoveAxis,
110 {
111 self.fold_axis(axis, T::one(), |acc, &x| *acc * x)
112 }
113}
114
115impl<T, D> Array<T, D>
120where
121 T: Element + Copy + PartialOrd,
122 D: Dimension,
123{
124 pub fn min(&self) -> Option<T> {
130 let mut iter = self.iter().copied();
131 let first = iter.next()?;
132 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
133 }
134
135 pub fn max(&self) -> Option<T> {
139 let mut iter = self.iter().copied();
140 let first = iter.next()?;
141 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
142 }
143
144 pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
150 where
151 D::NdarrayDim: ndarray::RemoveAxis,
152 {
153 let ndim = self.ndim();
160 if axis.index() >= ndim {
161 return Err(crate::error::FerrayError::axis_out_of_bounds(
162 axis.index(),
163 ndim,
164 ));
165 }
166 if self.shape()[axis.index()] == 0 {
167 return Err(crate::error::FerrayError::shape_mismatch(
168 "cannot compute min along empty axis",
169 ));
170 }
171 self.fold_axis_min_max(axis, true)
175 }
176
177 pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
181 where
182 D::NdarrayDim: ndarray::RemoveAxis,
183 {
184 let ndim = self.ndim();
185 if axis.index() >= ndim {
186 return Err(crate::error::FerrayError::axis_out_of_bounds(
187 axis.index(),
188 ndim,
189 ));
190 }
191 if self.shape()[axis.index()] == 0 {
192 return Err(crate::error::FerrayError::shape_mismatch(
193 "cannot compute max along empty axis",
194 ));
195 }
196 self.fold_axis_min_max(axis, false)
197 }
198
199 fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
203 where
204 D::NdarrayDim: ndarray::RemoveAxis,
205 {
206 let nd_axis = ndarray::Axis(axis.index());
207 let lanes = self.inner.lanes(nd_axis);
210 let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
211 for lane in self.inner.lanes(nd_axis) {
212 let mut iter = lane.iter().copied();
213 let first = iter.next().unwrap(); let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
215 out.push(result);
216 }
217
218 let mut out_shape: Vec<usize> = self.shape().to_vec();
220 out_shape.remove(axis.index());
221 Array::from_vec(IxDyn::from(&out_shape[..]), out)
222 }
223}
224
225impl<T, D> Array<T, D>
230where
231 T: Element + Float,
232 D: Dimension,
233{
234 pub fn mean(&self) -> Option<T> {
236 let n = self.size();
237 if n == 0 {
238 return None;
239 }
240 let sum: T = self
241 .iter()
242 .copied()
243 .fold(<T as Element>::zero(), |acc, x| acc + x);
244 Some(sum / T::from(n).unwrap())
245 }
246
247 pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
249 where
250 D::NdarrayDim: ndarray::RemoveAxis,
251 {
252 let ndim = self.ndim();
253 if axis.index() >= ndim {
254 return Err(crate::error::FerrayError::axis_out_of_bounds(
255 axis.index(),
256 ndim,
257 ));
258 }
259 let n = self.shape()[axis.index()];
260 if n == 0 {
261 return Err(crate::error::FerrayError::shape_mismatch(
262 "cannot compute mean along empty axis",
263 ));
264 }
265 let sums = self.sum_axis(axis)?;
266 let n_t = T::from(n).unwrap();
267 Ok(sums.mapv(|x| x / n_t))
268 }
269
270 pub fn var(&self, ddof: usize) -> Option<T> {
274 let n = self.size();
275 if n == 0 || ddof >= n {
276 return None;
277 }
278 let mean = self.mean()?;
279 let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
280 acc + (x - mean) * (x - mean)
281 });
282 Some(sum_sq / T::from(n - ddof).unwrap())
283 }
284
285 pub fn std(&self, ddof: usize) -> Option<T> {
287 self.var(ddof).map(|v| v.sqrt())
288 }
289}
290
291impl<D> Array<bool, D>
296where
297 D: Dimension,
298{
299 pub fn any(&self) -> bool {
301 self.iter().any(|&x| x)
302 }
303
304 pub fn all(&self) -> bool {
306 self.iter().all(|&x| x)
307 }
308}
309
310impl<T, D> ArrayView<'_, T, D>
315where
316 T: Element + Copy,
317 D: Dimension,
318{
319 pub fn sum(&self) -> T
321 where
322 T: std::ops::Add<Output = T>,
323 {
324 let mut acc = T::zero();
325 for &x in self.iter() {
326 acc = acc + x;
327 }
328 acc
329 }
330
331 pub fn prod(&self) -> T
333 where
334 T: std::ops::Mul<Output = T>,
335 {
336 let mut acc = T::one();
337 for &x in self.iter() {
338 acc = acc * x;
339 }
340 acc
341 }
342}
343
344impl<T, D> ArrayView<'_, T, D>
345where
346 T: Element + Copy + PartialOrd,
347 D: Dimension,
348{
349 pub fn min(&self) -> Option<T> {
351 let mut iter = self.iter().copied();
352 let first = iter.next()?;
353 Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
354 }
355
356 pub fn max(&self) -> Option<T> {
358 let mut iter = self.iter().copied();
359 let first = iter.next()?;
360 Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
361 }
362}
363
364impl<T, D> ArrayView<'_, T, D>
365where
366 T: Element + Float,
367 D: Dimension,
368{
369 pub fn mean(&self) -> Option<T> {
371 let n = self.size();
372 if n == 0 {
373 return None;
374 }
375 let sum: T = self
376 .iter()
377 .copied()
378 .fold(<T as Element>::zero(), |acc, x| acc + x);
379 Some(sum / T::from(n).unwrap())
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::dimension::{Ix1, Ix2};
387
388 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
389 let n = data.len();
390 Array::from_vec(Ix1::new([n]), data).unwrap()
391 }
392
393 fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
394 Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
395 }
396
397 #[test]
400 fn sum_1d() {
401 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
402 assert_eq!(a.sum(), 10.0);
403 }
404
405 #[test]
406 fn sum_empty_returns_zero() {
407 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
408 assert_eq!(a.sum(), 0.0);
409 }
410
411 #[test]
412 fn sum_axis_2d() {
413 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
414 let s0 = a.sum_axis(Axis(0)).unwrap();
416 assert_eq!(s0.shape(), &[3]);
417 assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
418
419 let s1 = a.sum_axis(Axis(1)).unwrap();
421 assert_eq!(s1.shape(), &[2]);
422 assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
423 }
424
425 #[test]
426 fn prod_1d() {
427 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
428 assert_eq!(a.prod(), 24.0);
429 }
430
431 #[test]
432 fn prod_empty_returns_one() {
433 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
434 assert_eq!(a.prod(), 1.0);
435 }
436
437 #[test]
438 fn prod_axis_2d() {
439 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
440 let p0 = a.prod_axis(Axis(0)).unwrap();
441 assert_eq!(
442 p0.iter().copied().collect::<Vec<_>>(),
443 vec![4.0, 10.0, 18.0]
444 );
445
446 let p1 = a.prod_axis(Axis(1)).unwrap();
447 assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
448 }
449
450 #[test]
453 fn min_max_1d() {
454 let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
455 assert_eq!(a.min(), Some(1.0));
456 assert_eq!(a.max(), Some(9.0));
457 }
458
459 #[test]
460 fn min_max_empty_returns_none() {
461 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
462 assert_eq!(a.min(), None);
463 assert_eq!(a.max(), None);
464 }
465
466 #[test]
467 fn min_max_int() {
468 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
469 assert_eq!(a.min(), Some(-5));
470 assert_eq!(a.max(), Some(4));
471 }
472
473 #[test]
474 fn min_max_axis_2d() {
475 let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
476 let mn0 = a.min_axis(Axis(0)).unwrap();
478 assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
479 let mx0 = a.max_axis(Axis(0)).unwrap();
480 assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
481
482 let mn1 = a.min_axis(Axis(1)).unwrap();
484 assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
485 let mx1 = a.max_axis(Axis(1)).unwrap();
486 assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
487 }
488
489 #[test]
492 fn mean_1d() {
493 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
494 assert_eq!(a.mean(), Some(2.5));
495 }
496
497 #[test]
498 fn mean_empty_returns_none() {
499 let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
500 assert_eq!(a.mean(), None);
501 }
502
503 #[test]
504 fn mean_axis_2d() {
505 let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
506 let m0 = a.mean_axis(Axis(0)).unwrap();
507 assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
508 let m1 = a.mean_axis(Axis(1)).unwrap();
509 assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
510 }
511
512 #[test]
513 fn var_population() {
514 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
515 assert_eq!(a.var(0), Some(2.0));
517 }
518
519 #[test]
520 fn var_sample() {
521 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
522 assert_eq!(a.var(1), Some(2.5));
524 }
525
526 #[test]
527 fn std_basic() {
528 let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
529 let s = a.std(0).unwrap();
530 assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
531 }
532
533 #[test]
534 fn var_ddof_too_large_returns_none() {
535 let a = arr1(vec![1.0, 2.0]);
536 assert_eq!(a.var(2), None);
537 assert_eq!(a.var(5), None);
538 }
539
540 #[test]
543 fn any_all_bool() {
544 let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
545 let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
546 let false_arr =
547 Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
548 let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
549
550 assert!(true_arr.all());
551 assert!(true_arr.any());
552
553 assert!(!mixed.all());
554 assert!(mixed.any());
555
556 assert!(!false_arr.all());
557 assert!(!false_arr.any());
558
559 assert!(empty.all());
561 assert!(!empty.any());
562 }
563
564 #[test]
567 fn view_sum_min_max_mean() {
568 let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
569 let v = a.view();
570 assert_eq!(v.sum(), 10.0);
571 assert_eq!(v.min(), Some(1.0));
572 assert_eq!(v.max(), Some(4.0));
573 assert_eq!(v.mean(), Some(2.5));
574 }
575
576 #[test]
577 fn nan_propagates_in_min_max() {
578 let a = arr1(vec![1.0, f64::NAN, 3.0]);
580 assert!(a.min().unwrap().is_nan());
581 assert!(a.max().unwrap().is_nan());
582
583 let b = arr1(vec![f64::NAN, 1.0, 3.0]);
585 assert!(b.min().unwrap().is_nan());
586 assert!(b.max().unwrap().is_nan());
587
588 let c = arr1(vec![1.0, 3.0, f64::NAN]);
590 assert!(c.min().unwrap().is_nan());
591 assert!(c.max().unwrap().is_nan());
592 }
593}