cubecl_reduce/instructions/
maxabs.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8// TODO Add to test framework.
9/// Return the item with the maximum absolute value.
10#[derive(Debug, CubeType, Clone)]
11pub struct MaxAbs;
12
13impl ReduceFamily for MaxAbs {
14    type Instruction<P: ReducePrecision> = Self;
15    type Config = ();
16}
17
18#[cube]
19impl<P: ReducePrecision> ReduceInstruction<P> for MaxAbs {
20    type AccumulatorItem = Line<P::EA>;
21    type SharedAccumulator = SharedMemory<Line<P::EA>>;
22    type Config = ();
23
24    fn requirements(_this: &Self) -> ReduceRequirements {
25        ReduceRequirements { coordinates: false }
26    }
27
28    fn from_config(_config: Self::Config) -> Self {
29        MaxAbs {}
30    }
31
32    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
33        Line::empty(line_size).fill(P::EI::from_int(0))
34    }
35
36    fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
37        Line::empty(line_size).fill(P::EA::from_int(0))
38    }
39
40    fn assign_accumulator(
41        _this: &Self,
42        destination: &mut Self::AccumulatorItem,
43        source: &Self::AccumulatorItem,
44    ) {
45        *destination = *source;
46    }
47
48    fn reduce(
49        _this: &Self,
50        accumulator: &Self::AccumulatorItem,
51        item: Line<P::EI>,
52        _coordinate: ReduceCoordinate,
53        #[comptime] use_planes: bool,
54    ) -> Self::AccumulatorItem {
55        if use_planes {
56            let candidate_item = Line::cast_from(plane_max(Line::abs(item)));
57            select_many(
58                accumulator.greater_than(candidate_item),
59                *accumulator,
60                candidate_item,
61            )
62        } else {
63            let item_abs = Line::cast_from(Line::abs(item));
64            select_many(accumulator.greater_than(item_abs), *accumulator, item_abs)
65        }
66    }
67
68    fn fuse_accumulators(
69        _this: &Self,
70        lhs: Self::AccumulatorItem,
71        rhs: Self::AccumulatorItem,
72    ) -> Self::AccumulatorItem {
73        select_many(lhs.greater_than(rhs), lhs, rhs)
74    }
75
76    fn merge_line<Out: Numeric>(
77        _this: &Self,
78        accumulator: Self::AccumulatorItem,
79        _shape_axis_reduce: u32,
80    ) -> Out {
81        let mut max = P::EA::from_int(0);
82        #[unroll]
83        for k in 0..accumulator.size() {
84            let candidate = accumulator[k];
85            max = select(candidate > max, candidate, max);
86        }
87        Out::cast_from(max)
88    }
89
90    fn to_output_perpendicular<Out: Numeric>(
91        _this: &Self,
92        accumulator: Self::AccumulatorItem,
93        _shape_axis_reduce: u32,
94    ) -> Line<Out> {
95        Line::cast_from(accumulator)
96    }
97}