cubecl_reduce/instructions/
sum.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{Reduce, ReduceInstruction};
5
6#[derive(Debug)]
7pub struct Sum;
8
9impl Reduce for Sum {
10 type Instruction<In: Numeric> = Self;
11}
12
13#[cube]
14impl<In: Numeric> ReduceInstruction<In> for Sum {
15 type AccumulatorItem = Line<In>;
16 type SharedAccumulator = SharedMemory<Line<In>>;
17
18 fn null_input(#[comptime] line_size: u32) -> Line<In> {
19 Line::empty(line_size).fill(In::from_int(0))
20 }
21
22 fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
23 Self::null_input(line_size)
24 }
25
26 fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
27 *destination = *source;
28 }
29
30 fn reduce(
31 accumulator: &Self::AccumulatorItem,
32 item: Line<In>,
33 _coordinate: Line<u32>,
34 #[comptime] use_planes: bool,
35 ) -> Self::AccumulatorItem {
36 if use_planes {
37 *accumulator + plane_sum(item)
38 } else {
39 *accumulator + item
40 }
41 }
42
43 fn fuse_accumulators(
44 lhs: Self::AccumulatorItem,
45 rhs: Self::AccumulatorItem,
46 ) -> Self::AccumulatorItem {
47 lhs + rhs
48 }
49
50 fn merge_line<Out: Numeric>(
51 accumulator: Self::AccumulatorItem,
52 _shape_axis_reduce: u32,
53 ) -> Out {
54 let mut sum = In::from_int(0);
55 #[unroll]
56 for k in 0..accumulator.size() {
57 sum += accumulator[k];
58 }
59 Out::cast_from(sum)
60 }
61
62 fn to_output_perpendicular<Out: Numeric>(
63 accumulator: Self::AccumulatorItem,
64 _shape_axis_reduce: u32,
65 ) -> Line<Out> {
66 Line::cast_from(accumulator)
67 }
68}