cubecl_reduce/instructions/
mixed.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand};
4
5use crate::precision::ReducePrecision;
6
7use super::{
8    ArgMax, ArgMin, Max, MaxAbs, Mean, Min, Prod, ReduceCoordinate, ReduceFamily,
9    ReduceInstruction, ReduceRequirements, SharedAccumulator, Sum,
10};
11
12#[derive(Debug, CubeType, Clone)]
13pub enum ReduceFn {
14    Sum(Sum),
15    Prod(Prod),
16    Mean(Mean),
17    MaxAbs(MaxAbs),
18    ArgMax(ArgMax),
19    ArgMin(ArgMin),
20    Max(Max),
21    Min(Min),
22}
23
24#[derive_cube_comptime]
25pub enum ReduceFnConfig {
26    Sum,
27    Prod,
28    Mean,
29    MaxAbs,
30    ArgMax,
31    ArgMin,
32    Max,
33    Min,
34}
35
36impl ReduceFamily for ReduceFn {
37    type Instruction<P: ReducePrecision> = Self;
38    type Config = ReduceFnConfig;
39}
40
41#[derive(CubeType)]
42pub struct DynamicAccumulator<N: Numeric> {
43    pub elements: SharedMemory<Line<N>>,
44    pub args: CubeOption<SharedMemory<Line<u32>>>,
45}
46
47#[derive(CubeType)]
48pub struct DynamicAccumulatorItem<N: Numeric> {
49    pub elements: Line<N>,
50    pub args: CubeOption<Line<u32>>,
51}
52
53#[cube]
54impl<In: Numeric> SharedAccumulator for DynamicAccumulator<In> {
55    type Item = DynamicAccumulatorItem<In>;
56
57    fn allocate(
58        #[comptime] length: u32,
59        #[comptime] line_size: u32,
60        #[comptime] coordinate: bool,
61    ) -> Self {
62        let elements = SharedMemory::new_lined(length, line_size);
63        let args = if comptime![coordinate] {
64            let args = SharedMemory::new_lined(length, line_size);
65            CubeOption::new_Some(args)
66        } else {
67            CubeOption::new_None()
68        };
69
70        DynamicAccumulator::<In> { elements, args }
71    }
72
73    fn read(accumulator: &Self, index: u32) -> Self::Item {
74        let elements = accumulator.elements[index];
75        let args = match accumulator.args {
76            CubeOption::Some(args) => CubeOption::new_Some(args[index]),
77            CubeOption::None => CubeOption::new_None(),
78        };
79
80        DynamicAccumulatorItem::<In> { elements, args }
81    }
82
83    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
84        accumulator.elements[index] = item.elements;
85
86        let args = &mut accumulator.args;
87        match args {
88            CubeOption::Some(args) => {
89                args[index] = item.args.unwrap();
90            }
91            CubeOption::None => {}
92        };
93    }
94}
95
96#[cube]
97impl<P: ReducePrecision> ReduceInstruction<P> for ReduceFn {
98    type AccumulatorItem = DynamicAccumulatorItem<P::EA>;
99    type SharedAccumulator = DynamicAccumulator<P::EA>;
100    type Config = ReduceFnConfig;
101
102    fn requirements(this: &Self) -> ReduceRequirements {
103        let coordinates = match this {
104            ReduceFn::Sum(..) => comptime![false],
105            ReduceFn::Prod(..) => comptime![false],
106            ReduceFn::Mean(..) => comptime![false],
107            ReduceFn::MaxAbs(..) => comptime![false],
108            ReduceFn::ArgMax(..) => comptime![true],
109            ReduceFn::ArgMin(..) => comptime![true],
110            ReduceFn::Max(..) => comptime![false],
111            ReduceFn::Min(..) => comptime![false],
112        };
113        ReduceRequirements {
114            coordinates: comptime! {coordinates},
115        }
116    }
117
118    fn from_config(#[comptime] config: Self::Config) -> Self {
119        match config {
120            ReduceFnConfig::Sum => ReduceFn::new_Sum(Sum {}),
121            ReduceFnConfig::Prod => ReduceFn::new_Prod(Prod {}),
122            ReduceFnConfig::Mean => ReduceFn::new_Mean(Mean { sum: Sum {} }),
123            ReduceFnConfig::MaxAbs => ReduceFn::new_MaxAbs(MaxAbs {}),
124            ReduceFnConfig::ArgMax => ReduceFn::new_ArgMax(ArgMax {}),
125            ReduceFnConfig::ArgMin => ReduceFn::new_ArgMin(ArgMin {}),
126            ReduceFnConfig::Max => ReduceFn::new_Max(Max {}),
127            ReduceFnConfig::Min => ReduceFn::new_Min(Min {}),
128        }
129    }
130
131    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
132        match this {
133            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::null_input(sum, line_size),
134            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::null_input(prod, line_size),
135            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::null_input(mean, line_size),
136            ReduceFn::MaxAbs(maxabs) => {
137                <MaxAbs as ReduceInstruction<P>>::null_input(maxabs, line_size)
138            }
139            ReduceFn::ArgMax(argmax) => {
140                <ArgMax as ReduceInstruction<P>>::null_input(argmax, line_size)
141            }
142            ReduceFn::ArgMin(argmin) => {
143                <ArgMin as ReduceInstruction<P>>::null_input(argmin, line_size)
144            }
145            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::null_input(max, line_size),
146            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::null_input(min, line_size),
147        }
148    }
149
150    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
151        match this {
152            ReduceFn::Sum(sum) => {
153                let elements = <Sum as ReduceInstruction<P>>::null_accumulator(sum, line_size);
154
155                DynamicAccumulatorItem::<P::EA> {
156                    elements,
157                    args: CubeOption::new_None(),
158                }
159            }
160            ReduceFn::Mean(sum) => {
161                let elements = <Mean as ReduceInstruction<P>>::null_accumulator(sum, line_size);
162
163                DynamicAccumulatorItem::<P::EA> {
164                    elements,
165                    args: CubeOption::new_None(),
166                }
167            }
168            ReduceFn::Prod(sum) => {
169                let elements = <Prod as ReduceInstruction<P>>::null_accumulator(sum, line_size);
170
171                DynamicAccumulatorItem::<P::EA> {
172                    elements,
173                    args: CubeOption::new_None(),
174                }
175            }
176            ReduceFn::MaxAbs(maxabs) => {
177                let elements =
178                    <MaxAbs as ReduceInstruction<P>>::null_accumulator(maxabs, line_size);
179
180                DynamicAccumulatorItem::<P::EA> {
181                    elements,
182                    args: CubeOption::new_None(),
183                }
184            }
185            ReduceFn::ArgMax(argmax) => {
186                let (elements, args) =
187                    <ArgMax as ReduceInstruction<P>>::null_accumulator(argmax, line_size);
188
189                DynamicAccumulatorItem::<P::EA> {
190                    elements,
191                    args: CubeOption::new_Some(args),
192                }
193            }
194            ReduceFn::ArgMin(argmin) => {
195                let (elements, args) =
196                    <ArgMin as ReduceInstruction<P>>::null_accumulator(argmin, line_size);
197
198                DynamicAccumulatorItem::<P::EA> {
199                    elements,
200                    args: CubeOption::new_Some(args),
201                }
202            }
203            ReduceFn::Max(max) => {
204                let elements = <Max as ReduceInstruction<P>>::null_accumulator(max, line_size);
205
206                DynamicAccumulatorItem::<P::EA> {
207                    elements,
208                    args: CubeOption::new_None(),
209                }
210            }
211            ReduceFn::Min(min) => {
212                let elements = <Min as ReduceInstruction<P>>::null_accumulator(min, line_size);
213
214                DynamicAccumulatorItem::<P::EA> {
215                    elements,
216                    args: CubeOption::new_None(),
217                }
218            }
219        }
220    }
221
222    fn assign_accumulator(
223        _this: &Self,
224        destination: &mut Self::AccumulatorItem,
225        source: &Self::AccumulatorItem,
226    ) {
227        destination.elements = source.elements;
228        let args = &mut destination.args;
229        match args {
230            CubeOption::Some(val) => *val = source.args.unwrap(),
231            CubeOption::None => {}
232        }
233    }
234
235    fn reduce(
236        this: &Self,
237        accumulator: &Self::AccumulatorItem,
238        item: Line<P::EI>,
239        coordinate: ReduceCoordinate,
240        #[comptime] use_planes: bool,
241    ) -> Self::AccumulatorItem {
242        match this {
243            ReduceFn::Sum(sum) => {
244                let elements = <Sum as ReduceInstruction<P>>::reduce(
245                    sum,
246                    &accumulator.elements,
247                    item,
248                    coordinate,
249                    use_planes,
250                );
251                DynamicAccumulatorItem::<P::EA> {
252                    elements,
253                    args: CubeOption::new_None(),
254                }
255            }
256            ReduceFn::Prod(sum) => {
257                let elements = <Prod as ReduceInstruction<P>>::reduce(
258                    sum,
259                    &accumulator.elements,
260                    item,
261                    coordinate,
262                    use_planes,
263                );
264                DynamicAccumulatorItem::<P::EA> {
265                    elements,
266                    args: CubeOption::new_None(),
267                }
268            }
269            ReduceFn::Mean(sum) => {
270                let elements = <Mean as ReduceInstruction<P>>::reduce(
271                    sum,
272                    &accumulator.elements,
273                    item,
274                    coordinate,
275                    use_planes,
276                );
277                DynamicAccumulatorItem::<P::EA> {
278                    elements,
279                    args: CubeOption::new_None(),
280                }
281            }
282            ReduceFn::MaxAbs(maxabs) => {
283                let elements = <MaxAbs as ReduceInstruction<P>>::reduce(
284                    maxabs,
285                    &accumulator.elements,
286                    item,
287                    coordinate,
288                    use_planes,
289                );
290                DynamicAccumulatorItem::<P::EA> {
291                    elements,
292                    args: CubeOption::new_None(),
293                }
294            }
295            ReduceFn::ArgMax(argmax) => {
296                let (elements, args) = <ArgMax as ReduceInstruction<P>>::reduce(
297                    argmax,
298                    &(accumulator.elements, accumulator.args.unwrap()),
299                    item,
300                    coordinate,
301                    use_planes,
302                );
303
304                DynamicAccumulatorItem::<P::EA> {
305                    elements,
306                    args: CubeOption::new_Some(args),
307                }
308            }
309            ReduceFn::ArgMin(argmin) => {
310                let (elements, args) = <ArgMin as ReduceInstruction<P>>::reduce(
311                    argmin,
312                    &(accumulator.elements, accumulator.args.unwrap()),
313                    item,
314                    coordinate,
315                    use_planes,
316                );
317
318                DynamicAccumulatorItem::<P::EA> {
319                    elements,
320                    args: CubeOption::new_Some(args),
321                }
322            }
323            ReduceFn::Max(max) => {
324                let elements = <Max as ReduceInstruction<P>>::reduce(
325                    max,
326                    &accumulator.elements,
327                    item,
328                    coordinate,
329                    use_planes,
330                );
331                DynamicAccumulatorItem::<P::EA> {
332                    elements,
333                    args: CubeOption::new_None(),
334                }
335            }
336            ReduceFn::Min(min) => {
337                let elements = <Min as ReduceInstruction<P>>::reduce(
338                    min,
339                    &accumulator.elements,
340                    item,
341                    coordinate,
342                    use_planes,
343                );
344                DynamicAccumulatorItem::<P::EA> {
345                    elements,
346                    args: CubeOption::new_None(),
347                }
348            }
349        }
350    }
351
352    fn fuse_accumulators(
353        this: &Self,
354        lhs: Self::AccumulatorItem,
355        rhs: Self::AccumulatorItem,
356    ) -> Self::AccumulatorItem {
357        match this {
358            ReduceFn::Sum(sum) => {
359                let elements = <Sum as ReduceInstruction<P>>::fuse_accumulators(
360                    sum,
361                    lhs.elements,
362                    rhs.elements,
363                );
364                DynamicAccumulatorItem::<P::EA> {
365                    elements,
366                    args: CubeOption::new_None(),
367                }
368            }
369            ReduceFn::Prod(prod) => {
370                let elements = <Prod as ReduceInstruction<P>>::fuse_accumulators(
371                    prod,
372                    lhs.elements,
373                    rhs.elements,
374                );
375                DynamicAccumulatorItem::<P::EA> {
376                    elements,
377                    args: CubeOption::new_None(),
378                }
379            }
380            ReduceFn::Mean(mean) => {
381                let elements = <Mean as ReduceInstruction<P>>::fuse_accumulators(
382                    mean,
383                    lhs.elements,
384                    rhs.elements,
385                );
386                DynamicAccumulatorItem::<P::EA> {
387                    elements,
388                    args: CubeOption::new_None(),
389                }
390            }
391            ReduceFn::MaxAbs(maxabs) => {
392                let elements = <MaxAbs as ReduceInstruction<P>>::fuse_accumulators(
393                    maxabs,
394                    lhs.elements,
395                    rhs.elements,
396                );
397                DynamicAccumulatorItem::<P::EA> {
398                    elements,
399                    args: CubeOption::new_None(),
400                }
401            }
402            ReduceFn::ArgMax(argmax) => {
403                let (elements, args) = <ArgMax as ReduceInstruction<P>>::fuse_accumulators(
404                    argmax,
405                    (lhs.elements, lhs.args.unwrap()),
406                    (rhs.elements, rhs.args.unwrap()),
407                );
408                DynamicAccumulatorItem::<P::EA> {
409                    elements,
410                    args: CubeOption::new_Some(args),
411                }
412            }
413            ReduceFn::ArgMin(argmin) => {
414                let (elements, args) = <ArgMin as ReduceInstruction<P>>::fuse_accumulators(
415                    argmin,
416                    (lhs.elements, lhs.args.unwrap()),
417                    (rhs.elements, rhs.args.unwrap()),
418                );
419                DynamicAccumulatorItem::<P::EA> {
420                    elements,
421                    args: CubeOption::new_Some(args),
422                }
423            }
424            ReduceFn::Max(max) => {
425                let elements = <Max as ReduceInstruction<P>>::fuse_accumulators(
426                    max,
427                    lhs.elements,
428                    rhs.elements,
429                );
430                DynamicAccumulatorItem::<P::EA> {
431                    elements,
432                    args: CubeOption::new_None(),
433                }
434            }
435            ReduceFn::Min(min) => {
436                let elements = <Min as ReduceInstruction<P>>::fuse_accumulators(
437                    min,
438                    lhs.elements,
439                    rhs.elements,
440                );
441                DynamicAccumulatorItem::<P::EA> {
442                    elements,
443                    args: CubeOption::new_None(),
444                }
445            }
446        }
447    }
448
449    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
450    //      Then, an instruction like Dynamic can be implemented by fusing a Sum reduction and a element-wise division.
451    fn merge_line<Out: Numeric>(
452        this: &Self,
453        accumulator: Self::AccumulatorItem,
454        shape_axis_reduce: u32,
455    ) -> Out {
456        match this {
457            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::merge_line::<Out>(
458                sum,
459                accumulator.elements,
460                shape_axis_reduce,
461            ),
462            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::merge_line::<Out>(
463                prod,
464                accumulator.elements,
465                shape_axis_reduce,
466            ),
467            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::merge_line::<Out>(
468                mean,
469                accumulator.elements,
470                shape_axis_reduce,
471            ),
472            ReduceFn::MaxAbs(maxabs) => <MaxAbs as ReduceInstruction<P>>::merge_line::<Out>(
473                maxabs,
474                accumulator.elements,
475                shape_axis_reduce,
476            ),
477            ReduceFn::ArgMax(argmax) => <ArgMax as ReduceInstruction<P>>::merge_line::<Out>(
478                argmax,
479                (accumulator.elements, accumulator.args.unwrap()),
480                shape_axis_reduce,
481            ),
482            ReduceFn::ArgMin(argmin) => <ArgMin as ReduceInstruction<P>>::merge_line::<Out>(
483                argmin,
484                (accumulator.elements, accumulator.args.unwrap()),
485                shape_axis_reduce,
486            ),
487            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::merge_line::<Out>(
488                max,
489                accumulator.elements,
490                shape_axis_reduce,
491            ),
492            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::merge_line::<Out>(
493                min,
494                accumulator.elements,
495                shape_axis_reduce,
496            ),
497        }
498    }
499
500    fn to_output_perpendicular<Out: Numeric>(
501        this: &Self,
502        accumulator: Self::AccumulatorItem,
503        shape_axis_reduce: u32,
504    ) -> Line<Out> {
505        match this {
506            ReduceFn::Sum(sum) => <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
507                sum,
508                accumulator.elements,
509                shape_axis_reduce,
510            ),
511            ReduceFn::Prod(prod) => <Prod as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
512                prod,
513                accumulator.elements,
514                shape_axis_reduce,
515            ),
516            ReduceFn::Mean(mean) => <Mean as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
517                mean,
518                accumulator.elements,
519                shape_axis_reduce,
520            ),
521            ReduceFn::MaxAbs(maxabs) => {
522                <MaxAbs as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
523                    maxabs,
524                    accumulator.elements,
525                    shape_axis_reduce,
526                )
527            }
528            ReduceFn::ArgMax(args) => {
529                <ArgMax as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
530                    args,
531                    (accumulator.elements, accumulator.args.unwrap()),
532                    shape_axis_reduce,
533                )
534            }
535            ReduceFn::ArgMin(args) => {
536                <ArgMin as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
537                    args,
538                    (accumulator.elements, accumulator.args.unwrap()),
539                    shape_axis_reduce,
540                )
541            }
542            ReduceFn::Max(max) => <Max as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
543                max,
544                accumulator.elements,
545                shape_axis_reduce,
546            ),
547            ReduceFn::Min(min) => <Min as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
548                min,
549                accumulator.elements,
550                shape_axis_reduce,
551            ),
552        }
553    }
554}