redstone_ml/ndarray/
reduce.rs

1use crate::dtype::{NumericDataType, RawDataType};
2use crate::flat_index_generator::FlatIndexGenerator;
3use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
4use crate::ops::reduce_max::ReduceMax;
5use crate::ops::reduce_min::ReduceMin;
6use crate::ops::reduce_min_magnitude::ReduceMinMagnitude;
7use crate::ops::reduce_product::ReduceProduct;
8use crate::ops::reduce_sum::ReduceSum;
9use crate::partial_ord::*;
10use crate::util::to_vec::ToVec;
11use crate::{AxisType, Constructors, FloatDataType, NdArray, StridedMemory};
12use num::NumCast;
13use std::collections::VecDeque;
14use crate::ops::reduce_max_magnitude::ReduceMaxMagnitude;
15
16/// Returns a tuple `(output_shape, map_stride)`
17///
18/// - `output_shape` is the shape of the output ndarray after the reduction operation
19///
20/// - `map_stride` maps a flat iteration over the input ndarray to iteration over the output ndarray.
21///   For example, if the reduce operation is addition, `reduce` iterates through the input ndarray
22///   element-by-element and `map_stride` describes iteration over the output ndarray
23///   to add each element to the correct location.
24///   It should now make sense why map_stride contains zeros on every reduced axis
25fn reduced_shape_and_stride(axes: &[isize], shape: &[usize]) -> (Vec<usize>, Vec<usize>) {
26    let ndims = shape.len();
27    let mut axis_mask = vec![false; ndims];
28
29    for &axis in axes.iter() {
30        let axis = axis.as_absolute(ndims);
31        if axis_mask[axis] {
32            panic!("duplicate axes specified");
33        }
34        axis_mask[axis] = true;
35    }
36
37    let mut new_stride = VecDeque::with_capacity(ndims);
38    let mut new_shape = VecDeque::with_capacity(ndims - axes.len());
39
40    let mut stride = 1;
41    for axis in (0..ndims).rev() {
42        if axis_mask[axis] {
43            new_stride.push_front(0);
44        } else {
45            new_stride.push_front(stride);
46            new_shape.push_front(shape[axis]);
47            stride *= shape[axis];
48        }
49    }
50
51    (Vec::from(new_shape), Vec::from(new_stride))
52}
53
54impl<T: RawDataType> NdArray<'_, T> {
55    /// Reduces the elements of a contiguous ndarray into a scalar using the specified function.
56    ///
57    /// # Safety
58    /// - Ensure that the underlying ndarray has uniform stride in memory
59    ///
60    /// # Parameters
61    /// - `func`: A closure or function that takes two arguments (the next value to be reduced
62    ///   and the value of the accumulator) and returns a reduction of both.
63    ///   For example, when the reduction operation is addition, `|src, acc| src + acc`
64    /// - `default`: The initial value used as the accumulator for the reduction.
65    /// - `stride`: The number of `T` elements in memory between consecutive elements of `self`
66    unsafe fn reduce_uniform_stride(&self, func: impl Fn(T, T) -> T, default: T, stride: usize) -> NdArray<'static, T> {
67        let mut output = default;
68
69        let mut src = self.ptr();
70        for _ in 0..self.size() {
71            output = func(*src, output);
72            src = src.add(stride);
73        }
74
75        NdArray::scalar(output)
76    }
77
78    fn reduce_along(&self, func: impl Fn(T, T) -> T, axes: impl ToVec<isize>, default: T) -> NdArray<'static, T> {
79        let (out_shape, map_stride) = reduced_shape_and_stride(&axes.to_vec(), &self.shape);
80        let (map_shape, map_stride) = collapse_to_uniform_stride(&self.shape, &map_stride);
81
82        let mut output = vec![default; out_shape.iter().product()];
83
84        let mut dst_indices = FlatIndexGenerator::from(&map_shape, &map_stride);
85        let dst: *mut T = output.as_mut_ptr();
86
87        for el in self.flatiter() {
88            unsafe {
89                let dst_i = dst_indices.next().unwrap();
90                let dst_ptr = dst.add(dst_i);
91                *dst_ptr = func(el, *dst_ptr);
92            }
93        }
94
95        unsafe { NdArray::from_contiguous_owned_buffer(out_shape, output) }
96    }
97
98    fn reduce(&self, func: impl Fn(T, T) -> T, default: T) -> NdArray<'static, T> {
99        if let Some(stride) = self.has_uniform_stride() {
100            return unsafe { self.reduce_uniform_stride(func, default, stride) };
101        }
102
103        let mut output = default;
104
105        for el in self.flatiter() {
106            output = func(el, output);
107        }
108
109        NdArray::scalar(output)
110    }
111}
112
113impl<T: NumericDataType> NdArray<'_, T> {
114    /// Computes the sum of all elements in the array.
115    ///
116    /// # Example
117    /// ```
118    /// use redstone_ml::*;
119    ///
120    /// let array = NdArray::new(vec![1, 2, 3, 4]);
121    /// let sum = array.sum();
122    /// assert_eq!(sum.value(), 1 + 2 + 3 + 4);
123    /// ```
124    pub fn sum(&self) -> NdArray<'static, T> {
125        let output = unsafe { <T as ReduceSum>::sum(self.ptr(), self.shape(), self.stride()) };
126        NdArray::scalar(output)
127    }
128
129    pub fn sum_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
130        self.reduce_along(|val, acc| acc + val, axes, T::zero())
131    }
132
133    /// Computes the product of all elements in the array.
134    ///
135    /// # Example
136    /// ```
137    /// use redstone_ml::*;
138    ///
139    /// let array = NdArray::new(vec![1, 2, 3, 4]);
140    /// let prod = array.product();
141    /// assert_eq!(prod.value(), 1 * 2 * 3 * 4);
142    /// ```
143    pub fn product(&self) -> NdArray<'static, T> {
144        let output = unsafe { <T as ReduceProduct>::product(self.ptr(), self.shape(), self.stride()) };
145        NdArray::scalar(output)
146    }
147
148    pub fn product_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
149        self.reduce_along(|val, acc| acc * val, axes, T::one())
150    }
151
152    /// Computes the minimum of all elements in the array.
153    ///
154    /// # Example
155    /// ```
156    /// use redstone_ml::*;
157    ///
158    /// let array = NdArray::new(vec![-1, 3, -7, 8]);
159    /// let min = array.min();
160    /// assert_eq!(min.value(), -7);
161    /// ```
162    pub fn min(&self) -> NdArray<'static, T> {
163        let output = unsafe { <T as ReduceMin>::min(self.ptr(), self.shape(), self.stride()) };
164        NdArray::scalar(output)
165    }
166
167    pub fn min_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
168        self.reduce_along(partial_min, axes, T::max_value())
169    }
170
171    /// Computes the maximum of all elements in the array.
172    ///
173    /// # Example
174    /// ```
175    /// use redstone_ml::*;
176    ///
177    /// let array = NdArray::new(vec![-1, 3, -7, 8]);
178    /// let max = array.max();
179    /// assert_eq!(max.value(), 8);
180    /// ```
181    pub fn max(&self) -> NdArray<'static, T> {
182        let output = unsafe { <T as ReduceMax>::max(self.ptr(), self.shape(), self.stride()) };
183        NdArray::scalar(output)
184    }
185
186    pub fn max_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
187        self.reduce_along(partial_max, axes, T::min_value())
188    }
189
190    /// Computes the minimum absolute value of all elements in the array.
191    ///
192    /// # Example
193    /// ```
194    /// use redstone_ml::*;
195    ///
196    /// let array = NdArray::new(vec![-1, 3, -7, 8]);
197    /// let min = array.min_magnitude();
198    /// assert_eq!(min.value(), 1);
199    /// ```
200    pub fn min_magnitude(&self) -> NdArray<'static, T> {
201        let output = unsafe { <T as ReduceMinMagnitude>::min_magnitude(self.ptr(), self.shape(), self.stride()) };
202        NdArray::scalar(output)
203    }
204
205    pub fn min_magnitude_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
206        self.reduce_along(partial_min_magnitude, axes, T::max_value())
207    }
208
209    /// Computes the maximum absolute value of all elements in the array.
210    ///
211    /// # Example
212    /// ```
213    /// use redstone_ml::*;
214    ///
215    /// let array = NdArray::new(vec![-1, 3, -9, 8]);
216    /// let max = array.max_magnitude();
217    /// assert_eq!(max.value(), 9);
218    /// ```
219    pub fn max_magnitude(&self) -> NdArray<'static, T> {
220        let output = unsafe { <T as ReduceMaxMagnitude>::max_magnitude(self.ptr(), self.shape(), self.stride()) };
221        NdArray::scalar(output)
222    }
223
224    pub fn max_magnitude_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T> {
225        self.reduce_along(partial_max_magnitude, axes, T::zero())
226    }
227
228    /// Computes the mean of all elements in the array.
229    ///
230    /// # Example
231    /// ```
232    /// use redstone_ml::*;
233    ///
234    /// let array = NdArray::new(vec![1.0, 3.0, 5.0, 7.0]);
235    /// let mean = array.mean();
236    /// assert_eq!(mean.value(), 4.0);
237    /// ```
238    pub fn mean(&self) -> NdArray<'static, T>
239    where
240        T: FloatDataType
241    {
242        let n: T = NumCast::from(self.size()).unwrap();
243        self.sum() / n
244    }
245
246    pub fn mean_along(&self, axes: impl ToVec<isize>) -> NdArray<'static, T>
247    where
248        T: FloatDataType
249    {
250        let axes = axes.to_vec();
251
252        let mut n = 1;
253        for &axis in axes.iter() {
254            assert!(axis >= 0, "negative axes are not currently supported");
255            n *= self.shape()[axis as usize];
256        }
257
258        let n: T = NumCast::from(n).unwrap();
259        self.sum_along(axes) / n
260    }
261}
262
263
264#[cfg(test)]
265mod tests {
266    use super::reduced_shape_and_stride;
267
268    #[test]
269    fn test_reduce_shape_and_stride() {
270        let shape = vec![3, 2];
271
272        let correct_shape = vec![3];
273        let correct_stride = vec![1, 0];
274        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1], &shape);
275        assert_eq!(new_shape, correct_shape);
276        assert_eq!(new_stride, correct_stride);
277
278        let shape = vec![4, 2, 3];
279
280        let correct_shape = vec![2, 3];
281        let correct_stride = vec![0, 3, 1];
282        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0], &shape);
283        assert_eq!(new_shape, correct_shape);
284        assert_eq!(new_stride, correct_stride);
285
286        let correct_shape = vec![4, 3];
287        let correct_stride = vec![3, 0, 1];
288        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1], &shape);
289        assert_eq!(new_shape, correct_shape);
290        assert_eq!(new_stride, correct_stride);
291
292        let correct_shape = vec![4, 2];
293        let correct_stride = vec![2, 1, 0];
294        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![2], &shape);
295        assert_eq!(new_shape, correct_shape);
296        assert_eq!(new_stride, correct_stride);
297
298        let correct_shape = vec![3];
299        let correct_stride = vec![0, 0, 1];
300        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0, 1], &shape);
301        assert_eq!(new_shape, correct_shape);
302        assert_eq!(new_stride, correct_stride);
303
304        let correct_shape = vec![2];
305        let correct_stride = vec![0, 1, 0];
306        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![0, 2], &shape);
307        assert_eq!(new_shape, correct_shape);
308        assert_eq!(new_stride, correct_stride);
309
310        let correct_shape = vec![4];
311        let correct_stride = vec![1, 0, 0];
312        let (new_shape, new_stride) = reduced_shape_and_stride(&vec![1, 2], &shape);
313        assert_eq!(new_shape, correct_shape);
314        assert_eq!(new_stride, correct_stride);
315    }
316}