cubecl_reduce/instructions/
mixed.rs

1use super::{
2    ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
3    ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
4};
5use crate::ReduceDtypes;
6use crate::precision::ReducePrecision;
7use cubecl_core::ir::{FloatKind, IntKind, UIntKind};
8use cubecl_core::prelude::*;
9use cubecl_core::{self as cubecl, ir::ElemType};
10use cubecl_std::{CubeOption, CubeOptionExpand};
11
12#[derive(Debug, CubeType, Clone)]
13pub enum ReduceFn {
14    Sum(Sum),
15    Prod(Prod),
16    Mean(Mean),
17    MaxAbs(MaxAbs),
18    ArgMax(ArgMax),
19    ArgMin(ArgMin),
20    Max(Max),
21    Min(Min),
22}
23
24#[derive_cube_comptime]
25pub enum ReduceFnConfig {
26    Sum,
27    Prod,
28    Mean,
29    MaxAbs,
30    ArgMax,
31    ArgMin,
32    Max,
33    Min,
34}
35
36impl ReduceFnConfig {
37    /// Computes the best case precision for the given config.
38    pub fn precision(&self, input: ElemType) -> ReduceDtypes {
39        match self {
40            ReduceFnConfig::Sum | ReduceFnConfig::Prod | ReduceFnConfig::Mean => {}
41            // No benefit to mixed precision accumulation.
42            ReduceFnConfig::MaxAbs | ReduceFnConfig::Max | ReduceFnConfig::Min => {
43                return ReduceDtypes {
44                    input: input.into(),
45                    output: input.into(),
46                    accumulation: input.into(),
47                };
48            }
49            ReduceFnConfig::ArgMax | ReduceFnConfig::ArgMin => {
50                return ReduceDtypes {
51                    input: input.into(),
52                    output: i32::as_type_native_unchecked(),
53                    accumulation: input.into(),
54                };
55            }
56        };
57
58        match input {
59            ElemType::Float(kind) => {
60                let acc = match kind {
61                    FloatKind::F64 => f64::as_type_native_unchecked(),
62                    _ => f32::as_type_native_unchecked(),
63                };
64
65                ReduceDtypes {
66                    input: input.into(),
67                    output: input.into(),
68                    accumulation: acc,
69                }
70            }
71            ElemType::Int(kind) => {
72                let acc = match kind {
73                    IntKind::I64 => i64::as_type_native_unchecked(),
74                    _ => i32::as_type_native_unchecked(),
75                };
76
77                ReduceDtypes {
78                    input: input.into(),
79                    output: input.into(),
80                    accumulation: acc,
81                }
82            }
83            ElemType::UInt(kind) => {
84                let acc = match kind {
85                    UIntKind::U64 => u64::as_type_native_unchecked(),
86                    _ => u32::as_type_native_unchecked(),
87                };
88
89                ReduceDtypes {
90                    input: input.into(),
91                    output: input.into(),
92                    accumulation: acc,
93                }
94            }
95            ElemType::Bool => panic!("Can't reduce on booleans"),
96        }
97    }
98}
99
100impl ReduceFamily for ReduceFn {
101    type Instruction<P: ReducePrecision> = Self;
102    type Config = ReduceFnConfig;
103}
104
105#[derive(CubeType)]
106pub struct DynamicAccumulator<N: Numeric> {
107    pub elements: SharedMemory<Line<N>>,
108    pub args: CubeOption<SharedMemory<Line<u32>>>,
109}
110
111#[derive(CubeType)]
112pub struct DynamicAccumulatorItem<N: Numeric> {
113    pub elements: Line<N>,
114    pub args: CubeOption<Line<u32>>,
115}
116
117#[cube]
118impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
119    type Item = DynamicAccumulatorItem<In>;
120
121    fn allocate(
122        #[comptime] length: u32,
123        #[comptime] line_size: u32,
124        #[comptime] coordinate: bool,
125    ) -> Self {
126        let elements = SharedMemory::new_lined(length, line_size);
127        let args = if comptime![coordinate] {
128            let args = SharedMemory::new_lined(length, line_size);
129            CubeOption::new_Some(args)
130        } else {
131            CubeOption::new_None()
132        };
133
134        DynamicAccumulator::<In> { elements, args }
135    }
136
137    fn read(accumulator: &Self, index: u32) -> Self::Item {
138        let elements = accumulator.elements[index];
139        let args = match accumulator.args {
140            CubeOption::Some(args) => CubeOption::new_Some(args[index]),
141            CubeOption::None => CubeOption::new_None(),
142        };
143
144        DynamicAccumulatorItem::<In> { elements, args }
145    }
146
147    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
148        accumulator.elements[index] = item.elements;
149
150        let args = &mut accumulator.args;
151        match args {
152            CubeOption::Some(args) => {
153                args[index] = item.args.unwrap();
154            }
155            CubeOption::None => {}
156        };
157    }
158}
159
160#[cube]
161impl<P: ReducePrecision> ReduceInstruction<P> for ReduceFn {
162    type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
163    type SharedAccumulator = DynamicAccumulator<P::EA>;
164    type Config = ReduceFnConfig;
165
166    fn requirements(this: &Self) -> ReduceRequirements {
167        let coordinates = match this {
168            ReduceFn::Sum(..) => comptime![false],
169            ReduceFn::Prod(..) => comptime![false],
170            ReduceFn::Mean(..) => comptime![false],
171            ReduceFn::MaxAbs(..) => comptime![false],
172            ReduceFn::ArgMax(..) => comptime![true],
173            ReduceFn::ArgMin(..) => comptime![true],
174            ReduceFn::Max(..) => comptime![false],
175            ReduceFn::Min(..) => comptime![false],
176        };
177        ReduceRequirements {
178            coordinates: comptime! {coordinates},
179        }
180    }
181
182    fn from_config(#[comptime] config: Self::Config) -> Self {
183        match config {
184            ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
185            ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
186            ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
187            ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
188            ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
189            ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
190            ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
191            ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
192        }
193    }
194
195    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
196        match this {
197            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
198            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::null_input(prod, line_size),
199            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::null_input(mean, line_size),
200            ReduceFn::MaxAbs(maxabs) => {
201                <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
202            }
203            ReduceFn::ArgMax(argmax) => {
204                <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
205            }
206            ReduceFn::ArgMin(argmin) => {
207                <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
208            }
209            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
210            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
211        }
212    }
213
214    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
215        match this {
216            ReduceFn::Sum(sum) => {
217                let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
218
219                DynamicAccumulatorItem::<P::EA> {
220                    elements,
221                    args: CubeOption::new_None(),
222                }
223            }
224            ReduceFn::Mean(sum) => {
225                let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
226
227                DynamicAccumulatorItem::<P::EA> {
228                    elements,
229                    args: CubeOption::new_None(),
230                }
231            }
232            ReduceFn::Prod(sum) => {
233                let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
234
235                DynamicAccumulatorItem::<P::EA> {
236                    elements,
237                    args: CubeOption::new_None(),
238                }
239            }
240            ReduceFn::MaxAbs(maxabs) => {
241                let elements =
242                    <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
243
244                DynamicAccumulatorItem::<P::EA> {
245                    elements,
246                    args: CubeOption::new_None(),
247                }
248            }
249            ReduceFn::ArgMax(argmax) => {
250                let (elements, args) =
251                    <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
252
253                DynamicAccumulatorItem::<P::EA> {
254                    elements,
255                    args: CubeOption::new_Some(args),
256                }
257            }
258            ReduceFn::ArgMin(argmin) => {
259                let (elements, args) =
260                    <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
261
262                DynamicAccumulatorItem::<P::EA> {
263                    elements,
264                    args: CubeOption::new_Some(args),
265                }
266            }
267            ReduceFn::Max(max) => {
268                let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
269
270                DynamicAccumulatorItem::<P::EA> {
271                    elements,
272                    args: CubeOption::new_None(),
273                }
274            }
275            ReduceFn::Min(min) => {
276                let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
277
278                DynamicAccumulatorItem::<P::EA> {
279                    elements,
280                    args: CubeOption::new_None(),
281                }
282            }
283        }
284    }
285
286    fn assign_accumulator(
287        _this: &Self,
288        destination: &mut Self::AccumulatorItem,
289        source: &Self::AccumulatorItem,
290    ) {
291        destination.elements = source.elements;
292        let args = &mut destination.args;
293        match args {
294            CubeOption::Some(val) => *val = source.args.unwrap(),
295            CubeOption::None => {}
296        }
297    }
298
299    fn reduce(
300        this: &Self,
301        accumulator: &Self::AccumulatorItem,
302        item: Line<P::EI>,
303        coordinate: ReduceCoordinate,
304        #[comptime] use_planes: bool,
305    ) -> Self::AccumulatorItem {
306        match this {
307            ReduceFn::Sum(sum) => {
308                let elements = <Sum as ReduceInstruction<P>>::reduce(
309                    sum,
310                    &accumulator.elements,
311                    item,
312                    coordinate,
313                    use_planes,
314                );
315                DynamicAccumulatorItem::<P::EA> {
316                    elements,
317                    args: CubeOption::new_None(),
318                }
319            }
320            ReduceFn::Prod(sum) => {
321                let elements = <Prod as ReduceInstruction<P>>::reduce(
322                    sum,
323                    &accumulator.elements,
324                    item,
325                    coordinate,
326                    use_planes,
327                );
328                DynamicAccumulatorItem::<P::EA> {
329                    elements,
330                    args: CubeOption::new_None(),
331                }
332            }
333            ReduceFn::Mean(sum) => {
334                let elements = <Mean as ReduceInstruction<P>>::reduce(
335                    sum,
336                    &accumulator.elements,
337                    item,
338                    coordinate,
339                    use_planes,
340                );
341                DynamicAccumulatorItem::<P::EA> {
342                    elements,
343                    args: CubeOption::new_None(),
344                }
345            }
346            ReduceFn::MaxAbs(maxabs) => {
347                let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
348                    maxabs,
349                    &accumulator.elements,
350                    item,
351                    coordinate,
352                    use_planes,
353                );
354                DynamicAccumulatorItem::<P::EA> {
355                    elements,
356                    args: CubeOption::new_None(),
357                }
358            }
359            ReduceFn::ArgMax(argmax) => {
360                let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
361                    argmax,
362                    &(accumulator.elements, accumulator.args.unwrap()),
363                    item,
364                    coordinate,
365                    use_planes,
366                );
367
368                DynamicAccumulatorItem::<P::EA> {
369                    elements,
370                    args: CubeOption::new_Some(args),
371                }
372            }
373            ReduceFn::ArgMin(argmin) => {
374                let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
375                    argmin,
376                    &(accumulator.elements, accumulator.args.unwrap()),
377                    item,
378                    coordinate,
379                    use_planes,
380                );
381
382                DynamicAccumulatorItem::<P::EA> {
383                    elements,
384                    args: CubeOption::new_Some(args),
385                }
386            }
387            ReduceFn::Max(max) => {
388                let elements = <Max as ReduceInstruction<P>>::reduce(
389                    max,
390                    &accumulator.elements,
391                    item,
392                    coordinate,
393                    use_planes,
394                );
395                DynamicAccumulatorItem::<P::EA> {
396                    elements,
397                    args: CubeOption::new_None(),
398                }
399            }
400            ReduceFn::Min(min) => {
401                let elements = <Min as ReduceInstruction<P>>::reduce(
402                    min,
403                    &accumulator.elements,
404                    item,
405                    coordinate,
406                    use_planes,
407                );
408                DynamicAccumulatorItem::<P::EA> {
409                    elements,
410                    args: CubeOption::new_None(),
411                }
412            }
413        }
414    }
415
416    fn fuse_accumulators(
417        this: &Self,
418        lhs: Self::AccumulatorItem,
419        rhs: Self::AccumulatorItem,
420    ) -> Self::AccumulatorItem {
421        match this {
422            ReduceFn::Sum(sum) => {
423                let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
424                    sum,
425                    lhs.elements,
426                    rhs.elements,
427                );
428                DynamicAccumulatorItem::<P::EA> {
429                    elements,
430                    args: CubeOption::new_None(),
431                }
432            }
433            ReduceFn::Prod(prod) => {
434                let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
435                    prod,
436                    lhs.elements,
437                    rhs.elements,
438                );
439                DynamicAccumulatorItem::<P::EA> {
440                    elements,
441                    args: CubeOption::new_None(),
442                }
443            }
444            ReduceFn::Mean(mean) => {
445                let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
446                    mean,
447                    lhs.elements,
448                    rhs.elements,
449                );
450                DynamicAccumulatorItem::<P::EA> {
451                    elements,
452                    args: CubeOption::new_None(),
453                }
454            }
455            ReduceFn::MaxAbs(maxabs) => {
456                let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
457                    maxabs,
458                    lhs.elements,
459                    rhs.elements,
460                );
461                DynamicAccumulatorItem::<P::EA> {
462                    elements,
463                    args: CubeOption::new_None(),
464                }
465            }
466            ReduceFn::ArgMax(argmax) => {
467                let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
468                    argmax,
469                    (lhs.elements, lhs.args.unwrap()),
470                    (rhs.elements, rhs.args.unwrap()),
471                );
472                DynamicAccumulatorItem::<P::EA> {
473                    elements,
474                    args: CubeOption::new_Some(args),
475                }
476            }
477            ReduceFn::ArgMin(argmin) => {
478                let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
479                    argmin,
480                    (lhs.elements, lhs.args.unwrap()),
481                    (rhs.elements, rhs.args.unwrap()),
482                );
483                DynamicAccumulatorItem::<P::EA> {
484                    elements,
485                    args: CubeOption::new_Some(args),
486                }
487            }
488            ReduceFn::Max(max) => {
489                let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
490                    max,
491                    lhs.elements,
492                    rhs.elements,
493                );
494                DynamicAccumulatorItem::<P::EA> {
495                    elements,
496                    args: CubeOption::new_None(),
497                }
498            }
499            ReduceFn::Min(min) => {
500                let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
501                    min,
502                    lhs.elements,
503                    rhs.elements,
504                );
505                DynamicAccumulatorItem::<P::EA> {
506                    elements,
507                    args: CubeOption::new_None(),
508                }
509            }
510        }
511    }
512
513    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
514    //      Then, an instruction like Dynamic can be implemented by fusing a Sum reduction and a element-wise division.
515    fn merge_line<Out: Numeric>(
516        this: &Self,
517        accumulator: Self::AccumulatorItem,
518        shape_axis_reduce: u32,
519    ) -> Out {
520        match this {
521            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
522                sum,
523                accumulator.elements,
524                shape_axis_reduce,
525            ),
526            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
527                prod,
528                accumulator.elements,
529                shape_axis_reduce,
530            ),
531            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
532                mean,
533                accumulator.elements,
534                shape_axis_reduce,
535            ),
536            ReduceFn::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
537                maxabs,
538                accumulator.elements,
539                shape_axis_reduce,
540            ),
541            ReduceFn::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
542                argmax,
543                (accumulator.elements, accumulator.args.unwrap()),
544                shape_axis_reduce,
545            ),
546            ReduceFn::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
547                argmin,
548                (accumulator.elements, accumulator.args.unwrap()),
549                shape_axis_reduce,
550            ),
551            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
552                max,
553                accumulator.elements,
554                shape_axis_reduce,
555            ),
556            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
557                min,
558                accumulator.elements,
559                shape_axis_reduce,
560            ),
561        }
562    }
563
564    fn to_output_perpendicular<Out: Numeric>(
565        this: &Self,
566        accumulator: Self::AccumulatorItem,
567        shape_axis_reduce: u32,
568    ) -> Line<Out> {
569        match this {
570            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
571                sum,
572                accumulator.elements,
573                shape_axis_reduce,
574            ),
575            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
576                prod,
577                accumulator.elements,
578                shape_axis_reduce,
579            ),
580            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
581                mean,
582                accumulator.elements,
583                shape_axis_reduce,
584            ),
585            ReduceFn::MaxAbs(maxabs) => {
586                <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
587                    maxabs,
588                    accumulator.elements,
589                    shape_axis_reduce,
590                )
591            }
592            ReduceFn::ArgMax(args) => {
593                <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
594                    args,
595                    (accumulator.elements, accumulator.args.unwrap()),
596                    shape_axis_reduce,
597                )
598            }
599            ReduceFn::ArgMin(args) => {
600                <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
601                    args,
602                    (accumulator.elements, accumulator.args.unwrap()),
603                    shape_axis_reduce,
604                )
605            }
606            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
607                max,
608                accumulator.elements,
609                shape_axis_reduce,
610            ),
611            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
612                min,
613                accumulator.elements,
614                shape_axis_reduce,
615            ),
616        }
617    }
618}