ndarray/numeric/
impl_numeric.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul, MulAssign, Sub};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17use crate::Slice;
18
19/// # Numerical Methods for Arrays
20impl<A, D> ArrayRef<A, D>
21where D: Dimension
22{
23    /// Return the sum of all elements in the array.
24    ///
25    /// ```
26    /// use ndarray::arr2;
27    ///
28    /// let a = arr2(&[[1., 2.],
29    ///                [3., 4.]]);
30    /// assert_eq!(a.sum(), 10.);
31    /// ```
32    pub fn sum(&self) -> A
33    where A: Clone + Add<Output = A> + num_traits::Zero
34    {
35        if let Some(slc) = self.as_slice_memory_order() {
36            return numeric_util::unrolled_fold(slc, A::zero, A::add);
37        }
38        let mut sum = A::zero();
39        for row in self.rows() {
40            if let Some(slc) = row.as_slice() {
41                sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
42            } else {
43                sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
44            }
45        }
46        sum
47    }
48
49    /// Returns the [arithmetic mean] x̅ of all elements in the array:
50    ///
51    /// ```text
52    ///     1   n
53    /// x̅ = ―   ∑ xᵢ
54    ///     n  i=1
55    /// ```
56    ///
57    /// If the array is empty, `None` is returned.
58    ///
59    /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
60    ///
61    /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
62    pub fn mean(&self) -> Option<A>
63    where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
64    {
65        let n_elements = self.len();
66        if n_elements == 0 {
67            None
68        } else {
69            let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
70            Some(self.sum() / n_elements)
71        }
72    }
73
74    /// Return the product of all elements in the array.
75    ///
76    /// ```
77    /// use ndarray::arr2;
78    ///
79    /// let a = arr2(&[[1., 2.],
80    ///                [3., 4.]]);
81    /// assert_eq!(a.product(), 24.);
82    /// ```
83    pub fn product(&self) -> A
84    where A: Clone + Mul<Output = A> + num_traits::One
85    {
86        if let Some(slc) = self.as_slice_memory_order() {
87            return numeric_util::unrolled_fold(slc, A::one, A::mul);
88        }
89        let mut sum = A::one();
90        for row in self.rows() {
91            if let Some(slc) = row.as_slice() {
92                sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
93            } else {
94                sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
95            }
96        }
97        sum
98    }
99
100    /// Return the cumulative product of elements along a given axis.
101    ///
102    /// ```
103    /// use ndarray::{arr2, Axis};
104    ///
105    /// let a = arr2(&[[1., 2., 3.],
106    ///                [4., 5., 6.]]);
107    ///
108    /// // Cumulative product along rows (axis 0)
109    /// assert_eq!(
110    ///     a.cumprod(Axis(0)),
111    ///     arr2(&[[1., 2., 3.],
112    ///           [4., 10., 18.]])
113    /// );
114    ///
115    /// // Cumulative product along columns (axis 1)
116    /// assert_eq!(
117    ///     a.cumprod(Axis(1)),
118    ///     arr2(&[[1., 2., 6.],
119    ///           [4., 20., 120.]])
120    /// );
121    /// ```
122    ///
123    /// **Panics** if `axis` is out of bounds.
124    #[track_caller]
125    pub fn cumprod(&self, axis: Axis) -> Array<A, D>
126    where
127        A: Clone + Mul<Output = A> + MulAssign,
128        D: Dimension + RemoveAxis,
129    {
130        if axis.0 >= self.ndim() {
131            panic!("axis is out of bounds for array of dimension");
132        }
133
134        let mut result = self.to_owned();
135        result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone());
136        result
137    }
138
139    /// Return variance of elements in the array.
140    ///
141    /// The variance is computed using the [Welford one-pass
142    /// algorithm](https://www.jstor.org/stable/1266577).
143    ///
144    /// The parameter `ddof` specifies the "delta degrees of freedom". For
145    /// example, to calculate the population variance, use `ddof = 0`, or to
146    /// calculate the sample variance, use `ddof = 1`.
147    ///
148    /// The variance is defined as:
149    ///
150    /// ```text
151    ///               1       n
152    /// variance = ――――――――   ∑ (xᵢ - x̅)²
153    ///            n - ddof  i=1
154    /// ```
155    ///
156    /// where
157    ///
158    /// ```text
159    ///     1   n
160    /// x̅ = ―   ∑ xᵢ
161    ///     n  i=1
162    /// ```
163    ///
164    /// and `n` is the length of the array.
165    ///
166    /// **Panics** if `ddof` is less than zero or greater than `n`
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// use ndarray::array;
172    /// use approx::assert_abs_diff_eq;
173    ///
174    /// let a = array![1., -4.32, 1.14, 0.32];
175    /// let var = a.var(1.);
176    /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
177    /// ```
178    #[track_caller]
179    #[cfg(feature = "std")]
180    pub fn var(&self, ddof: A) -> A
181    where A: Float + FromPrimitive
182    {
183        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
184        let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
185        assert!(
186            !(ddof < zero || ddof > n),
187            "`ddof` must not be less than zero or greater than the length of \
188             the axis",
189        );
190        let dof = n - ddof;
191        let mut mean = A::zero();
192        let mut sum_sq = A::zero();
193        let mut i = 0;
194        self.for_each(|&x| {
195            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
196            let delta = x - mean;
197            mean = mean + delta / count;
198            sum_sq = (x - mean).mul_add(delta, sum_sq);
199            i += 1;
200        });
201        sum_sq / dof
202    }
203
204    /// Return standard deviation of elements in the array.
205    ///
206    /// The standard deviation is computed from the variance using
207    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
208    ///
209    /// The parameter `ddof` specifies the "delta degrees of freedom". For
210    /// example, to calculate the population standard deviation, use `ddof = 0`,
211    /// or to calculate the sample standard deviation, use `ddof = 1`.
212    ///
213    /// The standard deviation is defined as:
214    ///
215    /// ```text
216    ///               ⎛    1       n          ⎞
217    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
218    ///               ⎝ n - ddof  i=1         ⎠
219    /// ```
220    ///
221    /// where
222    ///
223    /// ```text
224    ///     1   n
225    /// x̅ = ―   ∑ xᵢ
226    ///     n  i=1
227    /// ```
228    ///
229    /// and `n` is the length of the array.
230    ///
231    /// **Panics** if `ddof` is less than zero or greater than `n`
232    ///
233    /// # Example
234    ///
235    /// ```
236    /// use ndarray::array;
237    /// use approx::assert_abs_diff_eq;
238    ///
239    /// let a = array![1., -4.32, 1.14, 0.32];
240    /// let stddev = a.std(1.);
241    /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
242    /// ```
243    #[track_caller]
244    #[cfg(feature = "std")]
245    pub fn std(&self, ddof: A) -> A
246    where A: Float + FromPrimitive
247    {
248        self.var(ddof).sqrt()
249    }
250
251    /// Return sum along `axis`.
252    ///
253    /// ```
254    /// use ndarray::{aview0, aview1, arr2, Axis};
255    ///
256    /// let a = arr2(&[[1., 2., 3.],
257    ///                [4., 5., 6.]]);
258    /// assert!(
259    ///     a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
260    ///     a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
261    ///
262    ///     a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
263    /// );
264    /// ```
265    ///
266    /// **Panics** if `axis` is out of bounds.
267    #[track_caller]
268    pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
269    where
270        A: Clone + Zero + Add<Output = A>,
271        D: RemoveAxis,
272    {
273        let min_stride_axis = self._dim().min_stride_axis(self._strides());
274        if axis == min_stride_axis {
275            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
276        } else {
277            let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
278            for subview in self.axis_iter(axis) {
279                res = res + &subview;
280            }
281            res
282        }
283    }
284
285    /// Return product along `axis`.
286    ///
287    /// The product of an empty array is 1.
288    ///
289    /// ```
290    /// use ndarray::{aview0, aview1, arr2, Axis};
291    ///
292    /// let a = arr2(&[[1., 2., 3.],
293    ///                [4., 5., 6.]]);
294    ///
295    /// assert!(
296    ///     a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
297    ///     a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
298    ///
299    ///     a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
300    /// );
301    /// ```
302    ///
303    /// **Panics** if `axis` is out of bounds.
304    #[track_caller]
305    pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
306    where
307        A: Clone + One + Mul<Output = A>,
308        D: RemoveAxis,
309    {
310        let min_stride_axis = self._dim().min_stride_axis(self._strides());
311        if axis == min_stride_axis {
312            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
313        } else {
314            let mut res = Array::ones(self.raw_dim().remove_axis(axis));
315            for subview in self.axis_iter(axis) {
316                res = res * &subview;
317            }
318            res
319        }
320    }
321
322    /// Return mean along `axis`.
323    ///
324    /// Return `None` if the length of the axis is zero.
325    ///
326    /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
327    /// fails for the axis length.
328    ///
329    /// ```
330    /// use ndarray::{aview0, aview1, arr2, Axis};
331    ///
332    /// let a = arr2(&[[1., 2., 3.],
333    ///                [4., 5., 6.]]);
334    /// assert!(
335    ///     a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
336    ///     a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
337    ///
338    ///     a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
339    /// );
340    /// ```
341    #[track_caller]
342    pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
343    where
344        A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
345        D: RemoveAxis,
346    {
347        let axis_length = self.len_of(axis);
348        if axis_length == 0 {
349            None
350        } else {
351            let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
352            let sum = self.sum_axis(axis);
353            Some(sum / aview0(&axis_length))
354        }
355    }
356
357    /// Return variance along `axis`.
358    ///
359    /// The variance is computed using the [Welford one-pass
360    /// algorithm](https://www.jstor.org/stable/1266577).
361    ///
362    /// The parameter `ddof` specifies the "delta degrees of freedom". For
363    /// example, to calculate the population variance, use `ddof = 0`, or to
364    /// calculate the sample variance, use `ddof = 1`.
365    ///
366    /// The variance is defined as:
367    ///
368    /// ```text
369    ///               1       n
370    /// variance = ――――――――   ∑ (xᵢ - x̅)²
371    ///            n - ddof  i=1
372    /// ```
373    ///
374    /// where
375    ///
376    /// ```text
377    ///     1   n
378    /// x̅ = ―   ∑ xᵢ
379    ///     n  i=1
380    /// ```
381    ///
382    /// and `n` is the length of the axis.
383    ///
384    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
385    /// is out of bounds, or if `A::from_usize()` fails for any any of the
386    /// numbers in the range `0..=n`.
387    ///
388    /// # Example
389    ///
390    /// ```
391    /// use ndarray::{aview1, arr2, Axis};
392    ///
393    /// let a = arr2(&[[1., 2.],
394    ///                [3., 4.],
395    ///                [5., 6.]]);
396    /// let var = a.var_axis(Axis(0), 1.);
397    /// assert_eq!(var, aview1(&[4., 4.]));
398    /// ```
399    #[track_caller]
400    #[cfg(feature = "std")]
401    pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
402    where
403        A: Float + FromPrimitive,
404        D: RemoveAxis,
405    {
406        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
407        let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
408        assert!(
409            !(ddof < zero || ddof > n),
410            "`ddof` must not be less than zero or greater than the length of \
411             the axis",
412        );
413        let dof = n - ddof;
414        let mut mean = Array::<A, _>::zeros(self._dim().remove_axis(axis));
415        let mut sum_sq = Array::<A, _>::zeros(self._dim().remove_axis(axis));
416        for (i, subview) in self.axis_iter(axis).enumerate() {
417            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
418            azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
419                let delta = x - *mean;
420                *mean = *mean + delta / count;
421                *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
422            });
423        }
424        sum_sq.mapv_into(|s| s / dof)
425    }
426
427    /// Return standard deviation along `axis`.
428    ///
429    /// The standard deviation is computed from the variance using
430    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
431    ///
432    /// The parameter `ddof` specifies the "delta degrees of freedom". For
433    /// example, to calculate the population standard deviation, use `ddof = 0`,
434    /// or to calculate the sample standard deviation, use `ddof = 1`.
435    ///
436    /// The standard deviation is defined as:
437    ///
438    /// ```text
439    ///               ⎛    1       n          ⎞
440    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
441    ///               ⎝ n - ddof  i=1         ⎠
442    /// ```
443    ///
444    /// where
445    ///
446    /// ```text
447    ///     1   n
448    /// x̅ = ―   ∑ xᵢ
449    ///     n  i=1
450    /// ```
451    ///
452    /// and `n` is the length of the axis.
453    ///
454    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
455    /// is out of bounds, or if `A::from_usize()` fails for any any of the
456    /// numbers in the range `0..=n`.
457    ///
458    /// # Example
459    ///
460    /// ```
461    /// use ndarray::{aview1, arr2, Axis};
462    ///
463    /// let a = arr2(&[[1., 2.],
464    ///                [3., 4.],
465    ///                [5., 6.]]);
466    /// let stddev = a.std_axis(Axis(0), 1.);
467    /// assert_eq!(stddev, aview1(&[2., 2.]));
468    /// ```
469    #[track_caller]
470    #[cfg(feature = "std")]
471    pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
472    where
473        A: Float + FromPrimitive,
474        D: RemoveAxis,
475    {
476        self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
477    }
478
479    /// Calculates the (forward) finite differences of order `n`, along the `axis`.
480    /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
481    ///
482    /// For `n>=2`, the process is iterated:
483    /// ```
484    /// use ndarray::{array, Axis};
485    /// let arr = array![1.0, 2.0, 5.0];
486    /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
487    /// ```
488    /// **Panics** if `axis` is out of bounds
489    ///
490    /// **Panics** if `n` is too big / the array is to short:
491    /// ```should_panic
492    /// use ndarray::{array, Axis};
493    /// array![1.0, 2.0, 3.0].diff(10, Axis(0));
494    /// ```
495    pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
496    where A: Sub<A, Output = A> + Zero + Clone
497    {
498        if n == 0 {
499            return self.to_owned();
500        }
501        assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
502        assert!(
503            n < self.shape()[axis.0],
504            "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
505            n + 1,
506            self.shape()[axis.0]
507        );
508
509        let mut inp = self.to_owned();
510        let mut out = Array::zeros({
511            let mut inp_dim = self.raw_dim();
512            // inp_dim[axis.0] >= 1 as per the 2nd assertion.
513            inp_dim[axis.0] -= 1;
514            inp_dim
515        });
516        for _ in 0..n {
517            let head = inp.slice_axis(axis, Slice::from(..-1));
518            let tail = inp.slice_axis(axis, Slice::from(1..));
519
520            azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());
521
522            // feed the output as the input to the next iteration
523            std::mem::swap(&mut inp, &mut out);
524
525            // adjust the new output array width along `axis`.
526            // Current situation: width of `inp`: k, `out`: k+1
527            // needed width:               `inp`: k, `out`: k-1
528            // slice is possible, since k >= 1.
529            out.slice_axis_inplace(axis, Slice::from(..-2));
530        }
531        inp
532    }
533}