ndarray/numeric/impl_numeric.rs
1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul, MulAssign, Sub};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17use crate::Slice;
18
19/// # Numerical Methods for Arrays
20impl<A, D> ArrayRef<A, D>
21where D: Dimension
22{
23 /// Return the sum of all elements in the array.
24 ///
25 /// ```
26 /// use ndarray::arr2;
27 ///
28 /// let a = arr2(&[[1., 2.],
29 /// [3., 4.]]);
30 /// assert_eq!(a.sum(), 10.);
31 /// ```
32 pub fn sum(&self) -> A
33 where A: Clone + Add<Output = A> + num_traits::Zero
34 {
35 if let Some(slc) = self.as_slice_memory_order() {
36 return numeric_util::unrolled_fold(slc, A::zero, A::add);
37 }
38 let mut sum = A::zero();
39 for row in self.rows() {
40 if let Some(slc) = row.as_slice() {
41 sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
42 } else {
43 sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
44 }
45 }
46 sum
47 }
48
49 /// Returns the [arithmetic mean] x̅ of all elements in the array:
50 ///
51 /// ```text
52 /// 1 n
53 /// x̅ = ― ∑ xᵢ
54 /// n i=1
55 /// ```
56 ///
57 /// If the array is empty, `None` is returned.
58 ///
59 /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
60 ///
61 /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
62 pub fn mean(&self) -> Option<A>
63 where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
64 {
65 let n_elements = self.len();
66 if n_elements == 0 {
67 None
68 } else {
69 let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
70 Some(self.sum() / n_elements)
71 }
72 }
73
74 /// Return the product of all elements in the array.
75 ///
76 /// ```
77 /// use ndarray::arr2;
78 ///
79 /// let a = arr2(&[[1., 2.],
80 /// [3., 4.]]);
81 /// assert_eq!(a.product(), 24.);
82 /// ```
83 pub fn product(&self) -> A
84 where A: Clone + Mul<Output = A> + num_traits::One
85 {
86 if let Some(slc) = self.as_slice_memory_order() {
87 return numeric_util::unrolled_fold(slc, A::one, A::mul);
88 }
89 let mut sum = A::one();
90 for row in self.rows() {
91 if let Some(slc) = row.as_slice() {
92 sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
93 } else {
94 sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
95 }
96 }
97 sum
98 }
99
100 /// Return the cumulative product of elements along a given axis.
101 ///
102 /// ```
103 /// use ndarray::{arr2, Axis};
104 ///
105 /// let a = arr2(&[[1., 2., 3.],
106 /// [4., 5., 6.]]);
107 ///
108 /// // Cumulative product along rows (axis 0)
109 /// assert_eq!(
110 /// a.cumprod(Axis(0)),
111 /// arr2(&[[1., 2., 3.],
112 /// [4., 10., 18.]])
113 /// );
114 ///
115 /// // Cumulative product along columns (axis 1)
116 /// assert_eq!(
117 /// a.cumprod(Axis(1)),
118 /// arr2(&[[1., 2., 6.],
119 /// [4., 20., 120.]])
120 /// );
121 /// ```
122 ///
123 /// **Panics** if `axis` is out of bounds.
124 #[track_caller]
125 pub fn cumprod(&self, axis: Axis) -> Array<A, D>
126 where
127 A: Clone + Mul<Output = A> + MulAssign,
128 D: Dimension + RemoveAxis,
129 {
130 if axis.0 >= self.ndim() {
131 panic!("axis is out of bounds for array of dimension");
132 }
133
134 let mut result = self.to_owned();
135 result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone());
136 result
137 }
138
139 /// Return variance of elements in the array.
140 ///
141 /// The variance is computed using the [Welford one-pass
142 /// algorithm](https://www.jstor.org/stable/1266577).
143 ///
144 /// The parameter `ddof` specifies the "delta degrees of freedom". For
145 /// example, to calculate the population variance, use `ddof = 0`, or to
146 /// calculate the sample variance, use `ddof = 1`.
147 ///
148 /// The variance is defined as:
149 ///
150 /// ```text
151 /// 1 n
152 /// variance = ―――――――― ∑ (xᵢ - x̅)²
153 /// n - ddof i=1
154 /// ```
155 ///
156 /// where
157 ///
158 /// ```text
159 /// 1 n
160 /// x̅ = ― ∑ xᵢ
161 /// n i=1
162 /// ```
163 ///
164 /// and `n` is the length of the array.
165 ///
166 /// **Panics** if `ddof` is less than zero or greater than `n`
167 ///
168 /// # Example
169 ///
170 /// ```
171 /// use ndarray::array;
172 /// use approx::assert_abs_diff_eq;
173 ///
174 /// let a = array![1., -4.32, 1.14, 0.32];
175 /// let var = a.var(1.);
176 /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
177 /// ```
178 #[track_caller]
179 #[cfg(feature = "std")]
180 pub fn var(&self, ddof: A) -> A
181 where A: Float + FromPrimitive
182 {
183 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
184 let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
185 assert!(
186 !(ddof < zero || ddof > n),
187 "`ddof` must not be less than zero or greater than the length of \
188 the axis",
189 );
190 let dof = n - ddof;
191 let mut mean = A::zero();
192 let mut sum_sq = A::zero();
193 let mut i = 0;
194 self.for_each(|&x| {
195 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
196 let delta = x - mean;
197 mean = mean + delta / count;
198 sum_sq = (x - mean).mul_add(delta, sum_sq);
199 i += 1;
200 });
201 sum_sq / dof
202 }
203
204 /// Return standard deviation of elements in the array.
205 ///
206 /// The standard deviation is computed from the variance using
207 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
208 ///
209 /// The parameter `ddof` specifies the "delta degrees of freedom". For
210 /// example, to calculate the population standard deviation, use `ddof = 0`,
211 /// or to calculate the sample standard deviation, use `ddof = 1`.
212 ///
213 /// The standard deviation is defined as:
214 ///
215 /// ```text
216 /// ⎛ 1 n ⎞
217 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
218 /// ⎝ n - ddof i=1 ⎠
219 /// ```
220 ///
221 /// where
222 ///
223 /// ```text
224 /// 1 n
225 /// x̅ = ― ∑ xᵢ
226 /// n i=1
227 /// ```
228 ///
229 /// and `n` is the length of the array.
230 ///
231 /// **Panics** if `ddof` is less than zero or greater than `n`
232 ///
233 /// # Example
234 ///
235 /// ```
236 /// use ndarray::array;
237 /// use approx::assert_abs_diff_eq;
238 ///
239 /// let a = array![1., -4.32, 1.14, 0.32];
240 /// let stddev = a.std(1.);
241 /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
242 /// ```
243 #[track_caller]
244 #[cfg(feature = "std")]
245 pub fn std(&self, ddof: A) -> A
246 where A: Float + FromPrimitive
247 {
248 self.var(ddof).sqrt()
249 }
250
251 /// Return sum along `axis`.
252 ///
253 /// ```
254 /// use ndarray::{aview0, aview1, arr2, Axis};
255 ///
256 /// let a = arr2(&[[1., 2., 3.],
257 /// [4., 5., 6.]]);
258 /// assert!(
259 /// a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
260 /// a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
261 ///
262 /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
263 /// );
264 /// ```
265 ///
266 /// **Panics** if `axis` is out of bounds.
267 #[track_caller]
268 pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
269 where
270 A: Clone + Zero + Add<Output = A>,
271 D: RemoveAxis,
272 {
273 let min_stride_axis = self._dim().min_stride_axis(self._strides());
274 if axis == min_stride_axis {
275 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
276 } else {
277 let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
278 for subview in self.axis_iter(axis) {
279 res = res + &subview;
280 }
281 res
282 }
283 }
284
285 /// Return product along `axis`.
286 ///
287 /// The product of an empty array is 1.
288 ///
289 /// ```
290 /// use ndarray::{aview0, aview1, arr2, Axis};
291 ///
292 /// let a = arr2(&[[1., 2., 3.],
293 /// [4., 5., 6.]]);
294 ///
295 /// assert!(
296 /// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
297 /// a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
298 ///
299 /// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
300 /// );
301 /// ```
302 ///
303 /// **Panics** if `axis` is out of bounds.
304 #[track_caller]
305 pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
306 where
307 A: Clone + One + Mul<Output = A>,
308 D: RemoveAxis,
309 {
310 let min_stride_axis = self._dim().min_stride_axis(self._strides());
311 if axis == min_stride_axis {
312 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
313 } else {
314 let mut res = Array::ones(self.raw_dim().remove_axis(axis));
315 for subview in self.axis_iter(axis) {
316 res = res * &subview;
317 }
318 res
319 }
320 }
321
322 /// Return mean along `axis`.
323 ///
324 /// Return `None` if the length of the axis is zero.
325 ///
326 /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
327 /// fails for the axis length.
328 ///
329 /// ```
330 /// use ndarray::{aview0, aview1, arr2, Axis};
331 ///
332 /// let a = arr2(&[[1., 2., 3.],
333 /// [4., 5., 6.]]);
334 /// assert!(
335 /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
336 /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
337 ///
338 /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
339 /// );
340 /// ```
341 #[track_caller]
342 pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
343 where
344 A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
345 D: RemoveAxis,
346 {
347 let axis_length = self.len_of(axis);
348 if axis_length == 0 {
349 None
350 } else {
351 let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
352 let sum = self.sum_axis(axis);
353 Some(sum / aview0(&axis_length))
354 }
355 }
356
357 /// Return variance along `axis`.
358 ///
359 /// The variance is computed using the [Welford one-pass
360 /// algorithm](https://www.jstor.org/stable/1266577).
361 ///
362 /// The parameter `ddof` specifies the "delta degrees of freedom". For
363 /// example, to calculate the population variance, use `ddof = 0`, or to
364 /// calculate the sample variance, use `ddof = 1`.
365 ///
366 /// The variance is defined as:
367 ///
368 /// ```text
369 /// 1 n
370 /// variance = ―――――――― ∑ (xᵢ - x̅)²
371 /// n - ddof i=1
372 /// ```
373 ///
374 /// where
375 ///
376 /// ```text
377 /// 1 n
378 /// x̅ = ― ∑ xᵢ
379 /// n i=1
380 /// ```
381 ///
382 /// and `n` is the length of the axis.
383 ///
384 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
385 /// is out of bounds, or if `A::from_usize()` fails for any any of the
386 /// numbers in the range `0..=n`.
387 ///
388 /// # Example
389 ///
390 /// ```
391 /// use ndarray::{aview1, arr2, Axis};
392 ///
393 /// let a = arr2(&[[1., 2.],
394 /// [3., 4.],
395 /// [5., 6.]]);
396 /// let var = a.var_axis(Axis(0), 1.);
397 /// assert_eq!(var, aview1(&[4., 4.]));
398 /// ```
399 #[track_caller]
400 #[cfg(feature = "std")]
401 pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
402 where
403 A: Float + FromPrimitive,
404 D: RemoveAxis,
405 {
406 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
407 let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
408 assert!(
409 !(ddof < zero || ddof > n),
410 "`ddof` must not be less than zero or greater than the length of \
411 the axis",
412 );
413 let dof = n - ddof;
414 let mut mean = Array::<A, _>::zeros(self._dim().remove_axis(axis));
415 let mut sum_sq = Array::<A, _>::zeros(self._dim().remove_axis(axis));
416 for (i, subview) in self.axis_iter(axis).enumerate() {
417 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
418 azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
419 let delta = x - *mean;
420 *mean = *mean + delta / count;
421 *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
422 });
423 }
424 sum_sq.mapv_into(|s| s / dof)
425 }
426
427 /// Return standard deviation along `axis`.
428 ///
429 /// The standard deviation is computed from the variance using
430 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
431 ///
432 /// The parameter `ddof` specifies the "delta degrees of freedom". For
433 /// example, to calculate the population standard deviation, use `ddof = 0`,
434 /// or to calculate the sample standard deviation, use `ddof = 1`.
435 ///
436 /// The standard deviation is defined as:
437 ///
438 /// ```text
439 /// ⎛ 1 n ⎞
440 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
441 /// ⎝ n - ddof i=1 ⎠
442 /// ```
443 ///
444 /// where
445 ///
446 /// ```text
447 /// 1 n
448 /// x̅ = ― ∑ xᵢ
449 /// n i=1
450 /// ```
451 ///
452 /// and `n` is the length of the axis.
453 ///
454 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
455 /// is out of bounds, or if `A::from_usize()` fails for any any of the
456 /// numbers in the range `0..=n`.
457 ///
458 /// # Example
459 ///
460 /// ```
461 /// use ndarray::{aview1, arr2, Axis};
462 ///
463 /// let a = arr2(&[[1., 2.],
464 /// [3., 4.],
465 /// [5., 6.]]);
466 /// let stddev = a.std_axis(Axis(0), 1.);
467 /// assert_eq!(stddev, aview1(&[2., 2.]));
468 /// ```
469 #[track_caller]
470 #[cfg(feature = "std")]
471 pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
472 where
473 A: Float + FromPrimitive,
474 D: RemoveAxis,
475 {
476 self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
477 }
478
479 /// Calculates the (forward) finite differences of order `n`, along the `axis`.
480 /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
481 ///
482 /// For `n>=2`, the process is iterated:
483 /// ```
484 /// use ndarray::{array, Axis};
485 /// let arr = array![1.0, 2.0, 5.0];
486 /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
487 /// ```
488 /// **Panics** if `axis` is out of bounds
489 ///
490 /// **Panics** if `n` is too big / the array is to short:
491 /// ```should_panic
492 /// use ndarray::{array, Axis};
493 /// array![1.0, 2.0, 3.0].diff(10, Axis(0));
494 /// ```
495 pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
496 where A: Sub<A, Output = A> + Zero + Clone
497 {
498 if n == 0 {
499 return self.to_owned();
500 }
501 assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
502 assert!(
503 n < self.shape()[axis.0],
504 "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
505 n + 1,
506 self.shape()[axis.0]
507 );
508
509 let mut inp = self.to_owned();
510 let mut out = Array::zeros({
511 let mut inp_dim = self.raw_dim();
512 // inp_dim[axis.0] >= 1 as per the 2nd assertion.
513 inp_dim[axis.0] -= 1;
514 inp_dim
515 });
516 for _ in 0..n {
517 let head = inp.slice_axis(axis, Slice::from(..-1));
518 let tail = inp.slice_axis(axis, Slice::from(1..));
519
520 azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());
521
522 // feed the output as the input to the next iteration
523 std::mem::swap(&mut inp, &mut out);
524
525 // adjust the new output array width along `axis`.
526 // Current situation: width of `inp`: k, `out`: k+1
527 // needed width: `inp`: k, `out`: k-1
528 // slice is possible, since k >= 1.
529 out.slice_axis_inplace(axis, Slice::from(..-2));
530 }
531 inp
532 }
533}