cubecl_reduce/instructions/
argmax.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};
5
6#[derive(Debug)]
8pub struct ArgMax;
9
10#[cube]
11impl ArgMax {
12 pub fn choose_argmax<N: Numeric>(
16 items0: Line<N>,
17 coordinates0: Line<u32>,
18 items1: Line<N>,
19 coordinates1: Line<u32>,
20 ) -> (Line<N>, Line<u32>) {
21 let to_keep = select_many(
22 items0.equal(items1),
23 coordinates0.less_than(coordinates1),
24 items0.greater_than(items1),
25 );
26 let items = select_many(to_keep, items0, items1);
27 let coordinates = select_many(to_keep, coordinates0, coordinates1);
28 (items, coordinates)
29 }
30}
31
32impl Reduce for ArgMax {
33 type Instruction<In: Numeric> = Self;
34}
35
36#[cube]
37impl<In: Numeric> ReduceInstruction<In> for ArgMax {
38 type AccumulatorItem = (Line<In>, Line<u32>);
39 type SharedAccumulator = ArgAccumulator<In>;
40
41 fn null_input(#[comptime] line_size: u32) -> Line<In> {
42 Line::empty(line_size).fill(In::min_value())
43 }
44
45 fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
46 (
47 Self::null_input(line_size),
48 Line::empty(line_size).fill(u32::MAX),
49 )
50 }
51
52 fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
53 destination.0 = source.0;
54 destination.1 = source.1;
55 }
56
57 fn reduce(
58 accumulator: &Self::AccumulatorItem,
59 item: Line<In>,
60 coordinate: Line<u32>,
61 #[comptime] use_planes: bool,
62 ) -> Self::AccumulatorItem {
63 let (candidate_item, candidate_coordinate) = if use_planes {
64 let candidate_item = plane_max(item);
65 let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
66 (candidate_item, candidate_coordinate)
67 } else {
68 (item, coordinate)
69 };
70 Self::choose_argmax(
71 candidate_item,
72 candidate_coordinate,
73 accumulator.0,
74 accumulator.1,
75 )
76 }
77
78 fn fuse_accumulators(
79 lhs: Self::AccumulatorItem,
80 rhs: Self::AccumulatorItem,
81 ) -> Self::AccumulatorItem {
82 Self::choose_argmax(lhs.0, lhs.1, rhs.0, rhs.1)
83 }
84
85 fn merge_line<Out: Numeric>(
86 accumulator: Self::AccumulatorItem,
87 _shape_axis_reduce: u32,
88 ) -> Out {
89 let line_size = accumulator.0.size();
90 if comptime!(line_size > 1) {
91 let mut max = In::min_value();
92 let mut coordinate = u32::MAX.runtime();
93 #[unroll]
94 for k in 0..line_size {
95 let acc_element = accumulator.0[k];
96 let acc_coordinate = accumulator.1[k];
97 if acc_element == max && acc_coordinate < coordinate {
98 coordinate = acc_coordinate;
99 } else if acc_element > max {
100 max = acc_element;
101 coordinate = acc_coordinate;
102 }
103 }
104 Out::cast_from(coordinate)
105 } else {
106 Out::cast_from(accumulator.1)
107 }
108 }
109
110 fn to_output_perpendicular<Out: Numeric>(
111 accumulator: Self::AccumulatorItem,
112 _shape_axis_reduce: u32,
113 ) -> Line<Out> {
114 Line::cast_from(accumulator.1)
115 }
116}