cubecl_reduce/instructions/
argmax.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};
5
6/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
7#[derive(Debug)]
8pub struct ArgMax;
9
10#[cube]
11impl ArgMax {
12    /// Compare two pairs of items and coordinates and return a new pair
13    /// where each element in the lines is the maximal item with its coordinate.
14    /// In case of equality, the lowest coordinate is selected.
15    pub fn choose_argmax<N: Numeric>(
16        items0: Line<N>,
17        coordinates0: Line<u32>,
18        items1: Line<N>,
19        coordinates1: Line<u32>,
20    ) -> (Line<N>, Line<u32>) {
21        let to_keep = select_many(
22            items0.equal(items1),
23            coordinates0.less_than(coordinates1),
24            items0.greater_than(items1),
25        );
26        let items = select_many(to_keep, items0, items1);
27        let coordinates = select_many(to_keep, coordinates0, coordinates1);
28        (items, coordinates)
29    }
30}
31
32impl Reduce for ArgMax {
33    type Instruction<In: Numeric> = Self;
34}
35
36#[cube]
37impl<In: Numeric> ReduceInstruction<In> for ArgMax {
38    type AccumulatorItem = (Line<In>, Line<u32>);
39    type SharedAccumulator = ArgAccumulator<In>;
40
41    fn null_input(#[comptime] line_size: u32) -> Line<In> {
42        Line::empty(line_size).fill(In::min_value())
43    }
44
45    fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
46        (
47            Self::null_input(line_size),
48            Line::empty(line_size).fill(u32::MAX),
49        )
50    }
51
52    fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
53        destination.0 = source.0;
54        destination.1 = source.1;
55    }
56
57    fn reduce(
58        accumulator: &Self::AccumulatorItem,
59        item: Line<In>,
60        coordinate: Line<u32>,
61        #[comptime] use_planes: bool,
62    ) -> Self::AccumulatorItem {
63        let (candidate_item, candidate_coordinate) = if use_planes {
64            let candidate_item = plane_max(item);
65            let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
66            (candidate_item, candidate_coordinate)
67        } else {
68            (item, coordinate)
69        };
70        Self::choose_argmax(
71            candidate_item,
72            candidate_coordinate,
73            accumulator.0,
74            accumulator.1,
75        )
76    }
77
78    fn fuse_accumulators(
79        lhs: Self::AccumulatorItem,
80        rhs: Self::AccumulatorItem,
81    ) -> Self::AccumulatorItem {
82        Self::choose_argmax(lhs.0, lhs.1, rhs.0, rhs.1)
83    }
84
85    fn merge_line<Out: Numeric>(
86        accumulator: Self::AccumulatorItem,
87        _shape_axis_reduce: u32,
88    ) -> Out {
89        let line_size = accumulator.0.size();
90        if comptime!(line_size > 1) {
91            let mut max = In::min_value();
92            let mut coordinate = u32::MAX.runtime();
93            #[unroll]
94            for k in 0..line_size {
95                let acc_element = accumulator.0[k];
96                let acc_coordinate = accumulator.1[k];
97                if acc_element == max && acc_coordinate < coordinate {
98                    coordinate = acc_coordinate;
99                } else if acc_element > max {
100                    max = acc_element;
101                    coordinate = acc_coordinate;
102                }
103            }
104            Out::cast_from(coordinate)
105        } else {
106            Out::cast_from(accumulator.1)
107        }
108    }
109
110    fn to_output_perpendicular<Out: Numeric>(
111        accumulator: Self::AccumulatorItem,
112        _shape_axis_reduce: u32,
113    ) -> Line<Out> {
114        Line::cast_from(accumulator.1)
115    }
116}