cubecl_reduce/instructions/
sum.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements};
5
6#[derive(Debug, CubeType, Clone)]
7pub struct Sum {}
8
9impl ReduceFamily for Sum {
10    type Instruction<In: Numeric> = Self;
11    type Config = ();
12}
13
14#[cube]
15impl<In: Numeric> ReduceInstruction<In> for Sum {
16    type AccumulatorItem = Line<In>;
17    type SharedAccumulator = SharedMemory<Line<In>>;
18    type Config = ();
19
20    fn requirements(_this: &Self) -> ReduceRequirements {
21        ReduceRequirements { coordinates: false }
22    }
23
24    fn from_config(_config: Self::Config) -> Self {
25        Sum {}
26    }
27    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
28        Line::empty(line_size).fill(In::from_int(0))
29    }
30
31    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
32        Self::null_input(this, line_size)
33    }
34
35    fn assign_accumulator(
36        _this: &Self,
37        destination: &mut Self::AccumulatorItem,
38        source: &Self::AccumulatorItem,
39    ) {
40        *destination = *source;
41    }
42
43    fn reduce(
44        _this: &Self,
45        accumulator: &Self::AccumulatorItem,
46        item: Line<In>,
47        _coordinate: ReduceCoordinate,
48        #[comptime] use_planes: bool,
49    ) -> Self::AccumulatorItem {
50        if use_planes {
51            *accumulator + plane_sum(item)
52        } else {
53            *accumulator + item
54        }
55    }
56
57    fn fuse_accumulators(
58        _this: &Self,
59        lhs: Self::AccumulatorItem,
60        rhs: Self::AccumulatorItem,
61    ) -> Self::AccumulatorItem {
62        lhs + rhs
63    }
64
65    fn merge_line<Out: Numeric>(
66        _this: &Self,
67        accumulator: Self::AccumulatorItem,
68        _shape_axis_reduce: u32,
69    ) -> Out {
70        let mut sum = In::from_int(0);
71        #[unroll]
72        for k in 0..accumulator.size() {
73            sum += accumulator[k];
74        }
75        Out::cast_from(sum)
76    }
77
78    fn to_output_perpendicular<Out: Numeric>(
79        _this: &Self,
80        accumulator: Self::AccumulatorItem,
81        _shape_axis_reduce: u32,
82    ) -> Line<Out> {
83        Line::cast_from(accumulator)
84    }
85}