1use crate::{AutogradMetaT, Dim, NumDType, Result, Storage, StorageRef, Tensor, WithDType};
2use paste::paste;
3
4use super::ResettableIterator;
5
6macro_rules! reduce_impl {
7 ($fn_name:ident, $reduce:ident, $op:ident) => {
8 paste! {
9 pub fn $fn_name<D: Dim>(&self, axis: D) -> Result<Self> {
10 let (storage, dims) = self.compute_reduec_axis_op(axis, $reduce::op, stringify!($fn_name))?;
11 let meta = T::AutogradMeta::on_reduce_op(self, &dims, crate::ReduceOp::$op);
12 let res = Self::from_storage(storage, dims, meta);
13 res.squeeze(axis)
14 }
15
16 pub fn [< $fn_name _keepdim >]<D: Dim>(&self, axis: D) -> Result<Self> {
17 let (storage, dims) = self.compute_reduec_axis_op(axis, $reduce::op, stringify!([< $fn_name _keepdim >]))?;
18 let meta = T::AutogradMeta::on_reduce_op(self, &dims, crate::ReduceOp::$op);
19 Ok(Self::from_storage(storage, dims, meta))
20 }
21
22 pub fn [< $fn_name _all >](&self) -> Result<Self> {
23 self.flatten_all()?.$fn_name(0)
24 }
25 }
26 };
27}
28
29impl<T: NumDType> Tensor<T> {
30 reduce_impl!(sum, ReduceSum, Sum);
31 reduce_impl!(min, ReduceMin, Min);
32 reduce_impl!(max, ReduceMax, Max);
33 reduce_impl!(mean, ReduceMean, Mean);
34
35 pub fn var_keepdim<D: Dim>(&self, axis: D) -> Result<Self> {
36 let mean = self.mean_keepdim(axis)?; let delta = self.broadcast_sub(&mean)?; let delta_pow = &delta * δ delta_pow.mean_keepdim(axis)
41 }
42
43 pub fn var_unbiased_keepdim<D: Dim>(&self, axis: D) -> Result<Self> {
44 let n = T::from_usize(self.dim(axis)?);
45 let biased_var = self.var_keepdim(axis)?;
46
47 let correction = n / (n - T::one());
49 Ok(correction * biased_var)
50 }
51
52 pub fn var<D: Dim>(&self, axis: D) -> Result<Self> {
53 let v = self.var_keepdim(axis)?;
54 let v = v.squeeze(axis)?;
55 Ok(v)
56 }
57
58 pub fn var_unbiased<D: Dim>(&self, axis: D) -> Result<Self> {
59 let v = self.var_unbiased_keepdim(axis)?;
60 let v = v.squeeze(axis)?;
61 Ok(v)
62 }
63
64 pub fn var_all(&self) -> Result<Self> {
65 self.flatten_all()?.var(0)
66 }
67
68 pub fn var_unbiased_all(&self) -> Result<Self> {
69 self.flatten_all()?.var_unbiased(0)
70 }
71
72 pub fn argmin_keepdim<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
73 let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMin::op, "argmin")?;
74 Ok(Tensor::from_storage(storage, dims, Default::default()))
75 }
76
77 pub fn argmin<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
78 let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMin::op, "argmin_keepdim")?;
79 let res = Tensor::from_storage(storage, dims, Default::default());
80 res.squeeze(axis)
81 }
82
83 pub fn argmax_keepdim<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
84 let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMax::op, "argmax")?;
85 Ok(Tensor::from_storage(storage, dims, Default::default()))
86 }
87
88 pub fn argmax<D: Dim>(&self, axis: D) -> Result<Tensor<u32>> {
89 let (storage, dims) = self.compute_reduec_axis_op(axis, ReduceArgMax::op, "argmax_keepdim")?;
90 let res = Tensor::from_storage(storage, dims, Default::default());
91 res.squeeze(axis)
92 }
93}
94
95impl Tensor<bool> {
96 pub fn all(&self) -> crate::Result<bool> {
97 self.iter().map(|mut i| i.all(|a| a))
98 }
99
100 pub fn any(&self) -> crate::Result<bool> {
101 self.iter().map(|mut i| i.any(|a| a))
102 }
103
104 pub fn all_axis<D: Dim>(&self, axis: D) -> Result<Tensor<bool>> {
105 self.reduec_axis_op(axis, ReduceAll::op, Default::default(), "all")
106 }
107
108 pub fn any_axis<D: Dim>(&self, axis: D) -> Result<Tensor<bool>> {
109 self.reduec_axis_op(axis, ReduceAny::op, Default::default(), "any")
110 }
111}
112
113impl<T: WithDType> Tensor<T> {
114 fn reduec_axis_op<'a, F, R: WithDType, D: Dim>(&'a self, reduce_dim: D, f: F, meta: R::AutogradMeta, op_name: &'static str) -> Result<Tensor<R>>
115 where
116 F: Fn(&mut DimArrayIter<'a, T>) -> R
117 {
118 let (storage, shape) = self.compute_reduec_axis_op(reduce_dim, f, op_name)?;
119 Ok(Tensor::<R>::from_storage(storage, shape, meta))
120 }
121
122 fn compute_reduec_axis_op<'a, F, R: WithDType, D: Dim>(&'a self, reduce_dim: D, f: F, op_name: &'static str) -> Result<(Storage<R>, Vec<usize>)>
123 where
124 F: Fn(&mut DimArrayIter<'a, T>) -> R
125 {
126 let reduce_dim = reduce_dim.to_index(self.shape(), op_name)?;
127 assert!(reduce_dim < self.layout().dims().len());
128 let reduce_dim_stride = self.layout().stride()[reduce_dim];
129 let reduce_dim_size = self.layout().dims()[reduce_dim];
130
131 let dst_len = self.layout().element_count() / reduce_dim_size;
132 let mut dst: Vec<R> = Vec::with_capacity(dst_len);
133 let dst_to_set = dst.spare_capacity_mut();
134
135 let layout = self.layout().narrow(reduce_dim, 0, 1)?;
136 for (dst_index, src_index) in layout.storage_indices().enumerate() {
137 let arr: DimArray<'_, T> = DimArray {
138 src: self.storage_ref(src_index)?,
139 size: reduce_dim_size,
140 stride: reduce_dim_stride
141 };
142 let mut iter: DimArrayIter<'_, T> = arr.into_iter();
143 dst_to_set[dst_index].write(f(&mut iter));
144 }
145 unsafe { dst.set_len(dst_len) };
146
147 let storage = Storage::new(dst);
148 let mut shape = self.dims().to_vec();
149 shape[reduce_dim] = 1;
151
152 Ok((storage, shape))
153 }
154}
155
156pub trait ReduceOp<D: WithDType> {
157 type Output: WithDType;
158 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output;
159}
160
161pub struct ReduceAll;
162impl ReduceOp<bool> for ReduceAll {
163 type Output = bool;
164 fn op(arr: &mut DimArrayIter<'_, bool>) -> Self::Output {
165 arr.into_iter().all(|b| b)
166 }
167}
168
169pub struct ReduceAny;
170impl ReduceOp<bool> for ReduceAny {
171 type Output = bool;
172 fn op(arr: &mut DimArrayIter<'_, bool>) -> Self::Output {
173 arr.into_iter().any(|b| b)
174 }
175}
176
177pub struct ReduceSum;
178impl<D: NumDType> ReduceOp<D> for ReduceSum {
179 type Output = D;
180 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
181 arr.into_iter().sum::<D>()
182 }
183}
184
185pub struct ReduceMean;
186impl<D: NumDType> ReduceOp<D> for ReduceMean {
187 type Output = D;
188 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
189 let len = arr.len();
190 arr.into_iter().sum::<D>() / D::from_usize(len)
191 }
192}
193
194pub struct ReduceVar;
195impl<D: NumDType> ReduceOp<D> for ReduceVar {
196 type Output = D;
197 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
198 let len = arr.len();
199 if len == 0 { return D::zero(); }
200
201 let mean = ReduceMean::op(arr);
202
203 arr.reset();
204 let mut sum_sq_diff = D::zero();
205 while let Some(v) = arr.next() {
206 let diff = v - mean;
207 sum_sq_diff += diff * diff;
208 }
209
210 sum_sq_diff / D::from_usize(len)
211 }
212}
213
214pub struct ReduceProduct;
215impl<D: NumDType> ReduceOp<D> for ReduceProduct {
216 type Output = D;
217 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
218 arr.into_iter().product::<D>()
219 }
220}
221
222pub struct ReduceMin;
223impl<D: NumDType> ReduceOp<D> for ReduceMin {
224 type Output = D;
225 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
226 arr.into_iter()
227 .reduce(|a, b| D::minimum(a, b)).unwrap()
228 }
229}
230
231pub struct ReduceArgMin;
232impl<D: NumDType> ReduceOp<D> for ReduceArgMin {
233 type Output = u32;
234 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
235 arr.into_iter()
236 .enumerate()
237 .reduce(|(ia, a), (ib, b)| {
238 if a.partial_cmp(&b) == Some(std::cmp::Ordering::Less) {
239 (ia, a)
240 } else {
241 (ib, b)
242 }
243 }).unwrap().0 as u32
244 }
245}
246
247pub struct ReduceMax;
248impl<D: NumDType> ReduceOp<D> for ReduceMax {
249 type Output = D;
250 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
251 arr.into_iter()
252 .reduce(|a, b| D::maximum(a, b)).unwrap()
253 }
254}
255
256pub struct ReduceArgMax;
257impl<D: NumDType> ReduceOp<D> for ReduceArgMax {
258 type Output = u32;
259 fn op(arr: &mut DimArrayIter<'_, D>) -> Self::Output {
260 arr.into_iter()
261 .enumerate()
262 .reduce(|(ia, a), (ib, b)| {
263 if a.partial_cmp(&b) == Some(std::cmp::Ordering::Greater) {
264 (ia, a)
265 } else {
266 (ib, b)
267 }
268 }).unwrap().0 as u32
269 }
270}
271
272pub struct DimArray<'a, T> {
273 src: StorageRef<'a, T>,
274 size: usize,
275 stride: usize
276}
277
278impl<'a, T: WithDType> DimArray<'a, T> {
279 pub fn get(&self, index: usize) -> T {
280 self.src.get_unchecked(index * self.stride)
281 }
282
283 #[allow(unused)]
284 pub fn to_vec(&self) -> Vec<T> {
285 let mut v = vec![];
286 for i in 0..self.size {
287 v.push(self.get(i));
288 }
289 v
290 }
291}
292
293impl<'a, T: WithDType> IntoIterator for DimArray<'a, T> {
294 type IntoIter = DimArrayIter<'a, T>;
295 type Item = T;
296 fn into_iter(self) -> Self::IntoIter {
297 DimArrayIter::<'a, T> {
298 array: self,
299 index: 0,
300 }
301 }
302}
303
304pub struct DimArrayIter<'a, T> {
305 array: DimArray<'a, T>,
306 index: usize,
307}
308
309impl<'a, T: WithDType> Iterator for DimArrayIter<'a, T> {
310 type Item = T;
311 fn next(&mut self) -> Option<T> {
312 if self.index >= self.array.size {
313 None
314 } else {
315 let index = self.index;
316 self.index += 1;
317 Some(self.array.get(index))
318 }
319 }
320}
321
322impl<'a, T: WithDType> ExactSizeIterator for DimArrayIter<'a, T> {
323 fn len(&self) -> usize {
324 self.array.size
325 }
326}
327
328impl<'a, T: WithDType> ResettableIterator for DimArrayIter<'a, T> {
329 fn reset(&mut self) {
330 self.index = 0
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_sum_matrix_axis0() {
340 let arr = Tensor::new(&[[1, 2, 3], [3, 4, 5]]).unwrap();
344 let s = arr.sum(0).unwrap();
345 let expected = Tensor::new(&[4, 6, 8]).unwrap();
346 assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
347 }
348
349 #[test]
350 fn test_sum_matrix_axis1() {
351 let arr = Tensor::new(&[[1, 2, 3], [3, 4, 5]]).unwrap();
355 let s = arr.sum(1).unwrap();
356 let expected = Tensor::new(&[6, 12]).unwrap();
357 assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
358 }
359
360 #[test]
361 fn test_sum_ones_axis() {
362 let arr = Tensor::ones((2, 3)).unwrap();
366 let s0 = arr.sum(0).unwrap(); let s1 = arr.sum(1).unwrap(); let expected0 = Tensor::new(&[2, 2, 2]).unwrap();
370 let expected1 = Tensor::new(&[3, 3]).unwrap();
371
372 assert!(s0.allclose(&expected0, 1e-5, 1e-8).unwrap());
373 assert!(s1.allclose(&expected1, 1e-5, 1e-8).unwrap());
374 }
375
376 #[test]
377 fn test_min_matrix_axis0() {
378 let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
382 let m = arr.min(0).unwrap();
383 let expected = Tensor::new(&[1, 1, 0]).unwrap();
384 assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
385 }
386
387 #[test]
388 fn test_max_matrix_axis1() {
389 let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
393 let m = arr.max(1).unwrap();
394 let expected = Tensor::new(&[3, 3]).unwrap();
395 assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
396 }
397
398 #[test]
399 fn test_aragmin_matrix_axis0() {
400 let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
404 let m = arr.argmin(0).unwrap();
405 let expected = Tensor::new(&[0, 1, 1]).unwrap();
406 assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
407 }
408
409 #[test]
410 fn test_sum_all() {
411 let arr = Tensor::new(&[[1, 2], [3, 4]]).unwrap();
413 let s = arr.sum_all().unwrap();
414 let expected = Tensor::new(10).unwrap();
415 assert!(s.allclose(&expected, 1e-5, 1e-8).unwrap());
416 }
417
418 #[test]
419 fn test_mean_all() {
420 let arr = Tensor::new(&[[1.0, 2.0], [3.0, 4.0]]).unwrap();
422 let m = arr.mean_all().unwrap();
423 let expected = Tensor::new(2.5).unwrap();
424 assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
425 }
426
427 #[test]
428 fn test_min_max_all() {
429 let arr = Tensor::new(&[[10, 2, 5], [8, 1, 9]]).unwrap();
433
434 let min_val = arr.min_all().unwrap();
435 let max_val = arr.max_all().unwrap();
436
437 let expected_min = Tensor::new(1).unwrap();
438 let expected_max = Tensor::new(10).unwrap();
439
440 assert!(min_val.allclose(&expected_min, 1e-5, 1e-8).unwrap());
441 assert!(max_val.allclose(&expected_max, 1e-5, 1e-8).unwrap());
442 }
443
444 #[test]
445 fn test_argmax_matrix_axis1() {
446 let arr = Tensor::new(&[[1, 2, 3], [3, 1, 0]]).unwrap();
450 let m = arr.argmax(1).unwrap();
451 let expected = Tensor::new(&[2, 0]).unwrap();
452 assert!(m.allclose(&expected, 1e-5, 1e-8).unwrap());
453 }
454
455 #[test]
456 fn test_reductions_with_negatives() {
457 let arr = Tensor::new(&[[-2.0, 0.0, 2.0]]).unwrap();
463
464 assert!(arr.sum_all().unwrap().allclose(&Tensor::new(0.0).unwrap(), 1e-5, 1e-8).unwrap());
465 assert!(arr.mean_all().unwrap().allclose(&Tensor::new(0.0).unwrap(), 1e-5, 1e-8).unwrap());
466
467 let expected_var = Tensor::new(2.66666666666666666).unwrap();
468 assert!(arr.var_all().unwrap().allclose(&expected_var, 1e-5, 1e-8).unwrap());
469 }
470}