1use super::{
2 ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
3 ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
4};
5use crate::ReduceDtypes;
6use crate::precision::ReducePrecision;
7use cubecl_core::ir::{FloatKind, IntKind, UIntKind};
8use cubecl_core::prelude::*;
9use cubecl_core::{self as cubecl, ir::ElemType};
10use cubecl_std::{CubeOption, CubeOptionExpand};
11
12#[derive(Debug, CubeType, Clone)]
13pub enum ReduceFn {
14 Sum(Sum),
15 Prod(Prod),
16 Mean(Mean),
17 MaxAbs(MaxAbs),
18 ArgMax(ArgMax),
19 ArgMin(ArgMin),
20 Max(Max),
21 Min(Min),
22}
23
24#[derive_cube_comptime]
25pub enum ReduceFnConfig {
26 Sum,
27 Prod,
28 Mean,
29 MaxAbs,
30 ArgMax,
31 ArgMin,
32 Max,
33 Min,
34}
35
36impl ReduceFnConfig {
37 pub fn precision(&self, input: ElemType) -> ReduceDtypes {
39 match self {
40 ReduceFnConfig::Sum | ReduceFnConfig::Prod | ReduceFnConfig::Mean => {}
41 ReduceFnConfig::MaxAbs | ReduceFnConfig::Max | ReduceFnConfig::Min => {
43 return ReduceDtypes {
44 input: input.into(),
45 output: input.into(),
46 accumulation: input.into(),
47 };
48 }
49 ReduceFnConfig::ArgMax | ReduceFnConfig::ArgMin => {
50 return ReduceDtypes {
51 input: input.into(),
52 output: i32::as_type_native_unchecked(),
53 accumulation: input.into(),
54 };
55 }
56 };
57
58 match input {
59 ElemType::Float(kind) => {
60 let acc = match kind {
61 FloatKind::F64 => f64::as_type_native_unchecked(),
62 _ => f32::as_type_native_unchecked(),
63 };
64
65 ReduceDtypes {
66 input: input.into(),
67 output: input.into(),
68 accumulation: acc,
69 }
70 }
71 ElemType::Int(kind) => {
72 let acc = match kind {
73 IntKind::I64 => i64::as_type_native_unchecked(),
74 _ => i32::as_type_native_unchecked(),
75 };
76
77 ReduceDtypes {
78 input: input.into(),
79 output: input.into(),
80 accumulation: acc,
81 }
82 }
83 ElemType::UInt(kind) => {
84 let acc = match kind {
85 UIntKind::U64 => u64::as_type_native_unchecked(),
86 _ => u32::as_type_native_unchecked(),
87 };
88
89 ReduceDtypes {
90 input: input.into(),
91 output: input.into(),
92 accumulation: acc,
93 }
94 }
95 ElemType::Bool => panic!("Can't reduce on booleans"),
96 }
97 }
98}
99
100impl ReduceFamily for ReduceFn {
101 type Instruction<P: ReducePrecision> = Self;
102 type Config = ReduceFnConfig;
103}
104
105#[derive(CubeType)]
106pub struct DynamicAccumulator<N: Numeric> {
107 pub elements: SharedMemory<Line<N>>,
108 pub args: CubeOption<SharedMemory<Line<u32>>>,
109}
110
111#[derive(CubeType)]
112pub struct DynamicAccumulatorItem<N: Numeric> {
113 pub elements: Line<N>,
114 pub args: CubeOption<Line<u32>>,
115}
116
117#[cube]
118impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
119 type Item = DynamicAccumulatorItem<In>;
120
121 fn allocate(
122 #[comptime] length: u32,
123 #[comptime] line_size: u32,
124 #[comptime] coordinate: bool,
125 ) -> Self {
126 let elements = SharedMemory::new_lined(length, line_size);
127 let args = if comptime![coordinate] {
128 let args = SharedMemory::new_lined(length, line_size);
129 CubeOption::new_Some(args)
130 } else {
131 CubeOption::new_None()
132 };
133
134 DynamicAccumulator::<In> { elements, args }
135 }
136
137 fn read(accumulator: &Self, index: u32) -> Self::Item {
138 let elements = accumulator.elements[index];
139 let args = match accumulator.args {
140 CubeOption::Some(args) => CubeOption::new_Some(args[index]),
141 CubeOption::None => CubeOption::new_None(),
142 };
143
144 DynamicAccumulatorItem::<In> { elements, args }
145 }
146
147 fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
148 accumulator.elements[index] = item.elements;
149
150 let args = &mut accumulator.args;
151 match args {
152 CubeOption::Some(args) => {
153 args[index] = item.args.unwrap();
154 }
155 CubeOption::None => {}
156 };
157 }
158}
159
160#[cube]
161impl<P: ReducePrecision> ReduceInstruction<P> for ReduceFn {
162 type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
163 type SharedAccumulator = DynamicAccumulator<P::EA>;
164 type Config = ReduceFnConfig;
165
166 fn requirements(this: &Self) -> ReduceRequirements {
167 let coordinates = match this {
168 ReduceFn::Sum(..) => comptime![false],
169 ReduceFn::Prod(..) => comptime![false],
170 ReduceFn::Mean(..) => comptime![false],
171 ReduceFn::MaxAbs(..) => comptime![false],
172 ReduceFn::ArgMax(..) => comptime![true],
173 ReduceFn::ArgMin(..) => comptime![true],
174 ReduceFn::Max(..) => comptime![false],
175 ReduceFn::Min(..) => comptime![false],
176 };
177 ReduceRequirements {
178 coordinates: comptime! {coordinates},
179 }
180 }
181
182 fn from_config(#[comptime] config: Self::Config) -> Self {
183 match config {
184 ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
185 ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
186 ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
187 ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
188 ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
189 ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
190 ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
191 ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
192 }
193 }
194
195 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
196 match this {
197 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
198 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::null_input(prod, line_size),
199 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::null_input(mean, line_size),
200 ReduceFn::MaxAbs(maxabs) => {
201 <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
202 }
203 ReduceFn::ArgMax(argmax) => {
204 <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
205 }
206 ReduceFn::ArgMin(argmin) => {
207 <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
208 }
209 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
210 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
211 }
212 }
213
214 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
215 match this {
216 ReduceFn::Sum(sum) => {
217 let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
218
219 DynamicAccumulatorItem::<P::EA> {
220 elements,
221 args: CubeOption::new_None(),
222 }
223 }
224 ReduceFn::Mean(sum) => {
225 let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
226
227 DynamicAccumulatorItem::<P::EA> {
228 elements,
229 args: CubeOption::new_None(),
230 }
231 }
232 ReduceFn::Prod(sum) => {
233 let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
234
235 DynamicAccumulatorItem::<P::EA> {
236 elements,
237 args: CubeOption::new_None(),
238 }
239 }
240 ReduceFn::MaxAbs(maxabs) => {
241 let elements =
242 <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
243
244 DynamicAccumulatorItem::<P::EA> {
245 elements,
246 args: CubeOption::new_None(),
247 }
248 }
249 ReduceFn::ArgMax(argmax) => {
250 let (elements, args) =
251 <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
252
253 DynamicAccumulatorItem::<P::EA> {
254 elements,
255 args: CubeOption::new_Some(args),
256 }
257 }
258 ReduceFn::ArgMin(argmin) => {
259 let (elements, args) =
260 <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
261
262 DynamicAccumulatorItem::<P::EA> {
263 elements,
264 args: CubeOption::new_Some(args),
265 }
266 }
267 ReduceFn::Max(max) => {
268 let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
269
270 DynamicAccumulatorItem::<P::EA> {
271 elements,
272 args: CubeOption::new_None(),
273 }
274 }
275 ReduceFn::Min(min) => {
276 let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
277
278 DynamicAccumulatorItem::<P::EA> {
279 elements,
280 args: CubeOption::new_None(),
281 }
282 }
283 }
284 }
285
286 fn assign_accumulator(
287 _this: &Self,
288 destination: &mut Self::AccumulatorItem,
289 source: &Self::AccumulatorItem,
290 ) {
291 destination.elements = source.elements;
292 let args = &mut destination.args;
293 match args {
294 CubeOption::Some(val) => *val = source.args.unwrap(),
295 CubeOption::None => {}
296 }
297 }
298
299 fn reduce(
300 this: &Self,
301 accumulator: &Self::AccumulatorItem,
302 item: Line<P::EI>,
303 coordinate: ReduceCoordinate,
304 #[comptime] use_planes: bool,
305 ) -> Self::AccumulatorItem {
306 match this {
307 ReduceFn::Sum(sum) => {
308 let elements = <Sum as ReduceInstruction<P>>::reduce(
309 sum,
310 &accumulator.elements,
311 item,
312 coordinate,
313 use_planes,
314 );
315 DynamicAccumulatorItem::<P::EA> {
316 elements,
317 args: CubeOption::new_None(),
318 }
319 }
320 ReduceFn::Prod(sum) => {
321 let elements = <Prod as ReduceInstruction<P>>::reduce(
322 sum,
323 &accumulator.elements,
324 item,
325 coordinate,
326 use_planes,
327 );
328 DynamicAccumulatorItem::<P::EA> {
329 elements,
330 args: CubeOption::new_None(),
331 }
332 }
333 ReduceFn::Mean(sum) => {
334 let elements = <Mean as ReduceInstruction<P>>::reduce(
335 sum,
336 &accumulator.elements,
337 item,
338 coordinate,
339 use_planes,
340 );
341 DynamicAccumulatorItem::<P::EA> {
342 elements,
343 args: CubeOption::new_None(),
344 }
345 }
346 ReduceFn::MaxAbs(maxabs) => {
347 let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
348 maxabs,
349 &accumulator.elements,
350 item,
351 coordinate,
352 use_planes,
353 );
354 DynamicAccumulatorItem::<P::EA> {
355 elements,
356 args: CubeOption::new_None(),
357 }
358 }
359 ReduceFn::ArgMax(argmax) => {
360 let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
361 argmax,
362 &(accumulator.elements, accumulator.args.unwrap()),
363 item,
364 coordinate,
365 use_planes,
366 );
367
368 DynamicAccumulatorItem::<P::EA> {
369 elements,
370 args: CubeOption::new_Some(args),
371 }
372 }
373 ReduceFn::ArgMin(argmin) => {
374 let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
375 argmin,
376 &(accumulator.elements, accumulator.args.unwrap()),
377 item,
378 coordinate,
379 use_planes,
380 );
381
382 DynamicAccumulatorItem::<P::EA> {
383 elements,
384 args: CubeOption::new_Some(args),
385 }
386 }
387 ReduceFn::Max(max) => {
388 let elements = <Max as ReduceInstruction<P>>::reduce(
389 max,
390 &accumulator.elements,
391 item,
392 coordinate,
393 use_planes,
394 );
395 DynamicAccumulatorItem::<P::EA> {
396 elements,
397 args: CubeOption::new_None(),
398 }
399 }
400 ReduceFn::Min(min) => {
401 let elements = <Min as ReduceInstruction<P>>::reduce(
402 min,
403 &accumulator.elements,
404 item,
405 coordinate,
406 use_planes,
407 );
408 DynamicAccumulatorItem::<P::EA> {
409 elements,
410 args: CubeOption::new_None(),
411 }
412 }
413 }
414 }
415
416 fn fuse_accumulators(
417 this: &Self,
418 lhs: Self::AccumulatorItem,
419 rhs: Self::AccumulatorItem,
420 ) -> Self::AccumulatorItem {
421 match this {
422 ReduceFn::Sum(sum) => {
423 let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
424 sum,
425 lhs.elements,
426 rhs.elements,
427 );
428 DynamicAccumulatorItem::<P::EA> {
429 elements,
430 args: CubeOption::new_None(),
431 }
432 }
433 ReduceFn::Prod(prod) => {
434 let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
435 prod,
436 lhs.elements,
437 rhs.elements,
438 );
439 DynamicAccumulatorItem::<P::EA> {
440 elements,
441 args: CubeOption::new_None(),
442 }
443 }
444 ReduceFn::Mean(mean) => {
445 let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
446 mean,
447 lhs.elements,
448 rhs.elements,
449 );
450 DynamicAccumulatorItem::<P::EA> {
451 elements,
452 args: CubeOption::new_None(),
453 }
454 }
455 ReduceFn::MaxAbs(maxabs) => {
456 let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
457 maxabs,
458 lhs.elements,
459 rhs.elements,
460 );
461 DynamicAccumulatorItem::<P::EA> {
462 elements,
463 args: CubeOption::new_None(),
464 }
465 }
466 ReduceFn::ArgMax(argmax) => {
467 let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
468 argmax,
469 (lhs.elements, lhs.args.unwrap()),
470 (rhs.elements, rhs.args.unwrap()),
471 );
472 DynamicAccumulatorItem::<P::EA> {
473 elements,
474 args: CubeOption::new_Some(args),
475 }
476 }
477 ReduceFn::ArgMin(argmin) => {
478 let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
479 argmin,
480 (lhs.elements, lhs.args.unwrap()),
481 (rhs.elements, rhs.args.unwrap()),
482 );
483 DynamicAccumulatorItem::<P::EA> {
484 elements,
485 args: CubeOption::new_Some(args),
486 }
487 }
488 ReduceFn::Max(max) => {
489 let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
490 max,
491 lhs.elements,
492 rhs.elements,
493 );
494 DynamicAccumulatorItem::<P::EA> {
495 elements,
496 args: CubeOption::new_None(),
497 }
498 }
499 ReduceFn::Min(min) => {
500 let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
501 min,
502 lhs.elements,
503 rhs.elements,
504 );
505 DynamicAccumulatorItem::<P::EA> {
506 elements,
507 args: CubeOption::new_None(),
508 }
509 }
510 }
511 }
512
513 fn merge_line<Out: Numeric>(
516 this: &Self,
517 accumulator: Self::AccumulatorItem,
518 shape_axis_reduce: u32,
519 ) -> Out {
520 match this {
521 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
522 sum,
523 accumulator.elements,
524 shape_axis_reduce,
525 ),
526 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
527 prod,
528 accumulator.elements,
529 shape_axis_reduce,
530 ),
531 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
532 mean,
533 accumulator.elements,
534 shape_axis_reduce,
535 ),
536 ReduceFn::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
537 maxabs,
538 accumulator.elements,
539 shape_axis_reduce,
540 ),
541 ReduceFn::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
542 argmax,
543 (accumulator.elements, accumulator.args.unwrap()),
544 shape_axis_reduce,
545 ),
546 ReduceFn::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
547 argmin,
548 (accumulator.elements, accumulator.args.unwrap()),
549 shape_axis_reduce,
550 ),
551 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
552 max,
553 accumulator.elements,
554 shape_axis_reduce,
555 ),
556 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
557 min,
558 accumulator.elements,
559 shape_axis_reduce,
560 ),
561 }
562 }
563
564 fn to_output_perpendicular<Out: Numeric>(
565 this: &Self,
566 accumulator: Self::AccumulatorItem,
567 shape_axis_reduce: u32,
568 ) -> Line<Out> {
569 match this {
570 ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
571 sum,
572 accumulator.elements,
573 shape_axis_reduce,
574 ),
575 ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
576 prod,
577 accumulator.elements,
578 shape_axis_reduce,
579 ),
580 ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
581 mean,
582 accumulator.elements,
583 shape_axis_reduce,
584 ),
585 ReduceFn::MaxAbs(maxabs) => {
586 <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
587 maxabs,
588 accumulator.elements,
589 shape_axis_reduce,
590 )
591 }
592 ReduceFn::ArgMax(args) => {
593 <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
594 args,
595 (accumulator.elements, accumulator.args.unwrap()),
596 shape_axis_reduce,
597 )
598 }
599 ReduceFn::ArgMin(args) => {
600 <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
601 args,
602 (accumulator.elements, accumulator.args.unwrap()),
603 shape_axis_reduce,
604 )
605 }
606 ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
607 max,
608 accumulator.elements,
609 shape_axis_reduce,
610 ),
611 ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
612 min,
613 accumulator.elements,
614 shape_axis_reduce,
615 ),
616 }
617 }
618}