ferrite/tensor/ops/
reduction.rs1use std::rc::Rc;
2use crate::{DeviceStorage, MeanGrad, ProductGrad, Storage, SumGrad, Tensor, match_storage, match_storage_assign};
3
4pub trait ReductionOps {
5 fn sum(&self) -> Self;
6 fn sum_axis(&self, axis: usize) -> Self;
7 fn product(&self) -> Self;
8 fn mean(&self) -> Self;
9}
10
11
12impl ReductionOps for Storage {
13 fn sum(&self) -> Self {
14 match_storage!(unary self, sum)
15 }
16
17 fn sum_axis(&self, axis: usize) -> Self {
18 match_storage!(unary self, sum_axis, axis)
19 }
20
21 fn product(&self) -> Self {
22 match_storage!(unary self, product)
23 }
24
25 fn mean(&self) -> Self {
26 match_storage!(unary self, mean)
27 }
28}
29
30impl ReductionOps for Tensor {
31
32 fn sum(&self) -> Self {
33 let tensor = self.tensor().sum();
34 let requires_grad = *self.requires_grad();
35 let mut result = Tensor::new(tensor, self.device(), requires_grad);
36
37 if requires_grad {
38 result.set_grad_fn(Some(Rc::new(SumGrad::new(self, &result))));
39 }
40
41 result
42 }
43
44 fn sum_axis(&self, axis: usize) -> Self {
45 let tensor = self.tensor().sum_axis(axis);
46 let requires_grad = *self.requires_grad();
47 let mut result = Tensor::new(tensor, self.device(), requires_grad);
48
49 result
50 }
51
52 fn mean(&self) -> Self {
53 let tensor = self.tensor().mean();
54 let requires_grad = *self.requires_grad();
55 let mut result = Tensor::new(tensor, self.device(), requires_grad);
56
57 if requires_grad {
58 result.set_grad_fn(Some(Rc::new(MeanGrad::new(self, &result))));
59 }
60
61 result
62 }
63
64 fn product(&self) -> Self {
65 let tensor = self.tensor().product();
66 let requires_grad = *self.requires_grad();
67 let mut result = Tensor::new(tensor, self.device(), requires_grad);
68
69 if requires_grad {
70 result.set_grad_fn(Some(Rc::new(ProductGrad::new(self, &result))));
71 }
72
73 result
74 }
75
76}