cubek_reduce/components/instructions/
mixed.rs

1use super::{
2    ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
3    ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
4};
5use crate::{ReduceDtypes, components::precision::ReducePrecision};
6use cubecl::{
7    ir::{ElemType, FloatKind, IntKind, UIntKind},
8    prelude::*,
9    std::{CubeOption, CubeOptionExpand},
10};
11
12#[derive(Debug, CubeType, Clone)]
13#[allow(unused)]
14pub(crate) enum ReduceOperation {
15    Sum(Sum),
16    Prod(Prod),
17    Mean(Mean),
18    MaxAbs(MaxAbs),
19    ArgMax(ArgMax),
20    ArgMin(ArgMin),
21    Max(Max),
22    Min(Min),
23}
24
25#[derive_cube_comptime]
26pub enum ReduceOperationConfig {
27    Sum,
28    Prod,
29    Mean,
30    MaxAbs,
31    ArgMax,
32    ArgMin,
33    Max,
34    Min,
35}
36
37impl ReduceOperationConfig {
38    /// Computes the best case precision for the given config.
39    pub fn precision(&self, input: ElemType, output: Option<ElemType>) -> ReduceDtypes {
40        match self {
41            ReduceOperationConfig::Sum
42            | ReduceOperationConfig::Prod
43            | ReduceOperationConfig::Mean => {}
44            // No benefit to mixed precision accumulation.
45            ReduceOperationConfig::MaxAbs
46            | ReduceOperationConfig::Max
47            | ReduceOperationConfig::Min => {
48                return ReduceDtypes {
49                    input: input.into(),
50                    output: input.into(),
51                    accumulation: input.into(),
52                };
53            }
54            ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin => {
55                return ReduceDtypes {
56                    input: input.into(),
57                    output: output
58                        .expect("ArgMax and ArgMin must specify output type")
59                        .into(),
60                    accumulation: input.into(),
61                };
62            }
63        };
64
65        match input {
66            ElemType::Float(kind) => {
67                let acc = match kind {
68                    FloatKind::F64 => f64::as_type_native_unchecked(),
69                    _ => f32::as_type_native_unchecked(),
70                };
71
72                ReduceDtypes {
73                    input: input.into(),
74                    output: input.into(),
75                    accumulation: acc,
76                }
77            }
78            ElemType::Int(kind) => {
79                let acc = match kind {
80                    IntKind::I64 => i64::as_type_native_unchecked(),
81                    _ => i32::as_type_native_unchecked(),
82                };
83
84                ReduceDtypes {
85                    input: input.into(),
86                    output: input.into(),
87                    accumulation: acc,
88                }
89            }
90            ElemType::UInt(kind) => {
91                let acc = match kind {
92                    UIntKind::U64 => u64::as_type_native_unchecked(),
93                    _ => u32::as_type_native_unchecked(),
94                };
95
96                ReduceDtypes {
97                    input: input.into(),
98                    output: input.into(),
99                    accumulation: acc,
100                }
101            }
102            ElemType::Bool => panic!("Can't reduce on booleans"),
103        }
104    }
105}
106
107impl ReduceFamily for ReduceOperation {
108    type Instruction<P: ReducePrecision> = Self;
109    type Config = ReduceOperationConfig;
110}
111
112#[derive(CubeType)]
113pub struct DynamicAccumulator<N: Numeric> {
114    pub elements: SharedMemory<Line<N>>,
115    pub args: CubeOption<SharedMemory<Line<u32>>>,
116}
117
118#[derive(CubeType)]
119pub struct DynamicAccumulatorItem<N: Numeric> {
120    pub elements: Line<N>,
121    pub args: CubeOption<Line<u32>>,
122}
123
124#[cube]
125impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
126    type Item = DynamicAccumulatorItem<In>;
127
128    fn allocate(
129        #[comptime] length: usize,
130        #[comptime] line_size: LineSize,
131        #[comptime] coordinate: bool,
132    ) -> Self {
133        let elements = SharedMemory::new_lined(length, line_size);
134        let args = if coordinate {
135            let args = SharedMemory::new_lined(length, line_size);
136            CubeOption::new_Some(args)
137        } else {
138            CubeOption::new_None()
139        };
140
141        DynamicAccumulator::<In> { elements, args }
142    }
143
144    fn read(accumulator: &Self, index: usize) -> Self::Item {
145        let elements = accumulator.elements[index];
146        let args = match accumulator.args {
147            CubeOption::Some(args) => CubeOption::new_Some(args[index]),
148            CubeOption::None => CubeOption::new_None(),
149        };
150
151        DynamicAccumulatorItem::<In> { elements, args }
152    }
153
154    fn write(accumulator: &mut Self, index: usize, item: Self::Item) {
155        accumulator.elements[index] = item.elements;
156
157        let args = &mut accumulator.args;
158        match args {
159            CubeOption::Some(args) => {
160                args[index] = item.args.unwrap();
161            }
162            CubeOption::None => {}
163        };
164    }
165}
166
167#[cube]
168impl<P: ReducePrecision> ReduceInstruction<P> for ReduceOperation {
169    type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
170    type SharedAccumulator = DynamicAccumulator<P::EA>;
171    type Config = ReduceOperationConfig;
172
173    fn requirements(this: &Self) -> ReduceRequirements {
174        let coordinates = match this {
175            ReduceOperation::Sum(..) => false,
176            ReduceOperation::Prod(..) => false,
177            ReduceOperation::Mean(..) => false,
178            ReduceOperation::MaxAbs(..) => false,
179            ReduceOperation::ArgMax(..) => true,
180            ReduceOperation::ArgMin(..) => true,
181            ReduceOperation::Max(..) => false,
182            ReduceOperation::Min(..) => false,
183        };
184        ReduceRequirements { coordinates }
185    }
186
187    fn from_config(#[comptime] config: Self::Config) -> Self {
188        match config {
189            ReduceOperationConfig::Sum => ReduceOperation::new_Sum(Sum {}),
190            ReduceOperationConfig::Prod => ReduceOperation::new_Prod(Prod {}),
191            ReduceOperationConfig::Mean => ReduceOperation::new_Mean(Mean { sum: Sum {} }),
192            ReduceOperationConfig::MaxAbs => ReduceOperation::new_MaxAbs(MaxAbs {}),
193            ReduceOperationConfig::ArgMax => ReduceOperation::new_ArgMax(ArgMax {}),
194            ReduceOperationConfig::ArgMin => ReduceOperation::new_ArgMin(ArgMin {}),
195            ReduceOperationConfig::Max => ReduceOperation::new_Max(Max {}),
196            ReduceOperationConfig::Min => ReduceOperation::new_Min(Min {}),
197        }
198    }
199
200    fn null_input(this: &Self, #[comptime] line_size: LineSize) -> Line<P::EI> {
201        match this {
202            ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
203            ReduceOperation::Prod(prod) => {
204                <Prod as ReduceInstruction<P>>::null_input(prod, line_size)
205            }
206            ReduceOperation::Mean(mean) => {
207                <Mean as ReduceInstruction<P>>::null_input(mean, line_size)
208            }
209            ReduceOperation::MaxAbs(maxabs) => {
210                <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
211            }
212            ReduceOperation::ArgMax(argmax) => {
213                <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
214            }
215            ReduceOperation::ArgMin(argmin) => {
216                <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
217            }
218            ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
219            ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
220        }
221    }
222
223    fn null_accumulator(this: &Self, #[comptime] line_size: LineSize) -> Self::AccumulatorItem {
224        match this {
225            ReduceOperation::Sum(sum) => {
226                let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
227
228                DynamicAccumulatorItem::<P::EA> {
229                    elements,
230                    args: CubeOption::new_None(),
231                }
232            }
233            ReduceOperation::Mean(sum) => {
234                let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
235
236                DynamicAccumulatorItem::<P::EA> {
237                    elements,
238                    args: CubeOption::new_None(),
239                }
240            }
241            ReduceOperation::Prod(sum) => {
242                let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
243
244                DynamicAccumulatorItem::<P::EA> {
245                    elements,
246                    args: CubeOption::new_None(),
247                }
248            }
249            ReduceOperation::MaxAbs(maxabs) => {
250                let elements =
251                    <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
252
253                DynamicAccumulatorItem::<P::EA> {
254                    elements,
255                    args: CubeOption::new_None(),
256                }
257            }
258            ReduceOperation::ArgMax(argmax) => {
259                let (elements, args) =
260                    <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
261
262                DynamicAccumulatorItem::<P::EA> {
263                    elements,
264                    args: CubeOption::new_Some(args),
265                }
266            }
267            ReduceOperation::ArgMin(argmin) => {
268                let (elements, args) =
269                    <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
270
271                DynamicAccumulatorItem::<P::EA> {
272                    elements,
273                    args: CubeOption::new_Some(args),
274                }
275            }
276            ReduceOperation::Max(max) => {
277                let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
278
279                DynamicAccumulatorItem::<P::EA> {
280                    elements,
281                    args: CubeOption::new_None(),
282                }
283            }
284            ReduceOperation::Min(min) => {
285                let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
286
287                DynamicAccumulatorItem::<P::EA> {
288                    elements,
289                    args: CubeOption::new_None(),
290                }
291            }
292        }
293    }
294
295    fn read_accumulator(
296        this: &Self,
297        accumulator: &Self::AccumulatorItem,
298    ) -> (Line<P::EI>, ReduceCoordinate) {
299        match this {
300            ReduceOperation::Sum(sum) => {
301                <Sum as ReduceInstruction<P>>::read_accumulator(sum, &accumulator.elements)
302            }
303            ReduceOperation::Prod(prod) => {
304                <Prod as ReduceInstruction<P>>::read_accumulator(prod, &accumulator.elements)
305            }
306            ReduceOperation::Mean(mean) => {
307                <Mean as ReduceInstruction<P>>::read_accumulator(mean, &accumulator.elements)
308            }
309            ReduceOperation::MaxAbs(maxabs) => {
310                <MaxAbs as ReduceInstruction<P>>::read_accumulator(maxabs, &accumulator.elements)
311            }
312            ReduceOperation::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::read_accumulator(
313                argmax,
314                &(accumulator.elements, accumulator.args.unwrap()),
315            ),
316            ReduceOperation::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::read_accumulator(
317                argmin,
318                &(accumulator.elements, accumulator.args.unwrap()),
319            ),
320            ReduceOperation::Max(max) => {
321                <Max as ReduceInstruction<P>>::read_accumulator(max, &accumulator.elements)
322            }
323            ReduceOperation::Min(min) => {
324                <Min as ReduceInstruction<P>>::read_accumulator(min, &accumulator.elements)
325            }
326        }
327    }
328
329    fn assign_accumulator(
330        _this: &Self,
331        destination: &mut Self::AccumulatorItem,
332        source: &Self::AccumulatorItem,
333    ) {
334        destination.elements = source.elements;
335        let args = &mut destination.args;
336        match args {
337            CubeOption::Some(val) => *val = source.args.unwrap(),
338            CubeOption::None => {}
339        }
340    }
341
342    fn reduce(
343        this: &Self,
344        accumulator: &Self::AccumulatorItem,
345        item: Line<P::EI>,
346        coordinate: ReduceCoordinate,
347        #[comptime] use_planes: bool,
348    ) -> Self::AccumulatorItem {
349        match this {
350            ReduceOperation::Sum(sum) => {
351                let elements = <Sum as ReduceInstruction<P>>::reduce(
352                    sum,
353                    &accumulator.elements,
354                    item,
355                    coordinate,
356                    use_planes,
357                );
358                DynamicAccumulatorItem::<P::EA> {
359                    elements,
360                    args: CubeOption::new_None(),
361                }
362            }
363            ReduceOperation::Prod(sum) => {
364                let elements = <Prod as ReduceInstruction<P>>::reduce(
365                    sum,
366                    &accumulator.elements,
367                    item,
368                    coordinate,
369                    use_planes,
370                );
371                DynamicAccumulatorItem::<P::EA> {
372                    elements,
373                    args: CubeOption::new_None(),
374                }
375            }
376            ReduceOperation::Mean(sum) => {
377                let elements = <Mean as ReduceInstruction<P>>::reduce(
378                    sum,
379                    &accumulator.elements,
380                    item,
381                    coordinate,
382                    use_planes,
383                );
384                DynamicAccumulatorItem::<P::EA> {
385                    elements,
386                    args: CubeOption::new_None(),
387                }
388            }
389            ReduceOperation::MaxAbs(maxabs) => {
390                let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
391                    maxabs,
392                    &accumulator.elements,
393                    item,
394                    coordinate,
395                    use_planes,
396                );
397                DynamicAccumulatorItem::<P::EA> {
398                    elements,
399                    args: CubeOption::new_None(),
400                }
401            }
402            ReduceOperation::ArgMax(argmax) => {
403                let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
404                    argmax,
405                    &(accumulator.elements, accumulator.args.unwrap()),
406                    item,
407                    coordinate,
408                    use_planes,
409                );
410
411                DynamicAccumulatorItem::<P::EA> {
412                    elements,
413                    args: CubeOption::new_Some(args),
414                }
415            }
416            ReduceOperation::ArgMin(argmin) => {
417                let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
418                    argmin,
419                    &(accumulator.elements, accumulator.args.unwrap()),
420                    item,
421                    coordinate,
422                    use_planes,
423                );
424
425                DynamicAccumulatorItem::<P::EA> {
426                    elements,
427                    args: CubeOption::new_Some(args),
428                }
429            }
430            ReduceOperation::Max(max) => {
431                let elements = <Max as ReduceInstruction<P>>::reduce(
432                    max,
433                    &accumulator.elements,
434                    item,
435                    coordinate,
436                    use_planes,
437                );
438                DynamicAccumulatorItem::<P::EA> {
439                    elements,
440                    args: CubeOption::new_None(),
441                }
442            }
443            ReduceOperation::Min(min) => {
444                let elements = <Min as ReduceInstruction<P>>::reduce(
445                    min,
446                    &accumulator.elements,
447                    item,
448                    coordinate,
449                    use_planes,
450                );
451                DynamicAccumulatorItem::<P::EA> {
452                    elements,
453                    args: CubeOption::new_None(),
454                }
455            }
456        }
457    }
458
459    fn fuse_accumulators(
460        this: &Self,
461        lhs: Self::AccumulatorItem,
462        rhs: Self::AccumulatorItem,
463    ) -> Self::AccumulatorItem {
464        match this {
465            ReduceOperation::Sum(sum) => {
466                let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
467                    sum,
468                    lhs.elements,
469                    rhs.elements,
470                );
471                DynamicAccumulatorItem::<P::EA> {
472                    elements,
473                    args: CubeOption::new_None(),
474                }
475            }
476            ReduceOperation::Prod(prod) => {
477                let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
478                    prod,
479                    lhs.elements,
480                    rhs.elements,
481                );
482                DynamicAccumulatorItem::<P::EA> {
483                    elements,
484                    args: CubeOption::new_None(),
485                }
486            }
487            ReduceOperation::Mean(mean) => {
488                let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
489                    mean,
490                    lhs.elements,
491                    rhs.elements,
492                );
493                DynamicAccumulatorItem::<P::EA> {
494                    elements,
495                    args: CubeOption::new_None(),
496                }
497            }
498            ReduceOperation::MaxAbs(maxabs) => {
499                let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
500                    maxabs,
501                    lhs.elements,
502                    rhs.elements,
503                );
504                DynamicAccumulatorItem::<P::EA> {
505                    elements,
506                    args: CubeOption::new_None(),
507                }
508            }
509            ReduceOperation::ArgMax(argmax) => {
510                let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
511                    argmax,
512                    (lhs.elements, lhs.args.unwrap()),
513                    (rhs.elements, rhs.args.unwrap()),
514                );
515                DynamicAccumulatorItem::<P::EA> {
516                    elements,
517                    args: CubeOption::new_Some(args),
518                }
519            }
520            ReduceOperation::ArgMin(argmin) => {
521                let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
522                    argmin,
523                    (lhs.elements, lhs.args.unwrap()),
524                    (rhs.elements, rhs.args.unwrap()),
525                );
526                DynamicAccumulatorItem::<P::EA> {
527                    elements,
528                    args: CubeOption::new_Some(args),
529                }
530            }
531            ReduceOperation::Max(max) => {
532                let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
533                    max,
534                    lhs.elements,
535                    rhs.elements,
536                );
537                DynamicAccumulatorItem::<P::EA> {
538                    elements,
539                    args: CubeOption::new_None(),
540                }
541            }
542            ReduceOperation::Min(min) => {
543                let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
544                    min,
545                    lhs.elements,
546                    rhs.elements,
547                );
548                DynamicAccumulatorItem::<P::EA> {
549                    elements,
550                    args: CubeOption::new_None(),
551                }
552            }
553        }
554    }
555
556    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
557    //      Then, an instruction like Dynamic can be implemented by fusing a Sum reduction and a element-wise division.
558    fn merge_line<Out: Numeric>(
559        this: &Self,
560        accumulator: Self::AccumulatorItem,
561        shape_axis_reduce: usize,
562    ) -> Out {
563        match this {
564            ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
565                sum,
566                accumulator.elements,
567                shape_axis_reduce,
568            ),
569            ReduceOperation::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
570                prod,
571                accumulator.elements,
572                shape_axis_reduce,
573            ),
574            ReduceOperation::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
575                mean,
576                accumulator.elements,
577                shape_axis_reduce,
578            ),
579            ReduceOperation::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
580                maxabs,
581                accumulator.elements,
582                shape_axis_reduce,
583            ),
584            ReduceOperation::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
585                argmax,
586                (accumulator.elements, accumulator.args.unwrap()),
587                shape_axis_reduce,
588            ),
589            ReduceOperation::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
590                argmin,
591                (accumulator.elements, accumulator.args.unwrap()),
592                shape_axis_reduce,
593            ),
594            ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
595                max,
596                accumulator.elements,
597                shape_axis_reduce,
598            ),
599            ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
600                min,
601                accumulator.elements,
602                shape_axis_reduce,
603            ),
604        }
605    }
606
607    fn to_output_perpendicular<Out: Numeric>(
608        this: &Self,
609        accumulator: Self::AccumulatorItem,
610        shape_axis_reduce: usize,
611    ) -> Line<Out> {
612        match this {
613            ReduceOperation::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<
614                Out,
615            >(sum, accumulator.elements, shape_axis_reduce),
616            ReduceOperation::Prod(prod) => {
617                <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
618                    prod,
619                    accumulator.elements,
620                    shape_axis_reduce,
621                )
622            }
623            ReduceOperation::Mean(mean) => {
624                <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
625                    mean,
626                    accumulator.elements,
627                    shape_axis_reduce,
628                )
629            }
630            ReduceOperation::MaxAbs(maxabs) => {
631                <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
632                    maxabs,
633                    accumulator.elements,
634                    shape_axis_reduce,
635                )
636            }
637            ReduceOperation::ArgMax(args) => {
638                <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
639                    args,
640                    (accumulator.elements, accumulator.args.unwrap()),
641                    shape_axis_reduce,
642                )
643            }
644            ReduceOperation::ArgMin(args) => {
645                <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
646                    args,
647                    (accumulator.elements, accumulator.args.unwrap()),
648                    shape_axis_reduce,
649                )
650            }
651            ReduceOperation::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<
652                Out,
653            >(max, accumulator.elements, shape_axis_reduce),
654            ReduceOperation::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<
655                Out,
656            >(min, accumulator.elements, shape_axis_reduce),
657        }
658    }
659}