1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand};
4
5use crate::precision::ReducePrecision;
6
7use super::{
8 ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
9 ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
10};
11
12#[derive(Debug, CubeType, Clone)]
13pub enum ReduceFn {
14 Sum(Sum),
15 Prod(Prod),
16 Mean(Mean),
17 MaxAbs(MaxAbs),
18 ArgMax(ArgMax),
19 ArgMin(ArgMin),
20 Max(Max),
21 Min(Min),
22}
23
24#[derive_cube_comptime]
25pub enum ReduceFnConfig {
26 Sum,
27 Prod,
28 Mean,
29 MaxAbs,
30 ArgMax,
31 ArgMin,
32 Max,
33 Min,
34}
35
36impl ReduceFamily for ReduceFn {
37 type Instruction<P: ReducePrecision> = Self;
38 type Config = ReduceFnConfig;
39}
40
41#[derive(CubeType)]
42pub struct DynamicAccumulator<N: Numeric> {
43 pub elements: SharedMemory<Line<N>>,
44 pub args: CubeOption<SharedMemory<Line<u32>>>,
45}
46
47#[derive(CubeType)]
48pub struct DynamicAccumulatorItem<N: Numeric> {
49 pub elements: Line<N>,
50 pub args: CubeOption<Line<u32>>,
51}
52
53#[cube]
54impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
55 type Item = DynamicAccumulatorItem<In>;
56
57 fn allocate(
58 #[comptime] length: u32,
59 #[comptime] line_size: u32,
60 #[comptime] coordinate: bool,
61 ) -> Self {
62 let elements = SharedMemory::new_lined(length, line_size);
63 let args = if comptime![coordinate] {
64 let args = SharedMemory::new_lined(length, line_size);
65 CubeOption::new_Some(args)
66 } else {
67 CubeOption::new_None()
68 };
69
70 DynamicAccumulator::<In> { elements, args }
71 }
72
73 fn read(accumulator: &Self, index: u32) -> Self::Item {
74 let elements = accumulator.elements[index];
75 let args = match accumulator.args {
76 CubeOption::Some(args) => CubeOption::new_Some(args[index]),
77 CubeOption::None => CubeOption::new_None(),
78 };
79
80 DynamicAccumulatorItem::<In> { elements, args }
81 }
82
83 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
84 accumulator.elements[index] = item.elements;
85
86 let args = &mut accumulator.args;
87 match args {
88 CubeOption::Some(args) => {
89 args[index] = item.args.unwrap();
90 }
91 CubeOption::None => {}
92 };
93 }
94}
95
96#[cube]
97impl<P: ReducePrecision> ReduceInstruction<P> for ReduceFn {
98 type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
99 type SharedAccumulator = DynamicAccumulator<P::EA>;
100 type Config = ReduceFnConfig;
101
102 fn requirements(this: &Self) -> ReduceRequirements {
103 let coordinates = match this {
104 ReduceFn::Sum(..) => comptime![false],
105 ReduceFn::Prod(..) => comptime![false],
106 ReduceFn::Mean(..) => comptime![false],
107 ReduceFn::MaxAbs(..) => comptime![false],
108 ReduceFn::ArgMax(..) => comptime![true],
109 ReduceFn::ArgMin(..) => comptime![true],
110 ReduceFn::Max(..) => comptime![false],
111 ReduceFn::Min(..) => comptime![false],
112 };
113 ReduceRequirements {
114 coordinates: comptime! {coordinates},
115 }
116 }
117
118 fn from_config(#[comptime] config: Self::Config) -> Self {
119 match config {
120 ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
121 ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
122 ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
123 ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
124 ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
125 ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
126 ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
127 ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
128 }
129 }
130
131 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
132 match this {
133 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
134 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::null_input(prod, line_size),
135 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::null_input(mean, line_size),
136 ReduceFn::MaxAbs(maxabs) => {
137 <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
138 }
139 ReduceFn::ArgMax(argmax) => {
140 <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
141 }
142 ReduceFn::ArgMin(argmin) => {
143 <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
144 }
145 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
146 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
147 }
148 }
149
150 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
151 match this {
152 ReduceFn::Sum(sum) => {
153 let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
154
155 DynamicAccumulatorItem::<P::EA> {
156 elements,
157 args: CubeOption::new_None(),
158 }
159 }
160 ReduceFn::Mean(sum) => {
161 let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
162
163 DynamicAccumulatorItem::<P::EA> {
164 elements,
165 args: CubeOption::new_None(),
166 }
167 }
168 ReduceFn::Prod(sum) => {
169 let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
170
171 DynamicAccumulatorItem::<P::EA> {
172 elements,
173 args: CubeOption::new_None(),
174 }
175 }
176 ReduceFn::MaxAbs(maxabs) => {
177 let elements =
178 <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
179
180 DynamicAccumulatorItem::<P::EA> {
181 elements,
182 args: CubeOption::new_None(),
183 }
184 }
185 ReduceFn::ArgMax(argmax) => {
186 let (elements, args) =
187 <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
188
189 DynamicAccumulatorItem::<P::EA> {
190 elements,
191 args: CubeOption::new_Some(args),
192 }
193 }
194 ReduceFn::ArgMin(argmin) => {
195 let (elements, args) =
196 <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
197
198 DynamicAccumulatorItem::<P::EA> {
199 elements,
200 args: CubeOption::new_Some(args),
201 }
202 }
203 ReduceFn::Max(max) => {
204 let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
205
206 DynamicAccumulatorItem::<P::EA> {
207 elements,
208 args: CubeOption::new_None(),
209 }
210 }
211 ReduceFn::Min(min) => {
212 let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
213
214 DynamicAccumulatorItem::<P::EA> {
215 elements,
216 args: CubeOption::new_None(),
217 }
218 }
219 }
220 }
221
222 fn assign_accumulator(
223 _this: &Self,
224 destination: &mut Self::AccumulatorItem,
225 source: &Self::AccumulatorItem,
226 ) {
227 destination.elements = source.elements;
228 let args = &mut destination.args;
229 match args {
230 CubeOption::Some(val) => *val = source.args.unwrap(),
231 CubeOption::None => {}
232 }
233 }
234
235 fn reduce(
236 this: &Self,
237 accumulator: &Self::AccumulatorItem,
238 item: Line<P::EI>,
239 coordinate: ReduceCoordinate,
240 #[comptime] use_planes: bool,
241 ) -> Self::AccumulatorItem {
242 match this {
243 ReduceFn::Sum(sum) => {
244 let elements = <Sum as ReduceInstruction<P>>::reduce(
245 sum,
246 &accumulator.elements,
247 item,
248 coordinate,
249 use_planes,
250 );
251 DynamicAccumulatorItem::<P::EA> {
252 elements,
253 args: CubeOption::new_None(),
254 }
255 }
256 ReduceFn::Prod(sum) => {
257 let elements = <Prod as ReduceInstruction<P>>::reduce(
258 sum,
259 &accumulator.elements,
260 item,
261 coordinate,
262 use_planes,
263 );
264 DynamicAccumulatorItem::<P::EA> {
265 elements,
266 args: CubeOption::new_None(),
267 }
268 }
269 ReduceFn::Mean(sum) => {
270 let elements = <Mean as ReduceInstruction<P>>::reduce(
271 sum,
272 &accumulator.elements,
273 item,
274 coordinate,
275 use_planes,
276 );
277 DynamicAccumulatorItem::<P::EA> {
278 elements,
279 args: CubeOption::new_None(),
280 }
281 }
282 ReduceFn::MaxAbs(maxabs) => {
283 let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
284 maxabs,
285 &accumulator.elements,
286 item,
287 coordinate,
288 use_planes,
289 );
290 DynamicAccumulatorItem::<P::EA> {
291 elements,
292 args: CubeOption::new_None(),
293 }
294 }
295 ReduceFn::ArgMax(argmax) => {
296 let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
297 argmax,
298 &(accumulator.elements, accumulator.args.unwrap()),
299 item,
300 coordinate,
301 use_planes,
302 );
303
304 DynamicAccumulatorItem::<P::EA> {
305 elements,
306 args: CubeOption::new_Some(args),
307 }
308 }
309 ReduceFn::ArgMin(argmin) => {
310 let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
311 argmin,
312 &(accumulator.elements, accumulator.args.unwrap()),
313 item,
314 coordinate,
315 use_planes,
316 );
317
318 DynamicAccumulatorItem::<P::EA> {
319 elements,
320 args: CubeOption::new_Some(args),
321 }
322 }
323 ReduceFn::Max(max) => {
324 let elements = <Max as ReduceInstruction<P>>::reduce(
325 max,
326 &accumulator.elements,
327 item,
328 coordinate,
329 use_planes,
330 );
331 DynamicAccumulatorItem::<P::EA> {
332 elements,
333 args: CubeOption::new_None(),
334 }
335 }
336 ReduceFn::Min(min) => {
337 let elements = <Min as ReduceInstruction<P>>::reduce(
338 min,
339 &accumulator.elements,
340 item,
341 coordinate,
342 use_planes,
343 );
344 DynamicAccumulatorItem::<P::EA> {
345 elements,
346 args: CubeOption::new_None(),
347 }
348 }
349 }
350 }
351
352 fn fuse_accumulators(
353 this: &Self,
354 lhs: Self::AccumulatorItem,
355 rhs: Self::AccumulatorItem,
356 ) -> Self::AccumulatorItem {
357 match this {
358 ReduceFn::Sum(sum) => {
359 let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
360 sum,
361 lhs.elements,
362 rhs.elements,
363 );
364 DynamicAccumulatorItem::<P::EA> {
365 elements,
366 args: CubeOption::new_None(),
367 }
368 }
369 ReduceFn::Prod(prod) => {
370 let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
371 prod,
372 lhs.elements,
373 rhs.elements,
374 );
375 DynamicAccumulatorItem::<P::EA> {
376 elements,
377 args: CubeOption::new_None(),
378 }
379 }
380 ReduceFn::Mean(mean) => {
381 let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
382 mean,
383 lhs.elements,
384 rhs.elements,
385 );
386 DynamicAccumulatorItem::<P::EA> {
387 elements,
388 args: CubeOption::new_None(),
389 }
390 }
391 ReduceFn::MaxAbs(maxabs) => {
392 let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
393 maxabs,
394 lhs.elements,
395 rhs.elements,
396 );
397 DynamicAccumulatorItem::<P::EA> {
398 elements,
399 args: CubeOption::new_None(),
400 }
401 }
402 ReduceFn::ArgMax(argmax) => {
403 let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
404 argmax,
405 (lhs.elements, lhs.args.unwrap()),
406 (rhs.elements, rhs.args.unwrap()),
407 );
408 DynamicAccumulatorItem::<P::EA> {
409 elements,
410 args: CubeOption::new_Some(args),
411 }
412 }
413 ReduceFn::ArgMin(argmin) => {
414 let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
415 argmin,
416 (lhs.elements, lhs.args.unwrap()),
417 (rhs.elements, rhs.args.unwrap()),
418 );
419 DynamicAccumulatorItem::<P::EA> {
420 elements,
421 args: CubeOption::new_Some(args),
422 }
423 }
424 ReduceFn::Max(max) => {
425 let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
426 max,
427 lhs.elements,
428 rhs.elements,
429 );
430 DynamicAccumulatorItem::<P::EA> {
431 elements,
432 args: CubeOption::new_None(),
433 }
434 }
435 ReduceFn::Min(min) => {
436 let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
437 min,
438 lhs.elements,
439 rhs.elements,
440 );
441 DynamicAccumulatorItem::<P::EA> {
442 elements,
443 args: CubeOption::new_None(),
444 }
445 }
446 }
447 }
448
449 fn merge_line<Out: Numeric>(
452 this: &Self,
453 accumulator: Self::AccumulatorItem,
454 shape_axis_reduce: u32,
455 ) -> Out {
456 match this {
457 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
458 sum,
459 accumulator.elements,
460 shape_axis_reduce,
461 ),
462 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
463 prod,
464 accumulator.elements,
465 shape_axis_reduce,
466 ),
467 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
468 mean,
469 accumulator.elements,
470 shape_axis_reduce,
471 ),
472 ReduceFn::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
473 maxabs,
474 accumulator.elements,
475 shape_axis_reduce,
476 ),
477 ReduceFn::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
478 argmax,
479 (accumulator.elements, accumulator.args.unwrap()),
480 shape_axis_reduce,
481 ),
482 ReduceFn::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
483 argmin,
484 (accumulator.elements, accumulator.args.unwrap()),
485 shape_axis_reduce,
486 ),
487 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
488 max,
489 accumulator.elements,
490 shape_axis_reduce,
491 ),
492 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
493 min,
494 accumulator.elements,
495 shape_axis_reduce,
496 ),
497 }
498 }
499
500 fn to_output_perpendicular<Out: Numeric>(
501 this: &Self,
502 accumulator: Self::AccumulatorItem,
503 shape_axis_reduce: u32,
504 ) -> Line<Out> {
505 match this {
506 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
507 sum,
508 accumulator.elements,
509 shape_axis_reduce,
510 ),
511 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
512 prod,
513 accumulator.elements,
514 shape_axis_reduce,
515 ),
516 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
517 mean,
518 accumulator.elements,
519 shape_axis_reduce,
520 ),
521 ReduceFn::MaxAbs(maxabs) => {
522 <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
523 maxabs,
524 accumulator.elements,
525 shape_axis_reduce,
526 )
527 }
528 ReduceFn::ArgMax(args) => {
529 <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
530 args,
531 (accumulator.elements, accumulator.args.unwrap()),
532 shape_axis_reduce,
533 )
534 }
535 ReduceFn::ArgMin(args) => {
536 <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
537 args,
538 (accumulator.elements, accumulator.args.unwrap()),
539 shape_axis_reduce,
540 )
541 }
542 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
543 max,
544 accumulator.elements,
545 shape_axis_reduce,
546 ),
547 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
548 min,
549 accumulator.elements,
550 shape_axis_reduce,
551 ),
552 }
553 }
554}