1use super::{
2 ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
3 ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
4};
5use crate::{ReduceDtypes, components::precision::ReducePrecision};
6use cubecl::{
7 ir::{ElemType, FloatKind, IntKind, UIntKind},
8 prelude::*,
9 std::{CubeOption, CubeOptionExpand},
10};
11
12#[derive(Debug, CubeType, Clone)]
13#[allow(unused)]
14pub(crate) enum ReduceOperation {
15 Sum(Sum),
16 Prod(Prod),
17 Mean(Mean),
18 MaxAbs(MaxAbs),
19 ArgMax(ArgMax),
20 ArgMin(ArgMin),
21 Max(Max),
22 Min(Min),
23}
24
25#[derive_cube_comptime]
26pub enum ReduceOperationConfig {
27 Sum,
28 Prod,
29 Mean,
30 MaxAbs,
31 ArgMax,
32 ArgMin,
33 Max,
34 Min,
35}
36
37impl ReduceOperationConfig {
38 pub fn precision(&self, input: ElemType, output: Option<ElemType>) -> ReduceDtypes {
40 match self {
41 ReduceOperationConfig::Sum
42 | ReduceOperationConfig::Prod
43 | ReduceOperationConfig::Mean => {}
44 ReduceOperationConfig::MaxAbs
46 | ReduceOperationConfig::Max
47 | ReduceOperationConfig::Min => {
48 return ReduceDtypes {
49 input: input.into(),
50 output: input.into(),
51 accumulation: input.into(),
52 };
53 }
54 ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin => {
55 return ReduceDtypes {
56 input: input.into(),
57 output: output
58 .expect("ArgMax and ArgMin must specify output type")
59 .into(),
60 accumulation: input.into(),
61 };
62 }
63 };
64
65 match input {
66 ElemType::Float(kind) => {
67 let acc = match kind {
68 FloatKind::F64 => f64::as_type_native_unchecked(),
69 _ => f32::as_type_native_unchecked(),
70 };
71
72 ReduceDtypes {
73 input: input.into(),
74 output: input.into(),
75 accumulation: acc,
76 }
77 }
78 ElemType::Int(kind) => {
79 let acc = match kind {
80 IntKind::I64 => i64::as_type_native_unchecked(),
81 _ => i32::as_type_native_unchecked(),
82 };
83
84 ReduceDtypes {
85 input: input.into(),
86 output: input.into(),
87 accumulation: acc,
88 }
89 }
90 ElemType::UInt(kind) => {
91 let acc = match kind {
92 UIntKind::U64 => u64::as_type_native_unchecked(),
93 _ => u32::as_type_native_unchecked(),
94 };
95
96 ReduceDtypes {
97 input: input.into(),
98 output: input.into(),
99 accumulation: acc,
100 }
101 }
102 ElemType::Bool => panic!("Can't reduce on booleans"),
103 }
104 }
105}
106
107impl ReduceFamily for ReduceOperation {
108 type Instruction<P: ReducePrecision> = Self;
109 type Config = ReduceOperationConfig;
110}
111
112#[derive(CubeType)]
113pub struct DynamicAccumulator<N: Numeric> {
114 pub elements: SharedMemory<Line<N>>,
115 pub args: CubeOption<SharedMemory<Line<u32>>>,
116}
117
118#[derive(CubeType)]
119pub struct DynamicAccumulatorItem<N: Numeric> {
120 pub elements: Line<N>,
121 pub args: CubeOption<Line<u32>>,
122}
123
124#[cube]
125impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
126 type Item = DynamicAccumulatorItem<In>;
127
128 fn allocate(
129 #[comptime] length: usize,
130 #[comptime] line_size: LineSize,
131 #[comptime] coordinate: bool,
132 ) -> Self {
133 let elements = SharedMemory::new_lined(length, line_size);
134 let args = if coordinate {
135 let args = SharedMemory::new_lined(length, line_size);
136 CubeOption::new_Some(args)
137 } else {
138 CubeOption::new_None()
139 };
140
141 DynamicAccumulator::<In> { elements, args }
142 }
143
144 fn read(accumulator: &Self, index: usize) -> Self::Item {
145 let elements = accumulator.elements[index];
146 let args = match accumulator.args {
147 CubeOption::Some(args) => CubeOption::new_Some(args[index]),
148 CubeOption::None => CubeOption::new_None(),
149 };
150
151 DynamicAccumulatorItem::<In> { elements, args }
152 }
153
154 fn write(accumulator: &mut Self, index: usize, item: Self::Item) {
155 accumulator.elements[index] = item.elements;
156
157 let args = &mut accumulator.args;
158 match args {
159 CubeOption::Some(args) => {
160 args[index] = item.args.unwrap();
161 }
162 CubeOption::None => {}
163 };
164 }
165}
166
167#[cube]
168impl<P: ReducePrecision> ReduceInstruction<P> for ReduceOperation {
169 type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
170 type SharedAccumulator = DynamicAccumulator<P::EA>;
171 type Config = ReduceOperationConfig;
172
173 fn requirements(this: &Self) -> ReduceRequirements {
174 let coordinates = match this {
175 ReduceOperation::Sum(..) => false,
176 ReduceOperation::Prod(..) => false,
177 ReduceOperation::Mean(..) => false,
178 ReduceOperation::MaxAbs(..) => false,
179 ReduceOperation::ArgMax(..) => true,
180 ReduceOperation::ArgMin(..) => true,
181 ReduceOperation::Max(..) => false,
182 ReduceOperation::Min(..) => false,
183 };
184 ReduceRequirements { coordinates }
185 }
186
187 fn from_config(#[comptime] config: Self::Config) -> Self {
188 match config {
189 ReduceOperationConfig::Sum => ReduceOperation::new_Sum(Sum {}),
190 ReduceOperationConfig::Prod => ReduceOperation::new_Prod(Prod {}),
191 ReduceOperationConfig::Mean => ReduceOperation::new_Mean(Mean { sum: Sum {} }),
192 ReduceOperationConfig::MaxAbs => ReduceOperation::new_MaxAbs(MaxAbs {}),
193 ReduceOperationConfig::ArgMax => ReduceOperation::new_ArgMax(ArgMax {}),
194 ReduceOperationConfig::ArgMin => ReduceOperation::new_ArgMin(ArgMin {}),
195 ReduceOperationConfig::Max => ReduceOperation::new_Max(Max {}),
196 ReduceOperationConfig::Min => ReduceOperation::new_Min(Min {}),
197 }
198 }
199
200 fn null_input(this: &Self, #[comptime] line_size: LineSize) -> Line<P::EI> {
201 match this {
202 ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
203 ReduceOperation::Prod(prod) => {
204 <Prod as ReduceInstruction<P>>::null_input(prod, line_size)
205 }
206 ReduceOperation::Mean(mean) => {
207 <Mean as ReduceInstruction<P>>::null_input(mean, line_size)
208 }
209 ReduceOperation::MaxAbs(maxabs) => {
210 <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
211 }
212 ReduceOperation::ArgMax(argmax) => {
213 <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
214 }
215 ReduceOperation::ArgMin(argmin) => {
216 <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
217 }
218 ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
219 ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
220 }
221 }
222
223 fn null_accumulator(this: &Self, #[comptime] line_size: LineSize) -> Self::AccumulatorItem {
224 match this {
225 ReduceOperation::Sum(sum) => {
226 let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
227
228 DynamicAccumulatorItem::<P::EA> {
229 elements,
230 args: CubeOption::new_None(),
231 }
232 }
233 ReduceOperation::Mean(sum) => {
234 let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
235
236 DynamicAccumulatorItem::<P::EA> {
237 elements,
238 args: CubeOption::new_None(),
239 }
240 }
241 ReduceOperation::Prod(sum) => {
242 let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
243
244 DynamicAccumulatorItem::<P::EA> {
245 elements,
246 args: CubeOption::new_None(),
247 }
248 }
249 ReduceOperation::MaxAbs(maxabs) => {
250 let elements =
251 <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
252
253 DynamicAccumulatorItem::<P::EA> {
254 elements,
255 args: CubeOption::new_None(),
256 }
257 }
258 ReduceOperation::ArgMax(argmax) => {
259 let (elements, args) =
260 <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
261
262 DynamicAccumulatorItem::<P::EA> {
263 elements,
264 args: CubeOption::new_Some(args),
265 }
266 }
267 ReduceOperation::ArgMin(argmin) => {
268 let (elements, args) =
269 <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
270
271 DynamicAccumulatorItem::<P::EA> {
272 elements,
273 args: CubeOption::new_Some(args),
274 }
275 }
276 ReduceOperation::Max(max) => {
277 let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
278
279 DynamicAccumulatorItem::<P::EA> {
280 elements,
281 args: CubeOption::new_None(),
282 }
283 }
284 ReduceOperation::Min(min) => {
285 let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
286
287 DynamicAccumulatorItem::<P::EA> {
288 elements,
289 args: CubeOption::new_None(),
290 }
291 }
292 }
293 }
294
295 fn read_accumulator(
296 this: &Self,
297 accumulator: &Self::AccumulatorItem,
298 ) -> (Line<P::EI>, ReduceCoordinate) {
299 match this {
300 ReduceOperation::Sum(sum) => {
301 <Sum as ReduceInstruction<P>>::read_accumulator(sum, &accumulator.elements)
302 }
303 ReduceOperation::Prod(prod) => {
304 <Prod as ReduceInstruction<P>>::read_accumulator(prod, &accumulator.elements)
305 }
306 ReduceOperation::Mean(mean) => {
307 <Mean as ReduceInstruction<P>>::read_accumulator(mean, &accumulator.elements)
308 }
309 ReduceOperation::MaxAbs(maxabs) => {
310 <MaxAbs as ReduceInstruction<P>>::read_accumulator(maxabs, &accumulator.elements)
311 }
312 ReduceOperation::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::read_accumulator(
313 argmax,
314 &(accumulator.elements, accumulator.args.unwrap()),
315 ),
316 ReduceOperation::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::read_accumulator(
317 argmin,
318 &(accumulator.elements, accumulator.args.unwrap()),
319 ),
320 ReduceOperation::Max(max) => {
321 <Max as ReduceInstruction<P>>::read_accumulator(max, &accumulator.elements)
322 }
323 ReduceOperation::Min(min) => {
324 <Min as ReduceInstruction<P>>::read_accumulator(min, &accumulator.elements)
325 }
326 }
327 }
328
329 fn assign_accumulator(
330 _this: &Self,
331 destination: &mut Self::AccumulatorItem,
332 source: &Self::AccumulatorItem,
333 ) {
334 destination.elements = source.elements;
335 let args = &mut destination.args;
336 match args {
337 CubeOption::Some(val) => *val = source.args.unwrap(),
338 CubeOption::None => {}
339 }
340 }
341
342 fn reduce(
343 this: &Self,
344 accumulator: &Self::AccumulatorItem,
345 item: Line<P::EI>,
346 coordinate: ReduceCoordinate,
347 #[comptime] use_planes: bool,
348 ) -> Self::AccumulatorItem {
349 match this {
350 ReduceOperation::Sum(sum) => {
351 let elements = <Sum as ReduceInstruction<P>>::reduce(
352 sum,
353 &accumulator.elements,
354 item,
355 coordinate,
356 use_planes,
357 );
358 DynamicAccumulatorItem::<P::EA> {
359 elements,
360 args: CubeOption::new_None(),
361 }
362 }
363 ReduceOperation::Prod(sum) => {
364 let elements = <Prod as ReduceInstruction<P>>::reduce(
365 sum,
366 &accumulator.elements,
367 item,
368 coordinate,
369 use_planes,
370 );
371 DynamicAccumulatorItem::<P::EA> {
372 elements,
373 args: CubeOption::new_None(),
374 }
375 }
376 ReduceOperation::Mean(sum) => {
377 let elements = <Mean as ReduceInstruction<P>>::reduce(
378 sum,
379 &accumulator.elements,
380 item,
381 coordinate,
382 use_planes,
383 );
384 DynamicAccumulatorItem::<P::EA> {
385 elements,
386 args: CubeOption::new_None(),
387 }
388 }
389 ReduceOperation::MaxAbs(maxabs) => {
390 let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
391 maxabs,
392 &accumulator.elements,
393 item,
394 coordinate,
395 use_planes,
396 );
397 DynamicAccumulatorItem::<P::EA> {
398 elements,
399 args: CubeOption::new_None(),
400 }
401 }
402 ReduceOperation::ArgMax(argmax) => {
403 let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
404 argmax,
405 &(accumulator.elements, accumulator.args.unwrap()),
406 item,
407 coordinate,
408 use_planes,
409 );
410
411 DynamicAccumulatorItem::<P::EA> {
412 elements,
413 args: CubeOption::new_Some(args),
414 }
415 }
416 ReduceOperation::ArgMin(argmin) => {
417 let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
418 argmin,
419 &(accumulator.elements, accumulator.args.unwrap()),
420 item,
421 coordinate,
422 use_planes,
423 );
424
425 DynamicAccumulatorItem::<P::EA> {
426 elements,
427 args: CubeOption::new_Some(args),
428 }
429 }
430 ReduceOperation::Max(max) => {
431 let elements = <Max as ReduceInstruction<P>>::reduce(
432 max,
433 &accumulator.elements,
434 item,
435 coordinate,
436 use_planes,
437 );
438 DynamicAccumulatorItem::<P::EA> {
439 elements,
440 args: CubeOption::new_None(),
441 }
442 }
443 ReduceOperation::Min(min) => {
444 let elements = <Min as ReduceInstruction<P>>::reduce(
445 min,
446 &accumulator.elements,
447 item,
448 coordinate,
449 use_planes,
450 );
451 DynamicAccumulatorItem::<P::EA> {
452 elements,
453 args: CubeOption::new_None(),
454 }
455 }
456 }
457 }
458
459 fn fuse_accumulators(
460 this: &Self,
461 lhs: Self::AccumulatorItem,
462 rhs: Self::AccumulatorItem,
463 ) -> Self::AccumulatorItem {
464 match this {
465 ReduceOperation::Sum(sum) => {
466 let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
467 sum,
468 lhs.elements,
469 rhs.elements,
470 );
471 DynamicAccumulatorItem::<P::EA> {
472 elements,
473 args: CubeOption::new_None(),
474 }
475 }
476 ReduceOperation::Prod(prod) => {
477 let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
478 prod,
479 lhs.elements,
480 rhs.elements,
481 );
482 DynamicAccumulatorItem::<P::EA> {
483 elements,
484 args: CubeOption::new_None(),
485 }
486 }
487 ReduceOperation::Mean(mean) => {
488 let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
489 mean,
490 lhs.elements,
491 rhs.elements,
492 );
493 DynamicAccumulatorItem::<P::EA> {
494 elements,
495 args: CubeOption::new_None(),
496 }
497 }
498 ReduceOperation::MaxAbs(maxabs) => {
499 let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
500 maxabs,
501 lhs.elements,
502 rhs.elements,
503 );
504 DynamicAccumulatorItem::<P::EA> {
505 elements,
506 args: CubeOption::new_None(),
507 }
508 }
509 ReduceOperation::ArgMax(argmax) => {
510 let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
511 argmax,
512 (lhs.elements, lhs.args.unwrap()),
513 (rhs.elements, rhs.args.unwrap()),
514 );
515 DynamicAccumulatorItem::<P::EA> {
516 elements,
517 args: CubeOption::new_Some(args),
518 }
519 }
520 ReduceOperation::ArgMin(argmin) => {
521 let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
522 argmin,
523 (lhs.elements, lhs.args.unwrap()),
524 (rhs.elements, rhs.args.unwrap()),
525 );
526 DynamicAccumulatorItem::<P::EA> {
527 elements,
528 args: CubeOption::new_Some(args),
529 }
530 }
531 ReduceOperation::Max(max) => {
532 let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
533 max,
534 lhs.elements,
535 rhs.elements,
536 );
537 DynamicAccumulatorItem::<P::EA> {
538 elements,
539 args: CubeOption::new_None(),
540 }
541 }
542 ReduceOperation::Min(min) => {
543 let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
544 min,
545 lhs.elements,
546 rhs.elements,
547 );
548 DynamicAccumulatorItem::<P::EA> {
549 elements,
550 args: CubeOption::new_None(),
551 }
552 }
553 }
554 }
555
556 fn merge_line<Out: Numeric>(
559 this: &Self,
560 accumulator: Self::AccumulatorItem,
561 shape_axis_reduce: usize,
562 ) -> Out {
563 match this {
564 ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
565 sum,
566 accumulator.elements,
567 shape_axis_reduce,
568 ),
569 ReduceOperation::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
570 prod,
571 accumulator.elements,
572 shape_axis_reduce,
573 ),
574 ReduceOperation::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
575 mean,
576 accumulator.elements,
577 shape_axis_reduce,
578 ),
579 ReduceOperation::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
580 maxabs,
581 accumulator.elements,
582 shape_axis_reduce,
583 ),
584 ReduceOperation::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
585 argmax,
586 (accumulator.elements, accumulator.args.unwrap()),
587 shape_axis_reduce,
588 ),
589 ReduceOperation::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
590 argmin,
591 (accumulator.elements, accumulator.args.unwrap()),
592 shape_axis_reduce,
593 ),
594 ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
595 max,
596 accumulator.elements,
597 shape_axis_reduce,
598 ),
599 ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
600 min,
601 accumulator.elements,
602 shape_axis_reduce,
603 ),
604 }
605 }
606
607 fn to_output_perpendicular<Out: Numeric>(
608 this: &Self,
609 accumulator: Self::AccumulatorItem,
610 shape_axis_reduce: usize,
611 ) -> Line<Out> {
612 match this {
613 ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<
614 Out,
615 >(sum, accumulator.elements, shape_axis_reduce),
616 ReduceOperation::Prod(prod) => {
617 <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
618 prod,
619 accumulator.elements,
620 shape_axis_reduce,
621 )
622 }
623 ReduceOperation::Mean(mean) => {
624 <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
625 mean,
626 accumulator.elements,
627 shape_axis_reduce,
628 )
629 }
630 ReduceOperation::MaxAbs(maxabs) => {
631 <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
632 maxabs,
633 accumulator.elements,
634 shape_axis_reduce,
635 )
636 }
637 ReduceOperation::ArgMax(args) => {
638 <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
639 args,
640 (accumulator.elements, accumulator.args.unwrap()),
641 shape_axis_reduce,
642 )
643 }
644 ReduceOperation::ArgMin(args) => {
645 <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
646 args,
647 (accumulator.elements, accumulator.args.unwrap()),
648 shape_axis_reduce,
649 )
650 }
651 ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<
652 Out,
653 >(max, accumulator.elements, shape_axis_reduce),
654 ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<
655 Out,
656 >(min, accumulator.elements, shape_axis_reduce),
657 }
658 }
659}