cubecl_reduce/instructions/
argmin.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{
5    ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
6    ReduceRequirements, lowest_coordinate_matching,
7};
8
9/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
10#[derive(Debug, CubeType, Clone)]
11pub struct ArgMin {}
12
13impl ReduceFamily for ArgMin {
14    type Instruction<In: Numeric> = Self;
15    type Config = ();
16}
17
18#[cube]
19impl ArgMin {
20    /// Compare two pairs of items and coordinates and return a new pair
21    /// where each element in the lines is the minimal item with its coordinate.
22    /// In case of equality, the lowest coordinate is selected.
23    pub fn choose_argmin<N: Numeric>(
24        items0: Line<N>,
25        coordinates0: Line<u32>,
26        items1: Line<N>,
27        coordinates1: Line<u32>,
28    ) -> (Line<N>, Line<u32>) {
29        let to_keep = select_many(
30            items0.equal(items1),
31            coordinates0.less_than(coordinates1),
32            items0.less_than(items1),
33        );
34        let items = select_many(to_keep, items0, items1);
35        let coordinates = select_many(to_keep, coordinates0, coordinates1);
36        (items, coordinates)
37    }
38}
39
40#[cube]
41impl<In: Numeric> ReduceInstruction<In> for ArgMin {
42    type AccumulatorItem = (Line<In>, Line<u32>);
43    type SharedAccumulator = ArgAccumulator<In>;
44    type Config = ();
45
46    fn requirements(_this: &Self) -> ReduceRequirements {
47        ReduceRequirements { coordinates: true }
48    }
49    fn from_config(_config: Self::Config) -> Self {
50        ArgMin {}
51    }
52
53    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
54        Line::empty(line_size).fill(In::max_value())
55    }
56
57    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
58        (
59            Self::null_input(this, line_size),
60            Line::empty(line_size).fill(u32::MAX),
61        )
62    }
63
64    fn assign_accumulator(
65        _this: &Self,
66        destination: &mut Self::AccumulatorItem,
67        source: &Self::AccumulatorItem,
68    ) {
69        destination.0 = source.0;
70        destination.1 = source.1;
71    }
72
73    fn reduce(
74        _this: &Self,
75        accumulator: &Self::AccumulatorItem,
76        item: Line<In>,
77        coordinate: ReduceCoordinate,
78        #[comptime] use_planes: bool,
79    ) -> Self::AccumulatorItem {
80        let coordinate = match coordinate {
81            ReduceCoordinate::Required(val) => val,
82            ReduceCoordinate::NotRequired => {
83                comptime! {panic!("Coordinates are required for ArgMin")};
84                #[allow(unreachable_code)]
85                Line::new(0)
86            }
87        };
88
89        let (candidate_item, candidate_coordinate) = if use_planes {
90            let candidate_item = plane_min(item);
91            let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
92            (candidate_item, candidate_coordinate)
93        } else {
94            (item, coordinate)
95        };
96
97        Self::choose_argmin(
98            accumulator.0,
99            accumulator.1,
100            candidate_item,
101            candidate_coordinate,
102        )
103    }
104
105    fn fuse_accumulators(
106        _this: &Self,
107        lhs: Self::AccumulatorItem,
108        rhs: Self::AccumulatorItem,
109    ) -> Self::AccumulatorItem {
110        Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
111    }
112
113    fn merge_line<Out: Numeric>(
114        _this: &Self,
115        accumulator: Self::AccumulatorItem,
116        _shape_axis_reduce: u32,
117    ) -> Out {
118        let line_size = accumulator.0.size();
119        if comptime!(line_size > 1) {
120            let mut min = In::max_value();
121            let mut coordinate = u32::MAX.runtime();
122
123            #[unroll]
124            for k in 0..line_size {
125                let acc_element = accumulator.0[k];
126                let acc_coordinate = accumulator.1[k];
127                // TODO replace with select
128                if acc_element == min && acc_coordinate < coordinate {
129                    coordinate = acc_coordinate;
130                } else if acc_element < min {
131                    min = acc_element;
132                    coordinate = acc_coordinate;
133                }
134            }
135            Out::cast_from(coordinate)
136        } else {
137            Out::cast_from(accumulator.1)
138        }
139    }
140
141    fn to_output_perpendicular<Out: Numeric>(
142        _this: &Self,
143        accumulator: Self::AccumulatorItem,
144        _shape_axis_reduce: u32,
145    ) -> Line<Out> {
146        Line::cast_from(accumulator.1)
147    }
148}