cubecl_reduce/instructions/
argmax.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{
7    ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
8    lowest_coordinate_matching,
9};
10
11/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
12#[derive(Debug, CubeType, Clone)]
13pub struct ArgMax {}
14
15#[cube]
16impl ArgMax {
17    /// Compare two pairs of items and coordinates and return a new pair
18    /// where each element in the lines is the maximal item with its coordinate.
19    /// In case of equality, the lowest coordinate is selected.
20    pub fn choose_argmax<N: Numeric>(
21        items0: Line<N>,
22        coordinates0: Line<u32>,
23        items1: Line<N>,
24        coordinates1: Line<u32>,
25    ) -> (Line<N>, Line<u32>) {
26        let to_keep = select_many(
27            items0.equal(items1),
28            coordinates0.less_than(coordinates1),
29            items0.greater_than(items1),
30        );
31        let items = select_many(to_keep, items0, items1);
32        let coordinates = select_many(to_keep, coordinates0, coordinates1);
33        (items, coordinates)
34    }
35}
36
37impl ReduceFamily for ArgMax {
38    type Instruction<P: ReducePrecision> = Self;
39    type Config = ();
40}
41
42#[cube]
43impl<P: ReducePrecision> ReduceInstruction<P> for ArgMax {
44    type AccumulatorItem = (Line<P::EA>, Line<u32>);
45    type SharedAccumulator = ArgAccumulator<P::EA>;
46    type Config = ();
47
48    fn requirements(_this: &Self) -> ReduceRequirements {
49        ReduceRequirements { coordinates: true }
50    }
51
52    fn from_config(_config: Self::Config) -> Self {
53        ArgMax {}
54    }
55
56    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
57        Line::empty(line_size).fill(P::EI::min_value())
58    }
59
60    fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
61        (
62            Line::empty(line_size).fill(P::EA::min_value()),
63            Line::empty(line_size).fill(u32::MAX),
64        )
65    }
66
67    fn assign_accumulator(
68        _this: &Self,
69        destination: &mut Self::AccumulatorItem,
70        source: &Self::AccumulatorItem,
71    ) {
72        destination.0 = source.0;
73        destination.1 = source.1;
74    }
75
76    fn reduce(
77        _this: &Self,
78        accumulator: &Self::AccumulatorItem,
79        item: Line<P::EI>,
80        coordinate: ReduceCoordinate,
81        #[comptime] use_planes: bool,
82    ) -> Self::AccumulatorItem {
83        let coordinate = match coordinate {
84            ReduceCoordinate::Required(val) => val,
85            ReduceCoordinate::NotRequired => {
86                comptime! {panic!("Coordinates are required for ArgMin")};
87                #[allow(unreachable_code)]
88                Line::new(0)
89            }
90        };
91
92        let (candidate_item, candidate_coordinate) = if use_planes {
93            let candidate_item = plane_max(item);
94            let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
95            (candidate_item, candidate_coordinate)
96        } else {
97            (item, coordinate)
98        };
99
100        Self::choose_argmax(
101            Line::cast_from(candidate_item),
102            candidate_coordinate,
103            accumulator.0,
104            accumulator.1,
105        )
106    }
107
108    fn fuse_accumulators(
109        _this: &Self,
110        lhs: Self::AccumulatorItem,
111        rhs: Self::AccumulatorItem,
112    ) -> Self::AccumulatorItem {
113        Self::choose_argmax(lhs.0, lhs.1, rhs.0, rhs.1)
114    }
115
116    fn merge_line<Out: Numeric>(
117        _this: &Self,
118        accumulator: Self::AccumulatorItem,
119        _shape_axis_reduce: u32,
120    ) -> Out {
121        let line_size = accumulator.0.size();
122        if comptime!(line_size > 1) {
123            let mut max = P::EA::min_value();
124            let mut coordinate = u32::MAX.runtime();
125            #[unroll]
126            for k in 0..line_size {
127                let acc_element = accumulator.0[k];
128                let acc_coordinate = accumulator.1[k];
129                if acc_element == max && acc_coordinate < coordinate {
130                    coordinate = acc_coordinate;
131                } else if acc_element > max {
132                    max = acc_element;
133                    coordinate = acc_coordinate;
134                }
135            }
136            Out::cast_from(coordinate)
137        } else {
138            Out::cast_from(accumulator.1)
139        }
140    }
141
142    fn to_output_perpendicular<Out: Numeric>(
143        _this: &Self,
144        accumulator: Self::AccumulatorItem,
145        _shape_axis_reduce: u32,
146    ) -> Line<Out> {
147        Line::cast_from(accumulator.1)
148    }
149}