cubecl_reduce/instructions/
argmin.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6use super::{
7 ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
8 ReduceRequirements, lowest_coordinate_matching,
9};
10
11#[derive(Debug, CubeType, Clone)]
13pub struct ArgMin {}
14
15impl ReduceFamily for ArgMin {
16 type Instruction<P: ReducePrecision> = Self;
17 type Config = ();
18}
19
20#[cube]
21impl ArgMin {
22 pub fn choose_argmin<N: Numeric>(
26 items0: Line<N>,
27 coordinates0: Line<u32>,
28 items1: Line<N>,
29 coordinates1: Line<u32>,
30 ) -> (Line<N>, Line<u32>) {
31 let to_keep = select_many(
32 items0.equal(items1),
33 coordinates0.less_than(coordinates1),
34 items0.less_than(items1),
35 );
36 let items = select_many(to_keep, items0, items1);
37 let coordinates = select_many(to_keep, coordinates0, coordinates1);
38 (items, coordinates)
39 }
40}
41
42#[cube]
43impl<P: ReducePrecision> ReduceInstruction<P> for ArgMin {
44 type AccumulatorItem = (Line<P::EA>, Line<u32>);
45 type SharedAccumulator = ArgAccumulator<P::EA>;
46 type Config = ();
47
48 fn requirements(_this: &Self) -> ReduceRequirements {
49 ReduceRequirements { coordinates: true }
50 }
51 fn from_config(_config: Self::Config) -> Self {
52 ArgMin {}
53 }
54
55 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
56 Line::empty(line_size).fill(P::EI::max_value())
57 }
58
59 fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
60 (
61 Line::empty(line_size).fill(P::EA::max_value()),
62 Line::empty(line_size).fill(u32::MAX),
63 )
64 }
65
66 fn assign_accumulator(
67 _this: &Self,
68 destination: &mut Self::AccumulatorItem,
69 source: &Self::AccumulatorItem,
70 ) {
71 destination.0 = source.0;
72 destination.1 = source.1;
73 }
74
75 fn reduce(
76 _this: &Self,
77 accumulator: &Self::AccumulatorItem,
78 item: Line<P::EI>,
79 coordinate: ReduceCoordinate,
80 #[comptime] use_planes: bool,
81 ) -> Self::AccumulatorItem {
82 let coordinate = match coordinate {
83 ReduceCoordinate::Required(val) => val,
84 ReduceCoordinate::NotRequired => {
85 comptime! {panic!("Coordinates are required for ArgMin")};
86 #[allow(unreachable_code)]
87 Line::new(0)
88 }
89 };
90
91 let (candidate_item, candidate_coordinate) = if use_planes {
92 let candidate_item = plane_min(item);
93 let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
94 (candidate_item, candidate_coordinate)
95 } else {
96 (item, coordinate)
97 };
98
99 Self::choose_argmin(
100 accumulator.0,
101 accumulator.1,
102 Line::cast_from(candidate_item),
103 candidate_coordinate,
104 )
105 }
106
107 fn fuse_accumulators(
108 _this: &Self,
109 lhs: Self::AccumulatorItem,
110 rhs: Self::AccumulatorItem,
111 ) -> Self::AccumulatorItem {
112 Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
113 }
114
115 fn merge_line<Out: Numeric>(
116 _this: &Self,
117 accumulator: Self::AccumulatorItem,
118 _shape_axis_reduce: u32,
119 ) -> Out {
120 let line_size = accumulator.0.size();
121 if comptime!(line_size > 1) {
122 let mut min = P::EA::max_value();
123 let mut coordinate = u32::MAX.runtime();
124
125 #[unroll]
126 for k in 0..line_size {
127 let acc_element = accumulator.0[k];
128 let acc_coordinate = accumulator.1[k];
129 if acc_element == min && acc_coordinate < coordinate {
131 coordinate = acc_coordinate;
132 } else if acc_element < min {
133 min = acc_element;
134 coordinate = acc_coordinate;
135 }
136 }
137 Out::cast_from(coordinate)
138 } else {
139 Out::cast_from(accumulator.1)
140 }
141 }
142
143 fn to_output_perpendicular<Out: Numeric>(
144 _this: &Self,
145 accumulator: Self::AccumulatorItem,
146 _shape_axis_reduce: u32,
147 ) -> Line<Out> {
148 Line::cast_from(accumulator.1)
149 }
150}