cubecl_reduce/instructions/
argmin.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};
5
6#[derive(Debug)]
8pub struct ArgMin;
9
10impl Reduce for ArgMin {
11 type Instruction<In: Numeric> = Self;
12}
13
14#[cube]
15impl ArgMin {
16 pub fn choose_argmin<N: Numeric>(
20 items0: Line<N>,
21 coordinates0: Line<u32>,
22 items1: Line<N>,
23 coordinates1: Line<u32>,
24 ) -> (Line<N>, Line<u32>) {
25 let to_keep = select_many(
26 items0.equal(items1),
27 coordinates0.less_than(coordinates1),
28 items0.less_than(items1),
29 );
30 let items = select_many(to_keep, items0, items1);
31 let coordinates = select_many(to_keep, coordinates0, coordinates1);
32 (items, coordinates)
33 }
34}
35
36#[cube]
37impl<In: Numeric> ReduceInstruction<In> for ArgMin {
38 type AccumulatorItem = (Line<In>, Line<u32>);
39 type SharedAccumulator = ArgAccumulator<In>;
40
41 fn null_input(#[comptime] line_size: u32) -> Line<In> {
42 Line::empty(line_size).fill(In::max_value())
43 }
44
45 fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
46 (
47 Self::null_input(line_size),
48 Line::empty(line_size).fill(u32::MAX),
49 )
50 }
51
52 fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
53 destination.0 = source.0;
54 destination.1 = source.1;
55 }
56
57 fn reduce(
58 accumulator: &Self::AccumulatorItem,
59 item: Line<In>,
60 coordinate: Line<u32>,
61 #[comptime] use_planes: bool,
62 ) -> Self::AccumulatorItem {
63 let (candidate_item, candidate_coordinate) = if use_planes {
64 let candidate_item = plane_min(item);
65 let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
66 (candidate_item, candidate_coordinate)
67 } else {
68 (item, coordinate)
69 };
70 Self::choose_argmin(
71 accumulator.0,
72 accumulator.1,
73 candidate_item,
74 candidate_coordinate,
75 )
76 }
77
78 fn fuse_accumulators(
79 lhs: Self::AccumulatorItem,
80 rhs: Self::AccumulatorItem,
81 ) -> Self::AccumulatorItem {
82 Self::choose_argmin(lhs.0, lhs.1, rhs.0, rhs.1)
83 }
84
85 fn merge_line<Out: Numeric>(
86 accumulator: Self::AccumulatorItem,
87 _shape_axis_reduce: u32,
88 ) -> Out {
89 let line_size = accumulator.0.size();
90 if comptime!(line_size > 1) {
91 let mut min = In::max_value();
92 let mut coordinate = u32::MAX.runtime();
93
94 #[unroll]
95 for k in 0..line_size {
96 let acc_element = accumulator.0[k];
97 let acc_coordinate = accumulator.1[k];
98 if acc_element == min && acc_coordinate < coordinate {
100 coordinate = acc_coordinate;
101 } else if acc_element < min {
102 min = acc_element;
103 coordinate = acc_coordinate;
104 }
105 }
106 Out::cast_from(coordinate)
107 } else {
108 Out::cast_from(accumulator.1)
109 }
110 }
111
112 fn to_output_perpendicular<Out: Numeric>(
113 accumulator: Self::AccumulatorItem,
114 _shape_axis_reduce: u32,
115 ) -> Line<Out> {
116 Line::cast_from(accumulator.1)
117 }
118}