cubecl_reduce/instructions/
argmin.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6use super::{
7    ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
8    ReduceRequirements, lowest_coordinate_matching,
9};
10
11/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
12#[derive(Debug, CubeType, Clone)]
13pub struct ArgMin {}
14
15impl ReduceFamily for ArgMin {
16    type Instruction<P: ReducePrecision> = Self;
17    type Config = ();
18}
19
20#[cube]
21impl ArgMin {
22    /// Compare two pairs of items and coordinates and return a new pair
23    /// where each element in the lines is the minimal item with its coordinate.
24    /// In case of equality, the lowest coordinate is selected.
25    pub fn choose_argmin<N: Numeric>(
26        items0: Line<N>,
27        coordinates0: Line<u32>,
28        items1: Line<N>,
29        coordinates1: Line<u32>,
30    ) -> (Line<N>, Line<u32>) {
31        let to_keep = select_many(
32            items0.equal(items1),
33            coordinates0.less_than(coordinates1),
34            items0.less_than(items1),
35        );
36        let items = select_many(to_keep, items0, items1);
37        let coordinates = select_many(to_keep, coordinates0, coordinates1);
38        (items, coordinates)
39    }
40}
41
42#[cube]
43impl<P: ReducePrecision> ReduceInstruction<P> for ArgMin {
44    type AccumulatorItem = (Line<P::EA>, Line<u32>);
45    type SharedAccumulator = ArgAccumulator<P::EA>;
46    type Config = ();
47
48    fn requirements(_this: &Self) -> ReduceRequirements {
49        ReduceRequirements { coordinates: true }
50    }
51    fn from_config(_config: Self::Config) -> Self {
52        ArgMin {}
53    }
54
55    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
56        Line::empty(line_size).fill(P::EI::max_value())
57    }
58
59    fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
60        (
61            Line::empty(line_size).fill(P::EA::max_value()),
62            Line::empty(line_size).fill(u32::MAX),
63        )
64    }
65
66    fn assign_accumulator(
67        _this: &Self,
68        destination: &mut Self::AccumulatorItem,
69        source: &Self::AccumulatorItem,
70    ) {
71        destination.0 = source.0;
72        destination.1 = source.1;
73    }
74
75    fn reduce(
76        _this: &Self,
77        accumulator: &Self::AccumulatorItem,
78        item: Line<P::EI>,
79        coordinate: ReduceCoordinate,
80        #[comptime] use_planes: bool,
81    ) -> Self::AccumulatorItem {
82        let coordinate = match coordinate {
83            ReduceCoordinate::Required(val) => val,
84            ReduceCoordinate::NotRequired => {
85                comptime! {panic!("Coordinates are required for ArgMin")};
86                #[allow(unreachable_code)]
87                Line::new(0)
88            }
89        };
90
91        let (candidate_item, candidate_coordinate) = if use_planes {
92            let candidate_item = plane_min(item);
93            let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
94            (candidate_item, candidate_coordinate)
95        } else {
96            (item, coordinate)
97        };
98
99        Self::choose_argmin(
100            accumulator.0,
101            accumulator.1,
102            Line::cast_from(candidate_item),
103            candidate_coordinate,
104        )
105    }
106
107    fn fuse_accumulators(
108        _this: &Self,
109        lhs: Self::AccumulatorItem,
110        rhs: Self::AccumulatorItem,
111    ) -> Self::AccumulatorItem {
112        Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
113    }
114
115    fn merge_line<Out: Numeric>(
116        _this: &Self,
117        accumulator: Self::AccumulatorItem,
118        _shape_axis_reduce: u32,
119    ) -> Out {
120        let line_size = accumulator.0.size();
121        if comptime!(line_size > 1) {
122            let mut min = P::EA::max_value();
123            let mut coordinate = u32::MAX.runtime();
124
125            #[unroll]
126            for k in 0..line_size {
127                let acc_element = accumulator.0[k];
128                let acc_coordinate = accumulator.1[k];
129                // TODO replace with select
130                if acc_element == min && acc_coordinate < coordinate {
131                    coordinate = acc_coordinate;
132                } else if acc_element < min {
133                    min = acc_element;
134                    coordinate = acc_coordinate;
135                }
136            }
137            Out::cast_from(coordinate)
138        } else {
139            Out::cast_from(accumulator.1)
140        }
141    }
142
143    fn to_output_perpendicular<Out: Numeric>(
144        _this: &Self,
145        accumulator: Self::AccumulatorItem,
146        _shape_axis_reduce: u32,
147    ) -> Line<Out> {
148        Line::cast_from(accumulator.1)
149    }
150}