cubecl_reduce/instructions/
mixed.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand};
4
5use super::{
6    ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
7    ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
8};
9
10#[derive(Debug, CubeType, Clone)]
11pub enum ReduceFn {
12    Sum(Sum),
13    Prod(Prod),
14    Mean(Mean),
15    MaxAbs(MaxAbs),
16    ArgMax(ArgMax),
17    ArgMin(ArgMin),
18    Max(Max),
19    Min(Min),
20}
21
22#[derive_cube_comptime]
23pub enum ReduceFnConfig {
24    Sum,
25    Prod,
26    Mean,
27    MaxAbs,
28    ArgMax,
29    ArgMin,
30    Max,
31    Min,
32}
33
34impl ReduceFamily for ReduceFn {
35    type Instruction<In: Numeric> = Self;
36    type Config = ReduceFnConfig;
37}
38
39#[derive(CubeType)]
40pub struct DynamicAccumulator<N: Numeric> {
41    pub elements: SharedMemory<Line<N>>,
42    pub args: CubeOption<SharedMemory<Line<u32>>>,
43}
44
45#[derive(CubeType)]
46pub struct DynamicAccumulatorItem<N: Numeric> {
47    pub elements: Line<N>,
48    pub args: CubeOption<Line<u32>>,
49}
50
51#[cube]
52impl<In: Numeric> SharedAccumulator<In> for DynamicAccumulator<In> {
53    type Item = DynamicAccumulatorItem<In>;
54
55    fn allocate(
56        #[comptime] length: u32,
57        #[comptime] line_size: u32,
58        #[comptime] coordinate: bool,
59    ) -> Self {
60        let elements = SharedMemory::new_lined(length, line_size);
61        let args = if comptime![coordinate] {
62            let args = SharedMemory::new_lined(length, line_size);
63            CubeOption::new_Some(args)
64        } else {
65            CubeOption::new_None()
66        };
67
68        DynamicAccumulator::<In> { elements, args }
69    }
70
71    fn read(accumulator: &Self, index: u32) -> Self::Item {
72        let elements = accumulator.elements[index];
73        let args = match accumulator.args {
74            CubeOption::Some(args) => CubeOption::new_Some(args[index]),
75            CubeOption::None => CubeOption::new_None(),
76        };
77
78        DynamicAccumulatorItem::<In> { elements, args }
79    }
80
81    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
82        accumulator.elements[index] = item.elements;
83
84        let args = &mut accumulator.args;
85        match args {
86            CubeOption::Some(args) => {
87                args[index] = item.args.unwrap();
88            }
89            CubeOption::None => {}
90        };
91    }
92}
93
94#[cube]
95impl<In: Numeric> ReduceInstruction<In> for ReduceFn {
96    type AccumulatorItem = DynamicAccumulatorItem<In>;
97    type SharedAccumulator = DynamicAccumulator<In>;
98    type Config = ReduceFnConfig;
99
100    fn requirements(this: &Self) -> ReduceRequirements {
101        let coordinates = match this {
102            ReduceFn::Sum(..) => comptime![false],
103            ReduceFn::Prod(..) => comptime![false],
104            ReduceFn::Mean(..) => comptime![false],
105            ReduceFn::MaxAbs(..) => comptime![false],
106            ReduceFn::ArgMax(..) => comptime![true],
107            ReduceFn::ArgMin(..) => comptime![true],
108            ReduceFn::Max(..) => comptime![false],
109            ReduceFn::Min(..) => comptime![false],
110        };
111        ReduceRequirements {
112            coordinates: comptime! {coordinates},
113        }
114    }
115
116    fn from_config(#[comptime] config: Self::Config) -> Self {
117        match config {
118            ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
119            ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
120            ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
121            ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
122            ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
123            ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
124            ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
125            ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
126        }
127    }
128
129    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In> {
130        match this {
131            ReduceFn::Sum(sum) => Sum::null_input(sum, line_size),
132            ReduceFn::Prod(prod) => Prod::null_input(prod, line_size),
133            ReduceFn::Mean(mean) => Mean::null_input(mean, line_size),
134            ReduceFn::MaxAbs(maxabs) => MaxAbs::null_input(maxabs, line_size),
135            ReduceFn::ArgMax(argmax) => ArgMax::null_input(argmax, line_size),
136            ReduceFn::ArgMin(argmin) => ArgMin::null_input(argmin, line_size),
137            ReduceFn::Max(max) => Max::null_input(max, line_size),
138            ReduceFn::Min(min) => Min::null_input(min, line_size),
139        }
140    }
141
142    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
143        match this {
144            ReduceFn::Sum(sum) => {
145                let elements = Sum::null_accumulator(sum, line_size);
146
147                DynamicAccumulatorItem::<In> {
148                    elements,
149                    args: CubeOption::new_None(),
150                }
151            }
152            ReduceFn::Mean(sum) => {
153                let elements = Mean::null_accumulator(sum, line_size);
154
155                DynamicAccumulatorItem::<In> {
156                    elements,
157                    args: CubeOption::new_None(),
158                }
159            }
160            ReduceFn::Prod(sum) => {
161                let elements = Prod::null_accumulator(sum, line_size);
162
163                DynamicAccumulatorItem::<In> {
164                    elements,
165                    args: CubeOption::new_None(),
166                }
167            }
168            ReduceFn::MaxAbs(maxabs) => {
169                let elements = MaxAbs::null_accumulator(maxabs, line_size);
170
171                DynamicAccumulatorItem::<In> {
172                    elements,
173                    args: CubeOption::new_None(),
174                }
175            }
176            ReduceFn::ArgMax(argmax) => {
177                let (elements, args) = ArgMax::null_accumulator(argmax, line_size);
178
179                DynamicAccumulatorItem::<In> {
180                    elements,
181                    args: CubeOption::new_Some(args),
182                }
183            }
184            ReduceFn::ArgMin(argmin) => {
185                let (elements, args) = ArgMin::null_accumulator(argmin, line_size);
186
187                DynamicAccumulatorItem::<In> {
188                    elements,
189                    args: CubeOption::new_Some(args),
190                }
191            }
192            ReduceFn::Max(max) => {
193                let elements = Max::null_accumulator(max, line_size);
194
195                DynamicAccumulatorItem::<In> {
196                    elements,
197                    args: CubeOption::new_None(),
198                }
199            }
200            ReduceFn::Min(min) => {
201                let elements = Min::null_accumulator(min, line_size);
202
203                DynamicAccumulatorItem::<In> {
204                    elements,
205                    args: CubeOption::new_None(),
206                }
207            }
208        }
209    }
210
211    fn assign_accumulator(
212        _this: &Self,
213        destination: &mut Self::AccumulatorItem,
214        source: &Self::AccumulatorItem,
215    ) {
216        destination.elements = source.elements;
217        let args = &mut destination.args;
218        match args {
219            CubeOption::Some(val) => *val = source.args.unwrap(),
220            CubeOption::None => {}
221        }
222    }
223
224    fn reduce(
225        this: &Self,
226        accumulator: &Self::AccumulatorItem,
227        item: Line<In>,
228        coordinate: ReduceCoordinate,
229        #[comptime] use_planes: bool,
230    ) -> Self::AccumulatorItem {
231        match this {
232            ReduceFn::Sum(sum) => {
233                let elements =
234                    Sum::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
235                DynamicAccumulatorItem::<In> {
236                    elements,
237                    args: CubeOption::new_None(),
238                }
239            }
240            ReduceFn::Prod(sum) => {
241                let elements =
242                    Prod::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
243                DynamicAccumulatorItem::<In> {
244                    elements,
245                    args: CubeOption::new_None(),
246                }
247            }
248            ReduceFn::Mean(sum) => {
249                let elements =
250                    Mean::reduce(sum, &accumulator.elements, item, coordinate, use_planes);
251                DynamicAccumulatorItem::<In> {
252                    elements,
253                    args: CubeOption::new_None(),
254                }
255            }
256            ReduceFn::MaxAbs(maxabs) => {
257                let elements =
258                    MaxAbs::reduce(maxabs, &accumulator.elements, item, coordinate, use_planes);
259                DynamicAccumulatorItem::<In> {
260                    elements,
261                    args: CubeOption::new_None(),
262                }
263            }
264            ReduceFn::ArgMax(argmax) => {
265                let (elements, args) = ArgMax::reduce(
266                    argmax,
267                    &(accumulator.elements, accumulator.args.unwrap()),
268                    item,
269                    coordinate,
270                    use_planes,
271                );
272
273                DynamicAccumulatorItem::<In> {
274                    elements,
275                    args: CubeOption::new_Some(args),
276                }
277            }
278            ReduceFn::ArgMin(argmin) => {
279                let (elements, args) = ArgMin::reduce(
280                    argmin,
281                    &(accumulator.elements, accumulator.args.unwrap()),
282                    item,
283                    coordinate,
284                    use_planes,
285                );
286
287                DynamicAccumulatorItem::<In> {
288                    elements,
289                    args: CubeOption::new_Some(args),
290                }
291            }
292            ReduceFn::Max(max) => {
293                let elements =
294                    Max::reduce(max, &accumulator.elements, item, coordinate, use_planes);
295                DynamicAccumulatorItem::<In> {
296                    elements,
297                    args: CubeOption::new_None(),
298                }
299            }
300            ReduceFn::Min(min) => {
301                let elements =
302                    Min::reduce(min, &accumulator.elements, item, coordinate, use_planes);
303                DynamicAccumulatorItem::<In> {
304                    elements,
305                    args: CubeOption::new_None(),
306                }
307            }
308        }
309    }
310
311    fn fuse_accumulators(
312        this: &Self,
313        lhs: Self::AccumulatorItem,
314        rhs: Self::AccumulatorItem,
315    ) -> Self::AccumulatorItem {
316        match this {
317            ReduceFn::Sum(sum) => {
318                let elements = Sum::fuse_accumulators(sum, lhs.elements, rhs.elements);
319                DynamicAccumulatorItem::<In> {
320                    elements,
321                    args: CubeOption::new_None(),
322                }
323            }
324            ReduceFn::Prod(prod) => {
325                let elements = Prod::fuse_accumulators(prod, lhs.elements, rhs.elements);
326                DynamicAccumulatorItem::<In> {
327                    elements,
328                    args: CubeOption::new_None(),
329                }
330            }
331            ReduceFn::Mean(mean) => {
332                let elements = Mean::fuse_accumulators(mean, lhs.elements, rhs.elements);
333                DynamicAccumulatorItem::<In> {
334                    elements,
335                    args: CubeOption::new_None(),
336                }
337            }
338            ReduceFn::MaxAbs(maxabs) => {
339                let elements = MaxAbs::fuse_accumulators(maxabs, lhs.elements, rhs.elements);
340                DynamicAccumulatorItem::<In> {
341                    elements,
342                    args: CubeOption::new_None(),
343                }
344            }
345            ReduceFn::ArgMax(argmax) => {
346                let (elements, args) = ArgMax::fuse_accumulators(
347                    argmax,
348                    (lhs.elements, lhs.args.unwrap()),
349                    (rhs.elements, rhs.args.unwrap()),
350                );
351                DynamicAccumulatorItem::<In> {
352                    elements,
353                    args: CubeOption::new_Some(args),
354                }
355            }
356            ReduceFn::ArgMin(argmin) => {
357                let (elements, args) = ArgMin::fuse_accumulators(
358                    argmin,
359                    (lhs.elements, lhs.args.unwrap()),
360                    (rhs.elements, rhs.args.unwrap()),
361                );
362                DynamicAccumulatorItem::<In> {
363                    elements,
364                    args: CubeOption::new_Some(args),
365                }
366            }
367            ReduceFn::Max(max) => {
368                let elements = Max::fuse_accumulators(max, lhs.elements, rhs.elements);
369                DynamicAccumulatorItem::<In> {
370                    elements,
371                    args: CubeOption::new_None(),
372                }
373            }
374            ReduceFn::Min(min) => {
375                let elements = Min::fuse_accumulators(min, lhs.elements, rhs.elements);
376                DynamicAccumulatorItem::<In> {
377                    elements,
378                    args: CubeOption::new_None(),
379                }
380            }
381        }
382    }
383
384    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
385    //      Then, an instruction like Dynamic can be implemented by fusing a Sum reduction and a element-wise division.
386    fn merge_line<Out: Numeric>(
387        this: &Self,
388        accumulator: Self::AccumulatorItem,
389        shape_axis_reduce: u32,
390    ) -> Out {
391        match this {
392            ReduceFn::Sum(sum) => {
393                Sum::merge_line::<Out>(sum, accumulator.elements, shape_axis_reduce)
394            }
395            ReduceFn::Prod(prod) => {
396                Prod::merge_line::<Out>(prod, accumulator.elements, shape_axis_reduce)
397            }
398            ReduceFn::Mean(mean) => {
399                Mean::merge_line::<Out>(mean, accumulator.elements, shape_axis_reduce)
400            }
401            ReduceFn::MaxAbs(maxabs) => {
402                MaxAbs::merge_line::<Out>(maxabs, accumulator.elements, shape_axis_reduce)
403            }
404            ReduceFn::ArgMax(argmax) => ArgMax::merge_line::<Out>(
405                argmax,
406                (accumulator.elements, accumulator.args.unwrap()),
407                shape_axis_reduce,
408            ),
409            ReduceFn::ArgMin(argmin) => ArgMin::merge_line::<Out>(
410                argmin,
411                (accumulator.elements, accumulator.args.unwrap()),
412                shape_axis_reduce,
413            ),
414            ReduceFn::Max(max) => {
415                Max::merge_line::<Out>(max, accumulator.elements, shape_axis_reduce)
416            }
417            ReduceFn::Min(min) => {
418                Min::merge_line::<Out>(min, accumulator.elements, shape_axis_reduce)
419            }
420        }
421    }
422
423    fn to_output_perpendicular<Out: Numeric>(
424        this: &Self,
425        accumulator: Self::AccumulatorItem,
426        shape_axis_reduce: u32,
427    ) -> Line<Out> {
428        match this {
429            ReduceFn::Sum(sum) => {
430                Sum::to_output_perpendicular::<Out>(sum, accumulator.elements, shape_axis_reduce)
431            }
432            ReduceFn::Prod(prod) => {
433                Prod::to_output_perpendicular::<Out>(prod, accumulator.elements, shape_axis_reduce)
434            }
435            ReduceFn::Mean(mean) => {
436                Mean::to_output_perpendicular::<Out>(mean, accumulator.elements, shape_axis_reduce)
437            }
438            ReduceFn::MaxAbs(maxabs) => MaxAbs::to_output_perpendicular::<Out>(
439                maxabs,
440                accumulator.elements,
441                shape_axis_reduce,
442            ),
443            ReduceFn::ArgMax(args) => ArgMax::to_output_perpendicular::<Out>(
444                args,
445                (accumulator.elements, accumulator.args.unwrap()),
446                shape_axis_reduce,
447            ),
448            ReduceFn::ArgMin(args) => ArgMin::to_output_perpendicular::<Out>(
449                args,
450                (accumulator.elements, accumulator.args.unwrap()),
451                shape_axis_reduce,
452            ),
453            ReduceFn::Max(max) => {
454                Max::to_output_perpendicular::<Out>(max, accumulator.elements, shape_axis_reduce)
455            }
456            ReduceFn::Min(min) => {
457                Min::to_output_perpendicular::<Out>(min, accumulator.elements, shape_axis_reduce)
458            }
459        }
460    }
461}