cubecl_reduce/instructions/
argmin.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};
5
6/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
7#[derive(Debug)]
8pub struct ArgMin;
9
10impl Reduce for ArgMin {
11    type Instruction<In: Numeric> = Self;
12}
13
14#[cube]
15impl ArgMin {
16    /// Compare two pairs of items and coordinates and return a new pair
17    /// where each element in the lines is the minimal item with its coordinate.
18    /// In case of equality, the lowest coordinate is selected.
19    pub fn choose_argmin<N: Numeric>(
20        items0: Line<N>,
21        coordinates0: Line<u32>,
22        items1: Line<N>,
23        coordinates1: Line<u32>,
24    ) -> (Line<N>, Line<u32>) {
25        let to_keep = select_many(
26            items0.equal(items1),
27            coordinates0.less_than(coordinates1),
28            items0.less_than(items1),
29        );
30        let items = select_many(to_keep, items0, items1);
31        let coordinates = select_many(to_keep, coordinates0, coordinates1);
32        (items, coordinates)
33    }
34}
35
36#[cube]
37impl<In: Numeric> ReduceInstruction<In> for ArgMin {
38    type AccumulatorItem = (Line<In>, Line<u32>);
39    type SharedAccumulator = ArgAccumulator<In>;
40
41    fn null_input(#[comptime] line_size: u32) -> Line<In> {
42        Line::empty(line_size).fill(In::max_value())
43    }
44
45    fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
46        (
47            Self::null_input(line_size),
48            Line::empty(line_size).fill(u32::MAX),
49        )
50    }
51
52    fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
53        destination.0 = source.0;
54        destination.1 = source.1;
55    }
56
57    fn reduce(
58        accumulator: &Self::AccumulatorItem,
59        item: Line<In>,
60        coordinate: Line<u32>,
61        #[comptime] use_planes: bool,
62    ) -> Self::AccumulatorItem {
63        let (candidate_item, candidate_coordinate) = if use_planes {
64            let candidate_item = plane_min(item);
65            let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
66            (candidate_item, candidate_coordinate)
67        } else {
68            (item, coordinate)
69        };
70        Self::choose_argmin(
71            accumulator.0,
72            accumulator.1,
73            candidate_item,
74            candidate_coordinate,
75        )
76    }
77
78    fn fuse_accumulators(
79        lhs: Self::AccumulatorItem,
80        rhs: Self::AccumulatorItem,
81    ) -> Self::AccumulatorItem {
82        Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
83    }
84
85    fn merge_line<Out: Numeric>(
86        accumulator: Self::AccumulatorItem,
87        _shape_axis_reduce: u32,
88    ) -> Out {
89        let line_size = accumulator.0.size();
90        if comptime!(line_size > 1) {
91            let mut min = In::max_value();
92            let mut coordinate = u32::MAX.runtime();
93
94            #[unroll]
95            for k in 0..line_size {
96                let acc_element = accumulator.0[k];
97                let acc_coordinate = accumulator.1[k];
98                // TODO replace with select
99                if acc_element == min && acc_coordinate < coordinate {
100                    coordinate = acc_coordinate;
101                } else if acc_element < min {
102                    min = acc_element;
103                    coordinate = acc_coordinate;
104                }
105            }
106            Out::cast_from(coordinate)
107        } else {
108            Out::cast_from(accumulator.1)
109        }
110    }
111
112    fn to_output_perpendicular<Out: Numeric>(
113        accumulator: Self::AccumulatorItem,
114        _shape_axis_reduce: u32,
115    ) -> Line<Out> {
116        Line::cast_from(accumulator.1)
117    }
118}