cubecl_reduce/instructions/
sum.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements};
5
6#[derive(Debug, CubeType, Clone)]
7pub struct Sum {}
8
9impl ReduceFamily for Sum {
10 type Instruction<In: Numeric> = Self;
11 type Config = ();
12}
13
14#[cube]
15impl<In: Numeric> ReduceInstruction<In> for Sum {
16 type AccumulatorItem = Line<In>;
17 type SharedAccumulator = SharedMemory<Line<In>>;
18 type Config = ();
19
20 fn requirements(_this: &Self) -> ReduceRequirements {
21 ReduceRequirements { coordinates: false }
22 }
23
24 fn from_config(_config: Self::Config) -> Self {
25 Sum {}
26 }
27 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
28 Line::empty(line_size).fill(In::from_int(0))
29 }
30
31 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
32 Self::null_input(this, line_size)
33 }
34
35 fn assign_accumulator(
36 _this: &Self,
37 destination: &mut Self::AccumulatorItem,
38 source: &Self::AccumulatorItem,
39 ) {
40 *destination = *source;
41 }
42
43 fn reduce(
44 _this: &Self,
45 accumulator: &Self::AccumulatorItem,
46 item: Line<In>,
47 _coordinate: ReduceCoordinate,
48 #[comptime] use_planes: bool,
49 ) -> Self::AccumulatorItem {
50 if use_planes {
51 *accumulator + plane_sum(item)
52 } else {
53 *accumulator + item
54 }
55 }
56
57 fn fuse_accumulators(
58 _this: &Self,
59 lhs: Self::AccumulatorItem,
60 rhs: Self::AccumulatorItem,
61 ) -> Self::AccumulatorItem {
62 lhs + rhs
63 }
64
65 fn merge_line<Out: Numeric>(
66 _this: &Self,
67 accumulator: Self::AccumulatorItem,
68 _shape_axis_reduce: u32,
69 ) -> Out {
70 let mut sum = In::from_int(0);
71 #[unroll]
72 for k in 0..accumulator.size() {
73 sum += accumulator[k];
74 }
75 Out::cast_from(sum)
76 }
77
78 fn to_output_perpendicular<Out: Numeric>(
79 _this: &Self,
80 accumulator: Self::AccumulatorItem,
81 _shape_axis_reduce: u32,
82 ) -> Line<Out> {
83 Line::cast_from(accumulator)
84 }
85}