ferrite/tensor/ops/
reduction.rs

1use std::rc::Rc;
2use crate::{DeviceStorage, MeanGrad, ProductGrad, Storage, SumGrad, Tensor, match_storage, match_storage_assign};
3
4pub trait ReductionOps {
5  fn sum(&self) -> Self;
6  fn sum_axis(&self, axis: usize) -> Self;
7  fn product(&self) -> Self;
8  fn mean(&self) -> Self;
9}
10
11
12impl ReductionOps for Storage {
13  fn sum(&self) -> Self {
14    match_storage!(unary self, sum)
15  }
16
17  fn sum_axis(&self, axis: usize) -> Self {
18    match_storage!(unary self, sum_axis, axis)
19  }
20
21  fn product(&self) -> Self {
22    match_storage!(unary self, product)
23  }
24
25  fn mean(&self) -> Self {
26    match_storage!(unary self, mean)
27  }
28}
29
30impl ReductionOps for Tensor {
31
32  fn sum(&self) -> Self {
33    let tensor = self.tensor().sum();
34    let requires_grad = *self.requires_grad();
35    let mut result = Tensor::new(tensor, self.device(), requires_grad);
36    
37    if requires_grad {
38      result.set_grad_fn(Some(Rc::new(SumGrad::new(self, &result))));
39    }
40    
41    result
42  }
43
44  fn sum_axis(&self, axis: usize) -> Self {
45    let tensor = self.tensor().sum_axis(axis);
46    let requires_grad = *self.requires_grad();
47    let mut result = Tensor::new(tensor, self.device(), requires_grad);
48    
49    result
50  }
51
52  fn mean(&self) -> Self {
53    let tensor = self.tensor().mean();
54    let requires_grad = *self.requires_grad();
55    let mut result = Tensor::new(tensor, self.device(), requires_grad);
56    
57    if requires_grad {
58      result.set_grad_fn(Some(Rc::new(MeanGrad::new(self, &result))));
59    }
60    
61    result
62  }
63
64  fn product(&self) -> Self {
65    let tensor = self.tensor().product();
66    let requires_grad = *self.requires_grad();
67    let mut result = Tensor::new(tensor, self.device(), requires_grad);
68    
69    if requires_grad {
70      result.set_grad_fn(Some(Rc::new(ProductGrad::new(self, &result))));
71    }
72    
73    result
74  }
75
76}