cubecl_reduce/instructions/
argmin.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{
5 ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
6 ReduceRequirements, lowest_coordinate_matching,
7};
8
9#[derive(Debug, CubeType, Clone)]
11pub struct ArgMin {}
12
13impl ReduceFamily for ArgMin {
14 type Instruction<In: Numeric> = Self;
15 type Config = ();
16}
17
18#[cube]
19impl ArgMin {
20 pub fn choose_argmin<N: Numeric>(
24 items0: Line<N>,
25 coordinates0: Line<u32>,
26 items1: Line<N>,
27 coordinates1: Line<u32>,
28 ) -> (Line<N>, Line<u32>) {
29 let to_keep = select_many(
30 items0.equal(items1),
31 coordinates0.less_than(coordinates1),
32 items0.less_than(items1),
33 );
34 let items = select_many(to_keep, items0, items1);
35 let coordinates = select_many(to_keep, coordinates0, coordinates1);
36 (items, coordinates)
37 }
38}
39
40#[cube]
41impl<In: Numeric> ReduceInstruction<In> for ArgMin {
42 type AccumulatorItem = (Line<In>, Line<u32>);
43 type SharedAccumulator = ArgAccumulator<In>;
44 type Config = ();
45
46 fn requirements(_this: &Self) -> ReduceRequirements {
47 ReduceRequirements { coordinates: true }
48 }
49 fn from_config(_config: Self::Config) -> Self {
50 ArgMin {}
51 }
52
53 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
54 Line::empty(line_size).fill(In::max_value())
55 }
56
57 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
58 (
59 Self::null_input(this, line_size),
60 Line::empty(line_size).fill(u32::MAX),
61 )
62 }
63
64 fn assign_accumulator(
65 _this: &Self,
66 destination: &mut Self::AccumulatorItem,
67 source: &Self::AccumulatorItem,
68 ) {
69 destination.0 = source.0;
70 destination.1 = source.1;
71 }
72
73 fn reduce(
74 _this: &Self,
75 accumulator: &Self::AccumulatorItem,
76 item: Line<In>,
77 coordinate: ReduceCoordinate,
78 #[comptime] use_planes: bool,
79 ) -> Self::AccumulatorItem {
80 let coordinate = match coordinate {
81 ReduceCoordinate::Required(val) => val,
82 ReduceCoordinate::NotRequired => {
83 comptime! {panic!("Coordinates are required for ArgMin")};
84 #[allow(unreachable_code)]
85 Line::new(0)
86 }
87 };
88
89 let (candidate_item, candidate_coordinate) = if use_planes {
90 let candidate_item = plane_min(item);
91 let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
92 (candidate_item, candidate_coordinate)
93 } else {
94 (item, coordinate)
95 };
96
97 Self::choose_argmin(
98 accumulator.0,
99 accumulator.1,
100 candidate_item,
101 candidate_coordinate,
102 )
103 }
104
105 fn fuse_accumulators(
106 _this: &Self,
107 lhs: Self::AccumulatorItem,
108 rhs: Self::AccumulatorItem,
109 ) -> Self::AccumulatorItem {
110 Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
111 }
112
113 fn merge_line<Out: Numeric>(
114 _this: &Self,
115 accumulator: Self::AccumulatorItem,
116 _shape_axis_reduce: u32,
117 ) -> Out {
118 let line_size = accumulator.0.size();
119 if comptime!(line_size > 1) {
120 let mut min = In::max_value();
121 let mut coordinate = u32::MAX.runtime();
122
123 #[unroll]
124 for k in 0..line_size {
125 let acc_element = accumulator.0[k];
126 let acc_coordinate = accumulator.1[k];
127 if acc_element == min && acc_coordinate < coordinate {
129 coordinate = acc_coordinate;
130 } else if acc_element < min {
131 min = acc_element;
132 coordinate = acc_coordinate;
133 }
134 }
135 Out::cast_from(coordinate)
136 } else {
137 Out::cast_from(accumulator.1)
138 }
139 }
140
141 fn to_output_perpendicular<Out: Numeric>(
142 _this: &Self,
143 accumulator: Self::AccumulatorItem,
144 _shape_axis_reduce: u32,
145 ) -> Line<Out> {
146 Line::cast_from(accumulator.1)
147 }
148}