1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand};
4
5use super::{
6 ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
7 ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
8};
9
10#[derive(Debug, CubeType, Clone)]
11pub enum ReduceFn {
12 Sum(Sum),
13 Prod(Prod),
14 Mean(Mean),
15 MaxAbs(MaxAbs),
16 ArgMax(ArgMax),
17 ArgMin(ArgMin),
18 Max(Max),
19 Min(Min),
20}
21
22#[derive_cube_comptime]
23pub enum ReduceFnConfig {
24 Sum,
25 Prod,
26 Mean,
27 MaxAbs,
28 ArgMax,
29 ArgMin,
30 Max,
31 Min,
32}
33
34impl ReduceFamily for ReduceFn {
35 type Instruction<In: Numeric> = Self;
36 type Config = ReduceFnConfig;
37}
38
39#[derive(CubeType)]
40pub struct DynamicAccumulator<N: Numeric> {
41 pub elements: SharedMemory<Line<N>>,
42 pub args: CubeOption<SharedMemory<Line<u32>>>,
43}
44
45#[derive(CubeType)]
46pub struct DynamicAccumulatorItem<N: Numeric> {
47 pub elements: Line<N>,
48 pub args: CubeOption<Line<u32>>,
49}
50
51#[cube]
52impl<In: Numeric> SharedAccumulator<In> for DynamicAccumulator<In> {
53 type Item = DynamicAccumulatorItem<In>;
54
55 fn allocate(
56 #[comptime] length: u32,
57 #[comptime] line_size: u32,
58 #[comptime] coordinate: bool,
59 ) -> Self {
60 let elements = SharedMemory::new_lined(length, line_size);
61 let args = if comptime![coordinate] {
62 let args = SharedMemory::new_lined(length, line_size);
63 CubeOption::new_Some(args)
64 } else {
65 CubeOption::new_None()
66 };
67
68 DynamicAccumulator::<In> { elements, args }
69 }
70
71 fn read(accumulator: &Self, index: u32) -> Self::Item {
72 let elements = accumulator.elements[index];
73 let args = match accumulator.args {
74 CubeOption::Some(args) => CubeOption::new_Some(args[index]),
75 CubeOption::None => CubeOption::new_None(),
76 };
77
78 DynamicAccumulatorItem::<In> { elements, args }
79 }
80
81 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
82 accumulator.elements[index] = item.elements;
83
84 let args = &mut accumulator.args;
85 match args {
86 CubeOption::Some(args) => {
87 args[index] = item.args.unwrap();
88 }
89 CubeOption::None => {}
90 };
91 }
92}
93
94#[cube]
95impl<In: Numeric> ReduceInstruction<In> for ReduceFn {
96 type AccumulatorItem = DynamicAccumulatorItem<In>;
97 type SharedAccumulator = DynamicAccumulator<In>;
98 type Config = ReduceFnConfig;
99
100 fn requirements(this: &Self) -> ReduceRequirements {
101 let coordinates = match this {
102 ReduceFn::Sum(..) => comptime![false],
103 ReduceFn::Prod(..) => comptime![false],
104 ReduceFn::Mean(..) => comptime![false],
105 ReduceFn::MaxAbs(..) => comptime![false],
106 ReduceFn::ArgMax(..) => comptime![true],
107 ReduceFn::ArgMin(..) => comptime![true],
108 ReduceFn::Max(..) => comptime![false],
109 ReduceFn::Min(..) => comptime![false],
110 };
111 ReduceRequirements {
112 coordinates: comptime! {coordinates},
113 }
114 }
115
116 fn from_config(#[comptime] config: Self::Config) -> Self {
117 match config {
118 ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
119 ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
120 ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
121 ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
122 ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
123 ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
124 ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
125 ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
126 }
127 }
128
129 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In> {
130 match this {
131 ReduceFn::Sum(sum) => Sum::null_input(sum, line_size),
132 ReduceFn::Prod(prod) => Prod::null_input(prod, line_size),
133 ReduceFn::Mean(mean) => Mean::null_input(mean, line_size),
134 ReduceFn::MaxAbs(maxabs) => MaxAbs::null_input(maxabs, line_size),
135 ReduceFn::ArgMax(argmax) => ArgMax::null_input(argmax, line_size),
136 ReduceFn::ArgMin(argmin) => ArgMin::null_input(argmin, line_size),
137 ReduceFn::Max(max) => Max::null_input(max, line_size),
138 ReduceFn::Min(min) => Min::null_input(min, line_size),
139 }
140 }
141
142 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
143 match this {
144 ReduceFn::Sum(sum) => {
145 let elements = Sum::null_accumulator(sum, line_size);
146
147 DynamicAccumulatorItem::<In> {
148 elements,
149 args: CubeOption::new_None(),
150 }
151 }
152 ReduceFn::Mean(sum) => {
153 let elements = Mean::null_accumulator(sum, line_size);
154
155 DynamicAccumulatorItem::<In> {
156 elements,
157 args: CubeOption::new_None(),
158 }
159 }
160 ReduceFn::Prod(sum) => {
161 let elements = Prod::null_accumulator(sum, line_size);
162
163 DynamicAccumulatorItem::<In> {
164 elements,
165 args: CubeOption::new_None(),
166 }
167 }
168 ReduceFn::MaxAbs(maxabs) => {
169 let elements = MaxAbs::null_accumulator(maxabs, line_size);
170
171 DynamicAccumulatorItem::<In> {
172 elements,
173 args: CubeOption::new_None(),
174 }
175 }
176 ReduceFn::ArgMax(argmax) => {
177 let (elements, args) = ArgMax::null_accumulator(argmax, line_size);
178
179 DynamicAccumulatorItem::<In> {
180 elements,
181 args: CubeOption::new_Some(args),
182 }
183 }
184 ReduceFn::ArgMin(argmin) => {
185 let (elements, args) = ArgMin::null_accumulator(argmin, line_size);
186
187 DynamicAccumulatorItem::<In> {
188 elements,
189 args: CubeOption::new_Some(args),
190 }
191 }
192 ReduceFn::Max(max) => {
193 let elements = Max::null_accumulator(max, line_size);
194
195 DynamicAccumulatorItem::<In> {
196 elements,
197 args: CubeOption::new_None(),
198 }
199 }
200 ReduceFn::Min(min) => {
201 let elements = Min::null_accumulator(min, line_size);
202
203 DynamicAccumulatorItem::<In> {
204 elements,
205 args: CubeOption::new_None(),
206 }
207 }
208 }
209 }
210
211 fn assign_accumulator(
212 _this: &Self,
213 destination: &mut Self::AccumulatorItem,
214 source: &Self::AccumulatorItem,
215 ) {
216 destination.elements = source.elements;
217 let args = &mut destination.args;
218 match args {
219 CubeOption::Some(val) => *val = source.args.unwrap(),
220 CubeOption::None => {}
221 }
222 }
223
224 fn reduce(
225 this: &Self,
226 accumulator: &Self::AccumulatorItem,
227 item: Line<In>,
228 coordinate: ReduceCoordinate,
229 #[comptime] use_planes: bool,
230 ) -> Self::AccumulatorItem {
231 match this {
232 ReduceFn::Sum(sum) => {
233 let elements =
234 Sum::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
235 DynamicAccumulatorItem::<In> {
236 elements,
237 args: CubeOption::new_None(),
238 }
239 }
240 ReduceFn::Prod(sum) => {
241 let elements =
242 Prod::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
243 DynamicAccumulatorItem::<In> {
244 elements,
245 args: CubeOption::new_None(),
246 }
247 }
248 ReduceFn::Mean(sum) => {
249 let elements =
250 Mean::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
251 DynamicAccumulatorItem::<In> {
252 elements,
253 args: CubeOption::new_None(),
254 }
255 }
256 ReduceFn::MaxAbs(maxabs) => {
257 let elements =
258 MaxAbs::reduce(maxabs, &accumulator.elements, item, coordinate, use_planes);
259 DynamicAccumulatorItem::<In> {
260 elements,
261 args: CubeOption::new_None(),
262 }
263 }
264 ReduceFn::ArgMax(argmax) => {
265 let (elements, args) = ArgMax::reduce(
266 argmax,
267 &(accumulator.elements, accumulator.args.unwrap()),
268 item,
269 coordinate,
270 use_planes,
271 );
272
273 DynamicAccumulatorItem::<In> {
274 elements,
275 args: CubeOption::new_Some(args),
276 }
277 }
278 ReduceFn::ArgMin(argmin) => {
279 let (elements, args) = ArgMin::reduce(
280 argmin,
281 &(accumulator.elements, accumulator.args.unwrap()),
282 item,
283 coordinate,
284 use_planes,
285 );
286
287 DynamicAccumulatorItem::<In> {
288 elements,
289 args: CubeOption::new_Some(args),
290 }
291 }
292 ReduceFn::Max(max) => {
293 let elements =
294 Max::reduce(max, &accumulator.elements, item, coordinate, use_planes);
295 DynamicAccumulatorItem::<In> {
296 elements,
297 args: CubeOption::new_None(),
298 }
299 }
300 ReduceFn::Min(min) => {
301 let elements =
302 Min::reduce(min, &accumulator.elements, item, coordinate, use_planes);
303 DynamicAccumulatorItem::<In> {
304 elements,
305 args: CubeOption::new_None(),
306 }
307 }
308 }
309 }
310
311 fn fuse_accumulators(
312 this: &Self,
313 lhs: Self::AccumulatorItem,
314 rhs: Self::AccumulatorItem,
315 ) -> Self::AccumulatorItem {
316 match this {
317 ReduceFn::Sum(sum) => {
318 let elements = Sum::fuse_accumulators(sum, lhs.elements, rhs.elements);
319 DynamicAccumulatorItem::<In> {
320 elements,
321 args: CubeOption::new_None(),
322 }
323 }
324 ReduceFn::Prod(prod) => {
325 let elements = Prod::fuse_accumulators(prod, lhs.elements, rhs.elements);
326 DynamicAccumulatorItem::<In> {
327 elements,
328 args: CubeOption::new_None(),
329 }
330 }
331 ReduceFn::Mean(mean) => {
332 let elements = Mean::fuse_accumulators(mean, lhs.elements, rhs.elements);
333 DynamicAccumulatorItem::<In> {
334 elements,
335 args: CubeOption::new_None(),
336 }
337 }
338 ReduceFn::MaxAbs(maxabs) => {
339 let elements = MaxAbs::fuse_accumulators(maxabs, lhs.elements, rhs.elements);
340 DynamicAccumulatorItem::<In> {
341 elements,
342 args: CubeOption::new_None(),
343 }
344 }
345 ReduceFn::ArgMax(argmax) => {
346 let (elements, args) = ArgMax::fuse_accumulators(
347 argmax,
348 (lhs.elements, lhs.args.unwrap()),
349 (rhs.elements, rhs.args.unwrap()),
350 );
351 DynamicAccumulatorItem::<In> {
352 elements,
353 args: CubeOption::new_Some(args),
354 }
355 }
356 ReduceFn::ArgMin(argmin) => {
357 let (elements, args) = ArgMin::fuse_accumulators(
358 argmin,
359 (lhs.elements, lhs.args.unwrap()),
360 (rhs.elements, rhs.args.unwrap()),
361 );
362 DynamicAccumulatorItem::<In> {
363 elements,
364 args: CubeOption::new_Some(args),
365 }
366 }
367 ReduceFn::Max(max) => {
368 let elements = Max::fuse_accumulators(max, lhs.elements, rhs.elements);
369 DynamicAccumulatorItem::<In> {
370 elements,
371 args: CubeOption::new_None(),
372 }
373 }
374 ReduceFn::Min(min) => {
375 let elements = Min::fuse_accumulators(min, lhs.elements, rhs.elements);
376 DynamicAccumulatorItem::<In> {
377 elements,
378 args: CubeOption::new_None(),
379 }
380 }
381 }
382 }
383
384 fn merge_line<Out: Numeric>(
387 this: &Self,
388 accumulator: Self::AccumulatorItem,
389 shape_axis_reduce: u32,
390 ) -> Out {
391 match this {
392 ReduceFn::Sum(sum) => {
393 Sum::merge_line::<Out>(sum, accumulator.elements, shape_axis_reduce)
394 }
395 ReduceFn::Prod(prod) => {
396 Prod::merge_line::<Out>(prod, accumulator.elements, shape_axis_reduce)
397 }
398 ReduceFn::Mean(mean) => {
399 Mean::merge_line::<Out>(mean, accumulator.elements, shape_axis_reduce)
400 }
401 ReduceFn::MaxAbs(maxabs) => {
402 MaxAbs::merge_line::<Out>(maxabs, accumulator.elements, shape_axis_reduce)
403 }
404 ReduceFn::ArgMax(argmax) => ArgMax::merge_line::<Out>(
405 argmax,
406 (accumulator.elements, accumulator.args.unwrap()),
407 shape_axis_reduce,
408 ),
409 ReduceFn::ArgMin(argmin) => ArgMin::merge_line::<Out>(
410 argmin,
411 (accumulator.elements, accumulator.args.unwrap()),
412 shape_axis_reduce,
413 ),
414 ReduceFn::Max(max) => {
415 Max::merge_line::<Out>(max, accumulator.elements, shape_axis_reduce)
416 }
417 ReduceFn::Min(min) => {
418 Min::merge_line::<Out>(min, accumulator.elements, shape_axis_reduce)
419 }
420 }
421 }
422
423 fn to_output_perpendicular<Out: Numeric>(
424 this: &Self,
425 accumulator: Self::AccumulatorItem,
426 shape_axis_reduce: u32,
427 ) -> Line<Out> {
428 match this {
429 ReduceFn::Sum(sum) => {
430 Sum::to_output_perpendicular::<Out>(sum, accumulator.elements, shape_axis_reduce)
431 }
432 ReduceFn::Prod(prod) => {
433 Prod::to_output_perpendicular::<Out>(prod, accumulator.elements, shape_axis_reduce)
434 }
435 ReduceFn::Mean(mean) => {
436 Mean::to_output_perpendicular::<Out>(mean, accumulator.elements, shape_axis_reduce)
437 }
438 ReduceFn::MaxAbs(maxabs) => MaxAbs::to_output_perpendicular::<Out>(
439 maxabs,
440 accumulator.elements,
441 shape_axis_reduce,
442 ),
443 ReduceFn::ArgMax(args) => ArgMax::to_output_perpendicular::<Out>(
444 args,
445 (accumulator.elements, accumulator.args.unwrap()),
446 shape_axis_reduce,
447 ),
448 ReduceFn::ArgMin(args) => ArgMin::to_output_perpendicular::<Out>(
449 args,
450 (accumulator.elements, accumulator.args.unwrap()),
451 shape_axis_reduce,
452 ),
453 ReduceFn::Max(max) => {
454 Max::to_output_perpendicular::<Out>(max, accumulator.elements, shape_axis_reduce)
455 }
456 ReduceFn::Min(min) => {
457 Min::to_output_perpendicular::<Out>(min, accumulator.elements, shape_axis_reduce)
458 }
459 }
460 }
461}