ferrite/tensor/device/cpu/kernels/
reduction.rs

1use crate::*;
2use ndarray::prelude::*;
3
4
5impl ReductionOps for CpuStorage {
6  fn sum(&self) -> Self {
7    let data: f32 = self.data().read().unwrap().iter().sum();
8    CpuStorage::from_ndarray(&array![data], None, None)
9  }
10
11  fn sum_axis(&self, axis: usize) -> Self {
12    // Compute the sizes of the "outer" dimensions, the dimension to sum over,
13    // and the "inner" (trailing) dimensions.
14    let outer: usize = self.shape()[..axis].iter().product();
15    let axis_len: usize = self.shape()[axis];
16    let trailing: usize = self.shape()[axis+1..].iter().product();
17    
18    // The new shape is the original shape with the summing axis removed.
19    let mut new_shape = self.shape().clone();
20    new_shape.remove(axis);
21
22    // Prepare a vector for the summed data.
23    let mut new_data = vec![0.0; outer * trailing];
24    
25    // Borrow the underlying data.
26    let binding = self.data();
27    let data_ref = binding.read().unwrap();
28    
29    // Iterate over the "outer" blocks and the "inner" trailing dimensions.
30    // For each such location, sum over the elements along the `axis`.
31    for i in 0..outer {
32      for k in 0..trailing {
33        let mut sum = 0.0;
34        for j in 0..axis_len {
35          // In contiguous (row-major) layout, the index is computed as:
36          //   index = offset + i * (axis_len * trailing) + j * trailing + k
37          let index = self.offset() + i * (axis_len * trailing) + j * trailing + k;
38          sum += data_ref[index];
39        }
40        new_data[i * trailing + k] = sum;
41      }
42    }
43    
44    // Construct a new CpuStorage with the summed data and the new shape.
45    CpuStorage::new(new_data, new_shape)
46  }
47
48  fn product(&self) -> Self {
49    let data: f32 = self.data().read().unwrap().iter().sum();
50    CpuStorage::from_ndarray(&array![data], None, None)
51  }
52
53  fn mean(&self) -> Self {
54    let data: f32 = self.data().read().unwrap().iter().sum::<f32>() / self.data().read().unwrap().len() as f32;
55    CpuStorage::from_ndarray(&array![data], None, None)
56  }
57
58  
59}