1use crate::dtype::{NumericDataType, RawDataType};
2use crate::flat_index_generator::FlatIndexGenerator;
3use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
4use crate::ops::reduce_max::ReduceMax;
5use crate::ops::reduce_min::ReduceMin;
6use crate::ops::reduce_min_magnitude::ReduceMinMagnitude;
7use crate::ops::reduce_product::ReduceProduct;
8use crate::ops::reduce_sum::ReduceSum;
9use crate::partial_ord::*;
10use crate::util::to_vec::ToVec;
11use crate::{AxisType, Constructors, FloatDataType, NdArray, StridedMemory};
12use num::NumCast;
13use std::collections::VecDeque;
14use crate::ops::reduce_max_magnitude::ReduceMaxMagnitude;
15
16fn reduced_shape_and_stride(axes: &[isize], shape: &[usize]) -> (Vec<usize>, Vec<usize>) {
26 let ndims = shape.len();
27 let mut axis_mask = vec![false; ndims];
28
29 for &axis in axes.iter() {
30 let axis = axis.as_absolute(ndims);
31 if axis_mask[axis] {
32 panic!("duplicate axes specified");
33 }
34 axis_mask[axis] = true;
35 }
36
37 let mut new_stride = VecDeque::with_capacity(ndims);
38 let mut new_shape = VecDeque::with_capacity(ndims - axes.len());
39
40 let mut stride = 1;
41 for axis in (0..ndims).rev() {
42 if axis_mask[axis] {
43 new_stride.push_front(0);
44 } else {
45 new_stride.push_front(stride);
46 new_shape.push_front(shape[axis]);
47 stride *= shape[axis];
48 }
49 }
50
51 (Vec::from(new_shape), Vec::from(new_stride))
52}
53
54impl<T: RawDataType> NdArray<'_, T> {
55 unsafe fn reduce_uniform_stride(&self, func: impl Fn(T, T) -> T, default: T, stride: usize) -> NdArray<'static, T> {
67 let mut output = default;
68
69 let mut src = self.ptr();
70 for _ in 0..self.size() {
71 output = func(*src, output);
72 src = src.add(stride);
73 }
74
75 NdArray::scalar(output)
76 }
77
78 fn reduce_along(&self, func: impl Fn(T, T) -> T, axes: impl ToVec<isize>, default: T) -> NdArray<'static, T> {
79 let (out_shape, map_stride) = reduced_shape_and_stride(&axes.to_vec(), &self.shape);
80 let (map_shape, map_stride) = collapse_to_uniform_stride(&self.shape, &map_stride);
81
82 let mut output = vec![default; out_shape.iter().product()];
83
84 let mut dst_indices = FlatIndexGenerator::from(&map_shape, &map_stride);
85 let dst: *mut T = output.as_mut_ptr();
86
87 for el in self.flatiter() {
88 unsafe {
89 let dst_i = dst_indices.next().unwrap();
90 let dst_ptr = dst.add(dst_i);
91 *dst_ptr = func(el, *dst_ptr);
92 }
93 }
94
95 unsafe { NdArray::from_contiguous_owned_buffer(out_shape, output) }
96 }
97
98 fn reduce(&self, func: impl Fn(T, T) -> T, default: T) -> NdArray<'static, T> {
99 if let Some(stride) = self.has_uniform_stride() {
100 return unsafe { self.reduce_uniform_stride(func, default, stride) };
101 }
102
103 let mut output = default;
104
105 for el in self.flatiter() {
106 output = func(el, output);
107 }
108
109 NdArray::scalar(output)
110 }
111}
112
113impl<T: NumericDataType> NdArray<'_, T> {
114 pub fn sum(&self) -> NdArray<'static, T> {
125 let output = unsafe { <T as ReduceSum>::sum(self.ptr(), self.shape(), self.stride()) };
126 NdArray::scalar(output)
127 }
128
129 pub fn sum_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
130 self.reduce_along(|val, acc| acc + val, axes, T::zero())
131 }
132
133 pub fn product(&self) -> NdArray<'static, T> {
144 let output = unsafe { <T as ReduceProduct>::product(self.ptr(), self.shape(), self.stride()) };
145 NdArray::scalar(output)
146 }
147
148 pub fn product_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
149 self.reduce_along(|val, acc| acc * val, axes, T::one())
150 }
151
152 pub fn min(&self) -> NdArray<'static, T> {
163 let output = unsafe { <T as ReduceMin>::min(self.ptr(), self.shape(), self.stride()) };
164 NdArray::scalar(output)
165 }
166
167 pub fn min_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
168 self.reduce_along(partial_min, axes, T::max_value())
169 }
170
171 pub fn max(&self) -> NdArray<'static, T> {
182 let output = unsafe { <T as ReduceMax>::max(self.ptr(), self.shape(), self.stride()) };
183 NdArray::scalar(output)
184 }
185
186 pub fn max_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
187 self.reduce_along(partial_max, axes, T::min_value())
188 }
189
190 pub fn min_magnitude(&self) -> NdArray<'static, T> {
201 let output = unsafe { <T as ReduceMinMagnitude>::min_magnitude(self.ptr(), self.shape(), self.stride()) };
202 NdArray::scalar(output)
203 }
204
205 pub fn min_magnitude_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
206 self.reduce_along(partial_min_magnitude, axes, T::max_value())
207 }
208
209 pub fn max_magnitude(&self) -> NdArray<'static, T> {
220 let output = unsafe { <T as ReduceMaxMagnitude>::max_magnitude(self.ptr(), self.shape(), self.stride()) };
221 NdArray::scalar(output)
222 }
223
224 pub fn max_magnitude_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
225 self.reduce_along(partial_max_magnitude, axes, T::zero())
226 }
227
228 pub fn mean(&self) -> NdArray<'static, T>
239 where
240 T: FloatDataType
241 {
242 let n: T = NumCast::from(self.size()).unwrap();
243 self.sum() / n
244 }
245
246 pub fn mean_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T>
247 where
248 T: FloatDataType
249 {
250 let axes = axes.to_vec();
251
252 let mut n = 1;
253 for &axis in axes.iter() {
254 assert!(axis >= 0, "negative axes are not currently supported");
255 n *= self.shape()[axis as usize];
256 }
257
258 let n: T = NumCast::from(n).unwrap();
259 self.sum_along(axes) / n
260 }
261}
262
263
264#[cfg(test)]
265mod tests {
266 use super::reduced_shape_and_stride;
267
268 #[test]
269 fn test_reduce_shape_and_stride() {
270 let shape = vec![3, 2];
271
272 let correct_shape = vec![3];
273 let correct_stride = vec![1, 0];
274 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1], &shape);
275 assert_eq!(new_shape, correct_shape);
276 assert_eq!(new_stride, correct_stride);
277
278 let shape = vec![4, 2, 3];
279
280 let correct_shape = vec![2, 3];
281 let correct_stride = vec![0, 3, 1];
282 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0], &shape);
283 assert_eq!(new_shape, correct_shape);
284 assert_eq!(new_stride, correct_stride);
285
286 let correct_shape = vec![4, 3];
287 let correct_stride = vec![3, 0, 1];
288 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1], &shape);
289 assert_eq!(new_shape, correct_shape);
290 assert_eq!(new_stride, correct_stride);
291
292 let correct_shape = vec![4, 2];
293 let correct_stride = vec![2, 1, 0];
294 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![2], &shape);
295 assert_eq!(new_shape, correct_shape);
296 assert_eq!(new_stride, correct_stride);
297
298 let correct_shape = vec![3];
299 let correct_stride = vec![0, 0, 1];
300 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0, 1], &shape);
301 assert_eq!(new_shape, correct_shape);
302 assert_eq!(new_stride, correct_stride);
303
304 let correct_shape = vec![2];
305 let correct_stride = vec![0, 1, 0];
306 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0, 2], &shape);
307 assert_eq!(new_shape, correct_shape);
308 assert_eq!(new_stride, correct_stride);
309
310 let correct_shape = vec![4];
311 let correct_stride = vec![1, 0, 0];
312 let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1, 2], &shape);
313 assert_eq!(new_shape, correct_shape);
314 assert_eq!(new_stride, correct_stride);
315 }
316}