cubecl_reduce/instructions/
sum.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{Reduce, ReduceInstruction};
5
6#[derive(Debug)]
7pub struct Sum;
8
9impl Reduce for Sum {
10    type Instruction<In: Numeric> = Self;
11}
12
13#[cube]
14impl<In: Numeric> ReduceInstruction<In> for Sum {
15    type AccumulatorItem = Line<In>;
16    type SharedAccumulator = SharedMemory<Line<In>>;
17
18    fn null_input(#[comptime] line_size: u32) -> Line<In> {
19        Line::empty(line_size).fill(In::from_int(0))
20    }
21
22    fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
23        Self::null_input(line_size)
24    }
25
26    fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
27        *destination = *source;
28    }
29
30    fn reduce(
31        accumulator: &Self::AccumulatorItem,
32        item: Line<In>,
33        _coordinate: Line<u32>,
34        #[comptime] use_planes: bool,
35    ) -> Self::AccumulatorItem {
36        if use_planes {
37            *accumulator + plane_sum(item)
38        } else {
39            *accumulator + item
40        }
41    }
42
43    fn fuse_accumulators(
44        lhs: Self::AccumulatorItem,
45        rhs: Self::AccumulatorItem,
46    ) -> Self::AccumulatorItem {
47        lhs + rhs
48    }
49
50    fn merge_line<Out: Numeric>(
51        accumulator: Self::AccumulatorItem,
52        _shape_axis_reduce: u32,
53    ) -> Out {
54        let mut sum = In::from_int(0);
55        #[unroll]
56        for k in 0..accumulator.size() {
57            sum += accumulator[k];
58        }
59        Out::cast_from(sum)
60    }
61
62    fn to_output_perpendicular<Out: Numeric>(
63        accumulator: Self::AccumulatorItem,
64        _shape_axis_reduce: u32,
65    ) -> Line<Out> {
66        Line::cast_from(accumulator)
67    }
68}