cubecl_reduce/instructions/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6pub trait ReduceFamily: Send + Sync + 'static + std::fmt::Debug {
7 type Instruction<P: ReducePrecision>: ReduceInstruction<P, Config = Self::Config>;
8 type Config: CubeComptime + Send + Sync;
9}
10
11#[derive(CubeType)]
12pub struct ReduceRequirements {
13 #[cube(comptime)]
14 pub coordinates: bool,
15}
16
17#[cube]
26pub trait ReduceInstruction<P: ReducePrecision>:
27 Send + Sync + 'static + std::fmt::Debug + CubeType
28{
29 type Config: CubeComptime + Send + Sync;
30
31 fn requirements(this: &Self) -> ReduceRequirements;
33
34 type AccumulatorItem: CubeType;
37
38 type SharedAccumulator: SharedAccumulator<Item = Self::AccumulatorItem>;
42
43 fn from_config(#[comptime] config: Self::Config) -> Self;
44 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI>;
47
48 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem;
51
52 fn assign_accumulator(
56 this: &Self,
57 destination: &mut Self::AccumulatorItem,
58 source: &Self::AccumulatorItem,
59 );
60
61 fn reduce(
64 this: &Self,
65 accumulator: &Self::AccumulatorItem,
66 item: Line<P::EI>,
67 coordinate: ReduceCoordinate,
68 #[comptime] use_planes: bool,
69 ) -> Self::AccumulatorItem;
70
71 fn fuse_accumulators(
73 this: &Self,
74 lhs: Self::AccumulatorItem,
75 rhs: Self::AccumulatorItem,
76 ) -> Self::AccumulatorItem;
77
78 fn merge_line<Out: Numeric>(
80 this: &Self,
81 accumulator: Self::AccumulatorItem,
82 shape_axis_reduce: u32,
83 ) -> Out;
84
85 fn to_output_perpendicular<Out: Numeric>(
87 this: &Self,
88 accumulator: Self::AccumulatorItem,
89 shape_axis_reduce: u32,
90 ) -> Line<Out>;
91}
92
93#[derive(CubeType)]
94pub enum ReduceCoordinate {
95 Required(Line<u32>),
96 NotRequired,
97}
98
99#[cube]
101pub trait SharedAccumulator: CubeType + Send + Sync + 'static {
102 type Item: CubeType;
103
104 fn allocate(
105 #[comptime] length: u32,
106 #[comptime] line_size: u32,
107 #[comptime] _coordinate: bool,
108 ) -> Self;
109
110 fn read(accumulator: &Self, index: u32) -> Self::Item;
111
112 fn write(accumulator: &mut Self, index: u32, item: Self::Item);
113}
114
115#[cube]
116impl<In: Numeric> SharedAccumulator for SharedMemory<Line<In>> {
117 type Item = Line<In>;
118
119 fn allocate(
120 #[comptime] length: u32,
121 #[comptime] line_size: u32,
122 #[comptime] _coordinate: bool,
123 ) -> Self {
124 SharedMemory::new_lined(length, line_size)
125 }
126
127 fn read(accumulator: &Self, index: u32) -> Self::Item {
128 accumulator[index]
129 }
130
131 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
132 accumulator[index] = item;
133 }
134}
135
136#[derive(CubeType)]
138pub struct ArgAccumulator<N: Numeric> {
139 pub elements: SharedMemory<Line<N>>,
140 pub args: SharedMemory<Line<u32>>,
141}
142
143#[cube]
144impl<In: Numeric> SharedAccumulator for ArgAccumulator<In> {
145 type Item = (Line<In>, Line<u32>);
146
147 fn allocate(
148 #[comptime] length: u32,
149 #[comptime] line_size: u32,
150 #[comptime] _coordinate: bool,
151 ) -> Self {
152 ArgAccumulator::<In> {
153 elements: SharedMemory::new_lined(length, line_size),
154 args: SharedMemory::new_lined(length, line_size),
155 }
156 }
157
158 fn read(accumulator: &Self, index: u32) -> Self::Item {
159 (accumulator.elements[index], accumulator.args[index])
160 }
161
162 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
163 accumulator.elements[index] = item.0;
164 accumulator.args[index] = item.1;
165 }
166}
167
168#[cube]
169pub fn reduce_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
170 inst: &R,
171 accumulator: &mut R::AccumulatorItem,
172 item: Line<P::EI>,
173 coordinate: ReduceCoordinate,
174 #[comptime] use_planes: bool,
175) {
176 let reduction = &R::reduce(inst, accumulator, item, coordinate, use_planes);
177 R::assign_accumulator(inst, accumulator, reduction);
178}
179
180#[cube]
181pub fn reduce_shared_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
182 inst: &R,
183 accumulator: &mut R::SharedAccumulator,
184 index: u32,
185 item: Line<P::EI>,
186 coordinate: ReduceCoordinate,
187 #[comptime] use_planes: bool,
188) {
189 let acc_item = R::SharedAccumulator::read(accumulator, index);
190 let reduction = R::reduce(inst, &acc_item, item, coordinate, use_planes);
191 R::SharedAccumulator::write(accumulator, index, reduction);
192}
193
194#[cube]
195pub fn fuse_accumulator_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
196 inst: &R,
197 accumulator: &mut R::SharedAccumulator,
198 destination: u32,
199 origin: u32,
200) {
201 let fused = R::fuse_accumulators(
202 inst,
203 R::SharedAccumulator::read(accumulator, destination),
204 R::SharedAccumulator::read(accumulator, origin),
205 );
206 R::SharedAccumulator::write(accumulator, destination, fused);
207}