cubecl_reduce/instructions/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4pub trait ReduceFamily: Send + Sync + 'static + std::fmt::Debug {
5 type Instruction<In: Numeric>: ReduceInstruction<In, Config = Self::Config>;
6 type Config: CubeComptime + Send + Sync;
7}
8
9#[derive(CubeType)]
10pub struct ReduceRequirements {
11 #[cube(comptime)]
12 pub coordinates: bool,
13}
14
15#[cube]
24pub trait ReduceInstruction<In: Numeric>:
25 Send + Sync + 'static + std::fmt::Debug + CubeType
26{
27 type Config: CubeComptime + Send + Sync;
28
29 fn requirements(this: &Self) -> ReduceRequirements;
31
32 type AccumulatorItem: CubeType;
35
36 type SharedAccumulator: SharedAccumulator<In, Item = Self::AccumulatorItem>;
40
41 fn from_config(#[comptime] config: Self::Config) -> Self;
42 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In>;
45
46 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem;
49
50 fn assign_accumulator(
54 this: &Self,
55 destination: &mut Self::AccumulatorItem,
56 source: &Self::AccumulatorItem,
57 );
58
59 fn reduce(
62 this: &Self,
63 accumulator: &Self::AccumulatorItem,
64 item: Line<In>,
65 coordinate: ReduceCoordinate,
66 #[comptime] use_planes: bool,
67 ) -> Self::AccumulatorItem;
68
69 fn fuse_accumulators(
71 this: &Self,
72 lhs: Self::AccumulatorItem,
73 rhs: Self::AccumulatorItem,
74 ) -> Self::AccumulatorItem;
75
76 fn merge_line<Out: Numeric>(
78 this: &Self,
79 accumulator: Self::AccumulatorItem,
80 shape_axis_reduce: u32,
81 ) -> Out;
82
83 fn to_output_perpendicular<Out: Numeric>(
85 this: &Self,
86 accumulator: Self::AccumulatorItem,
87 shape_axis_reduce: u32,
88 ) -> Line<Out>;
89}
90
91#[derive(CubeType)]
92pub enum ReduceCoordinate {
93 Required(Line<u32>),
94 NotRequired,
95}
96
97#[cube]
99pub trait SharedAccumulator<In: Numeric>: CubeType + Send + Sync + 'static {
100 type Item: CubeType;
101
102 fn allocate(
103 #[comptime] length: u32,
104 #[comptime] line_size: u32,
105 #[comptime] _coordinate: bool,
106 ) -> Self;
107
108 fn read(accumulator: &Self, index: u32) -> Self::Item;
109
110 fn write(accumulator: &mut Self, index: u32, item: Self::Item);
111}
112
113#[cube]
114impl<In: Numeric> SharedAccumulator<In> for SharedMemory<Line<In>> {
115 type Item = Line<In>;
116
117 fn allocate(
118 #[comptime] length: u32,
119 #[comptime] line_size: u32,
120 #[comptime] _coordinate: bool,
121 ) -> Self {
122 SharedMemory::new_lined(length, line_size)
123 }
124
125 fn read(accumulator: &Self, index: u32) -> Self::Item {
126 accumulator[index]
127 }
128
129 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
130 accumulator[index] = item;
131 }
132}
133
134#[derive(CubeType)]
136pub struct ArgAccumulator<N: Numeric> {
137 pub elements: SharedMemory<Line<N>>,
138 pub args: SharedMemory<Line<u32>>,
139}
140
141#[cube]
142impl<In: Numeric> SharedAccumulator<In> for ArgAccumulator<In> {
143 type Item = (Line<In>, Line<u32>);
144
145 fn allocate(
146 #[comptime] length: u32,
147 #[comptime] line_size: u32,
148 #[comptime] _coordinate: bool,
149 ) -> Self {
150 ArgAccumulator::<In> {
151 elements: SharedMemory::new_lined(length, line_size),
152 args: SharedMemory::new_lined(length, line_size),
153 }
154 }
155
156 fn read(accumulator: &Self, index: u32) -> Self::Item {
157 (accumulator.elements[index], accumulator.args[index])
158 }
159
160 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
161 accumulator.elements[index] = item.0;
162 accumulator.args[index] = item.1;
163 }
164}
165
166#[cube]
167pub fn reduce_inplace<In: Numeric, R: ReduceInstruction<In>>(
168 inst: &R,
169 accumulator: &mut R::AccumulatorItem,
170 item: Line<In>,
171 coordinate: ReduceCoordinate,
172 #[comptime] use_planes: bool,
173) {
174 let reduction = &R::reduce(inst, accumulator, item, coordinate, use_planes);
175 R::assign_accumulator(inst, accumulator, reduction);
176}
177
178#[cube]
179pub fn reduce_shared_inplace<In: Numeric, R: ReduceInstruction<In>>(
180 inst: &R,
181 accumulator: &mut R::SharedAccumulator,
182 index: u32,
183 item: Line<In>,
184 coordinate: ReduceCoordinate,
185 #[comptime] use_planes: bool,
186) {
187 let acc_item = R::SharedAccumulator::read(accumulator, index);
188 let reduction = R::reduce(inst, &acc_item, item, coordinate, use_planes);
189 R::SharedAccumulator::write(accumulator, index, reduction);
190}
191
192#[cube]
193pub fn fuse_accumulator_inplace<In: Numeric, R: ReduceInstruction<In>>(
194 inst: &R,
195 accumulator: &mut R::SharedAccumulator,
196 destination: u32,
197 origin: u32,
198) {
199 let fused = R::fuse_accumulators(
200 inst,
201 R::SharedAccumulator::read(accumulator, destination),
202 R::SharedAccumulator::read(accumulator, origin),
203 );
204 R::SharedAccumulator::write(accumulator, destination, fused);
205}