Skip to main content

lumen_core/tensor/
reduce.rs

1use crate::{AutogradMetaT, Dim, NumDType, Result, Storage, StorageRef, Tensor, WithDType};
2use paste::paste;
3
4use super::ResettableIterator;
5
6macro_rules! reduce_impl {
7    ($fn_name:ident, $reduce:ident, $op:ident) => {
8        paste! {
9            pub fn $fn_name<D: Dim>(&self, axis: D) -> Result<Self> {
10                let (storage, dims) = self.compute_reduec_axis_op(axis, $reduce::op, stringify!($fn_name))?;
11                let meta = T::AutogradMeta::on_reduce_op(self, &dims, crate::ReduceOp::$op);
12                let res = Self::from_storage(storage, dims, meta);
13                res.squeeze(axis)
14            }
15        
16            pub fn [< $fn_name _keepdim >]<D: Dim>(&self, axis: D) -> Result<Self> {
17                let (storage, dims) = self.compute_reduec_axis_op(axis, $reduce::op, stringify!([< $fn_name _keepdim >]))?;
18                let meta = T::AutogradMeta::on_reduce_op(self, &dims, crate::ReduceOp::$op);
19                Ok(Self::from_storage(storage, dims, meta))
20            }
21
22            pub fn [< $fn_name _all >](&self) -> Result<Self> {
23                self.flatten_all()?.$fn_name(0)
24            }
25        }
26    };
27}
28
29impl<T: NumDType> Tensor<T> {
30    reduce_impl!(sum, ReduceSum, Sum);
31    reduce_impl!(min, ReduceMin, Min);
32    reduce_impl!(max, ReduceMax, Max);
33    reduce_impl!(mean, ReduceMean, Mean);
34
35    pub fn var_keepdim<D: Dim>(&self, axis: D) -> Result<Self> {
36        let mean = self.mean_keepdim(axis)?; // (..., 1, ...)
37        let delta = self.broadcast_sub(&mean)?; // (..., n, ...)
38        let delta_pow = &delta * &delta; // (..., n, ...)
39
40        delta_pow.mean_keepdim(axis)
41    }
42
43    pub fn var_unbiased_keepdim<D: Dim>(&self, axis: D) -> Result<Self> {
44        let n = T::from_usize(self.dim(axis)?);
45        let biased_var = self.var_keepdim(axis)?;
46        
47        // cor scale: N / (N - 1)
48        let correction = n / (n - T::one());
49        Ok(correction * biased_var)
50    }
51    
52    pub fn var<D: Dim>(&self, axis: D) -> Result<Self> {
53        let v = self.var_keepdim(axis)?;
54        let v = v.squeeze(axis)?;
55        Ok(v)
56    }
57
58    pub fn var_unbiased<D: Dim>(&self, axis: D) -> Result<Self> {
59        let v = self.var_unbiased_keepdim(axis)?;
60        let v = v.squeeze(axis)?;
61        Ok(v)
62    }
63
64    pub fn var_all(&self) -> Result<Self> {
65        self.flatten_all()?.var(0)
66    }
67
68    pub fn var_unbiased_all(&self) -> Result<Self> {
69        self.flatten_all()?.var_unbiased(0)
70    }
71
72    pub fn argmin_keepdim<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
73        let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMin::op, "argmin")?;
74        Ok(Tensor::from_storage(storage, dims, Default::default()))
75    }
76
77    pub fn argmin<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
78        let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMin::op, "argmin_keepdim")?;
79        let res = Tensor::from_storage(storage, dims, Default::default());
80        res.squeeze(axis)
81    }
82
83    pub fn argmax_keepdim<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
84        let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMax::op, "argmax")?;
85        Ok(Tensor::from_storage(storage, dims, Default::default()))
86    }
87
88    pub fn argmax<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
89        let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMax::op, "argmax_keepdim")?;
90        let res = Tensor::from_storage(storage, dims, Default::default());
91        res.squeeze(axis)
92    }
93}
94
95impl Tensor<bool> {
96    pub fn all(&self) -> crate::Result<bool> {
97        self.iter().map(|mut i| i.all(|a| a))
98    }
99
100    pub fn any(&self) -> crate::Result<bool> {
101        self.iter().map(|mut i| i.any(|a| a))
102    } 
103
104    pub fn all_axis<D: Dim>(&self, axis: D) -> Result<Tensor<bool>> {
105        self.reduec_axis_op(axis, ReduceAll::op, Default::default(), "all")
106    }
107
108    pub fn any_axis<D: Dim>(&self, axis: D) -> Result<Tensor<bool>> {
109        self.reduec_axis_op(axis, ReduceAny::op, Default::default(), "any")
110    }
111}
112
113impl<T: WithDType> Tensor<T> {
114    fn reduec_axis_op<'a, F, R: WithDType, D: Dim>(&'a self, reduce_dim: D, f: F, meta: R::AutogradMeta, op_name: &'static str) -> Result<Tensor<R>> 
115    where 
116        F: Fn(&mut DimArrayIter<'a, T>) -> R
117    {
118        let (storage, shape) = self.compute_reduec_axis_op(reduce_dim, f, op_name)?;
119        Ok(Tensor::<R>::from_storage(storage, shape, meta))
120    }
121
122    fn compute_reduec_axis_op<'a, F, R: WithDType, D: Dim>(&'a self, reduce_dim: D, f: F, op_name: &'static str) -> Result<(Storage<R>, Vec<usize>)> 
123    where 
124        F: Fn(&mut DimArrayIter<'a, T>) -> R
125    {
126        let reduce_dim = reduce_dim.to_index(self.shape(), op_name)?;
127        assert!(reduce_dim < self.layout().dims().len());
128        let reduce_dim_stride = self.layout().stride()[reduce_dim];
129        let reduce_dim_size = self.layout().dims()[reduce_dim];
130
131        let dst_len = self.layout().element_count() / reduce_dim_size;
132        let mut dst: Vec<R> = Vec::with_capacity(dst_len);
133        let dst_to_set = dst.spare_capacity_mut();
134
135        let layout = self.layout().narrow(reduce_dim, 0, 1)?;
136        for (dst_index, src_index) in layout.storage_indices().enumerate() {
137            let arr: DimArray<'_, T> = DimArray {
138                src: self.storage_ref(src_index)?,
139                size: reduce_dim_size,
140                stride: reduce_dim_stride
141            };
142            let mut iter: DimArrayIter<'_, T> = arr.into_iter();
143            dst_to_set[dst_index].write(f(&mut iter));
144        }
145        unsafe { dst.set_len(dst_len) };
146
147        let storage = Storage::new(dst);
148        let mut shape = self.dims().to_vec();
149        // shape.remove(reduce_dim);
150        shape[reduce_dim] = 1;
151
152        Ok((storage, shape))
153    }
154}
155
156pub trait ReduceOp<D: WithDType> {
157    type Output: WithDType;
158    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output;
159}
160
161pub struct ReduceAll;
162impl ReduceOp<bool> for ReduceAll {
163    type Output = bool;
164    fn op(arr: &mut DimArrayIter<'_, bool>) -> Self::Output {
165        arr.into_iter().all(|b| b)
166    }
167}
168
169pub struct ReduceAny;
170impl ReduceOp<bool> for ReduceAny {
171    type Output = bool;
172    fn op(arr: &mut DimArrayIter<'_, bool>) -> Self::Output {
173        arr.into_iter().any(|b| b)
174    }
175}
176
177pub struct ReduceSum;
178impl<D: NumDType> ReduceOp<D> for ReduceSum {
179    type Output = D;
180    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
181        arr.into_iter().sum::<D>()
182    }
183} 
184
185pub struct ReduceMean;
186impl<D: NumDType> ReduceOp<D> for ReduceMean {
187    type Output = D;
188    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
189        let len = arr.len();
190        arr.into_iter().sum::<D>() / D::from_usize(len)
191    }
192} 
193
194pub struct ReduceVar;
195impl<D: NumDType> ReduceOp<D> for ReduceVar {
196    type Output = D;
197    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
198        let len = arr.len();
199        if len == 0 { return D::zero(); }
200
201        let mean = ReduceMean::op(arr);
202        
203        arr.reset();
204        let mut sum_sq_diff = D::zero();
205        while let Some(v) = arr.next() {
206            let diff = v - mean;
207            sum_sq_diff += diff * diff;
208        }
209
210        sum_sq_diff / D::from_usize(len)
211    }
212} 
213
214pub struct ReduceProduct;
215impl<D: NumDType> ReduceOp<D> for ReduceProduct {
216    type Output = D;
217    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
218        arr.into_iter().product::<D>()
219    }
220} 
221
222pub struct ReduceMin;
223impl<D: NumDType> ReduceOp<D> for ReduceMin {
224    type Output = D;
225    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
226        arr.into_iter()
227            .reduce(|a, b| D::minimum(a, b)).unwrap()
228    }
229} 
230
231pub struct ReduceArgMin;
232impl<D: NumDType> ReduceOp<D> for ReduceArgMin {
233    type Output = u32;
234    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
235        arr.into_iter()
236            .enumerate()
237            .reduce(|(ia, a), (ib, b)| {
238                if a.partial_cmp(&b) == Some(std::cmp::Ordering::Less) {
239                    (ia, a)
240                } else {
241                    (ib, b)
242                }
243            }).unwrap().0 as u32
244    }
245} 
246
247pub struct ReduceMax;
248impl<D: NumDType> ReduceOp<D> for ReduceMax {
249    type Output = D;
250    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
251        arr.into_iter()
252            .reduce(|a, b| D::maximum(a, b)).unwrap()
253    }
254} 
255
256pub struct ReduceArgMax;
257impl<D: NumDType> ReduceOp<D> for ReduceArgMax {
258    type Output = u32;
259    fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
260        arr.into_iter()
261            .enumerate()
262            .reduce(|(ia, a), (ib, b)| {
263                if a.partial_cmp(&b) == Some(std::cmp::Ordering::Greater) {
264                    (ia, a)
265                } else {
266                    (ib, b)
267                }
268            }).unwrap().0 as u32
269    }
270} 
271
272pub struct DimArray<'a, T> {
273    src: StorageRef<'a, T>,
274    size: usize,
275    stride: usize
276}
277
278impl<'a, T: WithDType> DimArray<'a, T> {
279    pub fn get(&self, index: usize) -> T {
280        self.src.get_unchecked(index * self.stride)
281    }
282
283    #[allow(unused)]
284    pub fn to_vec(&self) -> Vec<T> {
285        let mut v = vec![];
286        for i in 0..self.size {
287            v.push(self.get(i));
288        }
289        v
290    }
291}
292
293impl<'a, T: WithDType> IntoIterator for DimArray<'a, T> {
294    type IntoIter = DimArrayIter<'a, T>;
295    type Item = T;
296    fn into_iter(self) -> Self::IntoIter {
297        DimArrayIter::<'a, T> {
298            array: self,
299            index: 0,
300        }
301    }
302}
303
304pub struct DimArrayIter<'a, T> {
305    array: DimArray<'a, T>,
306    index: usize,
307}
308
309impl<'a, T: WithDType> Iterator for DimArrayIter<'a, T> {
310    type Item = T;
311    fn next(&mut self) -> Option<T> {
312        if self.index >= self.array.size {
313            None
314        } else {
315            let index = self.index;
316            self.index += 1;
317            Some(self.array.get(index))
318        }
319    }
320}
321
322impl<'a, T: WithDType> ExactSizeIterator for DimArrayIter<'a, T> {
323    fn len(&self) -> usize {
324        self.array.size
325    }
326}
327
328impl<'a, T: WithDType> ResettableIterator for DimArrayIter<'a, T> {
329    fn reset(&mut self) {
330        self.index = 0
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_sum_matrix_axis0() {
340        // [[1, 2, 3],
341        //  [3, 4, 5]]
342        // sum_all(axis=0) -> [4, 6, 8]
343        let arr = Tensor::new(&[[1, 2, 3], [3, 4, 5]]).unwrap();
344        let s = arr.sum(0).unwrap();
345        let expected = Tensor::new(&[4, 6, 8]).unwrap();
346        assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
347    }
348
349    #[test]
350    fn test_sum_matrix_axis1() {
351        // [[1, 2, 3],
352        //  [3, 4, 5]]
353        // sum_all(axis=1) -> [6, 12]
354        let arr = Tensor::new(&[[1, 2, 3], [3, 4, 5]]).unwrap();
355        let s = arr.sum(1).unwrap();
356        let expected = Tensor::new(&[6, 12]).unwrap();
357        assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
358    }
359
360    #[test]
361    fn test_sum_ones_axis() {
362        // ones( (2,3), dtype=I32 )
363        // [[1,1,1],
364        //  [1,1,1]]
365        let arr = Tensor::ones((2, 3)).unwrap();
366        let s0 = arr.sum(0).unwrap(); // -> [2,2,2]
367        let s1 = arr.sum(1).unwrap(); // -> [3,3]
368
369        let expected0 = Tensor::new(&[2, 2, 2]).unwrap();
370        let expected1 = Tensor::new(&[3, 3]).unwrap();
371
372        assert!(s0.allclose(&expected0, 1e-5, 1e-8).unwrap());
373        assert!(s1.allclose(&expected1, 1e-5, 1e-8).unwrap());
374    }
375
376    #[test]
377    fn test_min_matrix_axis0() {
378        // [[1, 2, 3],
379        //  [3, 1, 0]]
380        // min_all(axis=0) -> [1, 1, 0]
381        let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
382        let m = arr.min(0).unwrap();
383        let expected = Tensor::new(&[1, 1, 0]).unwrap();
384        assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
385    }
386
387    #[test]
388    fn test_max_matrix_axis1() {
389        // [[1, 2, 3],
390        //  [3, 1, 0]]
391        // max_all(axis=1) -> [3, 3]
392        let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
393        let m = arr.max(1).unwrap();
394        let expected = Tensor::new(&[3, 3]).unwrap();
395        assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
396    }
397
398    #[test]
399    fn test_aragmin_matrix_axis0() {
400        // [[1, 2, 3],
401        //  [3, 1, 0]]
402        // min_all(axis=0) -> [1, 1, 0]
403        let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
404        let m = arr.argmin(0).unwrap();
405        let expected = Tensor::new(&[0, 1, 1]).unwrap();
406        assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
407    }
408
409    #[test]
410    fn test_sum_all() {
411        // [[1, 2], [3, 4]] -> 1+2+3+4 = 10
412        let arr = Tensor::new(&[[1, 2], [3, 4]]).unwrap();
413        let s = arr.sum_all().unwrap();        
414        let expected = Tensor::new(10).unwrap(); 
415        assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
416    }
417
418    #[test]
419    fn test_mean_all() {
420        // [[1.0, 2.0], [3.0, 4.0]] -> Sum=10.0, Count=4 -> Mean=2.5
421        let arr = Tensor::new(&[[1.0, 2.0], [3.0, 4.0]]).unwrap();
422        let m = arr.mean_all().unwrap();
423        let expected = Tensor::new(2.5).unwrap();
424        assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
425    }
426
427    #[test]
428    fn test_min_max_all() {
429        // [[10, 2, 5], [8, 1, 9]]
430        // Global Min: 1
431        // Global Max: 10
432        let arr = Tensor::new(&[[10, 2, 5], [8, 1, 9]]).unwrap();
433        
434        let min_val = arr.min_all().unwrap();
435        let max_val = arr.max_all().unwrap();
436        
437        let expected_min = Tensor::new(1).unwrap();
438        let expected_max = Tensor::new(10).unwrap();
439        
440        assert!(min_val.allclose(&expected_min, 1e-5, 1e-8).unwrap());
441        assert!(max_val.allclose(&expected_max, 1e-5, 1e-8).unwrap());
442    }
443    
444    #[test]
445    fn test_argmax_matrix_axis1() {
446        // [[1, 2, 3],
447        //  [3, 1, 0]]
448        // max_all(axis=1) -> [3, 3]
449        let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
450        let m = arr.argmax(1).unwrap();
451        let expected = Tensor::new(&[2, 0]).unwrap();
452        assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
453    }
454
455    #[test]
456    fn test_reductions_with_negatives() {
457        // [[-2.0, 0.0, 2.0]]
458        // sum_all = 0.0
459        // mean_all = 0.0
460        // var(axis=1) -> 4.0 (unbiased: (4+0+4)/2)
461        
462        let arr = Tensor::new(&[[-2.0, 0.0, 2.0]]).unwrap();
463        
464        assert!(arr.sum_all().unwrap().allclose(&Tensor::new(0.0).unwrap(), 1e-5, 1e-8).unwrap());
465        assert!(arr.mean_all().unwrap().allclose(&Tensor::new(0.0).unwrap(), 1e-5, 1e-8).unwrap());
466        
467        let expected_var = Tensor::new(2.66666666666666666).unwrap();
468        assert!(arr.var_all().unwrap().allclose(&expected_var, 1e-5, 1e-8).unwrap());
469    }
470}