cubecl_reduce/instructions/
argmax.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{
7 ArgAccumulator, ReduceCoordinate, ReduceCoordinateExpand, ReduceFamily, ReduceInstruction,
8 lowest_coordinate_matching,
9};
10
11#[derive(Debug, CubeType, Clone)]
13pub struct ArgMax {}
14
15#[cube]
16impl ArgMax {
17 pub fn choose_argmax<N: Numeric>(
21 items0: Line<N>,
22 coordinates0: Line<u32>,
23 items1: Line<N>,
24 coordinates1: Line<u32>,
25 ) -> (Line<N>, Line<u32>) {
26 let to_keep = select_many(
27 items0.equal(items1),
28 coordinates0.less_than(coordinates1),
29 items0.greater_than(items1),
30 );
31 let items = select_many(to_keep, items0, items1);
32 let coordinates = select_many(to_keep, coordinates0, coordinates1);
33 (items, coordinates)
34 }
35}
36
37impl ReduceFamily for ArgMax {
38 type Instruction<P: ReducePrecision> = Self;
39 type Config = ();
40}
41
42#[cube]
43impl<P: ReducePrecision> ReduceInstruction<P> for ArgMax {
44 type AccumulatorItem = (Line<P::EA>, Line<u32>);
45 type SharedAccumulator = ArgAccumulator<P::EA>;
46 type Config = ();
47
48 fn requirements(_this: &Self) -> ReduceRequirements {
49 ReduceRequirements { coordinates: true }
50 }
51
52 fn from_config(_config: Self::Config) -> Self {
53 ArgMax {}
54 }
55
56 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
57 Line::empty(line_size).fill(P::EI::min_value())
58 }
59
60 fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
61 (
62 Line::empty(line_size).fill(P::EA::min_value()),
63 Line::empty(line_size).fill(u32::MAX),
64 )
65 }
66
67 fn assign_accumulator(
68 _this: &Self,
69 destination: &mut Self::AccumulatorItem,
70 source: &Self::AccumulatorItem,
71 ) {
72 destination.0 = source.0;
73 destination.1 = source.1;
74 }
75
76 fn reduce(
77 _this: &Self,
78 accumulator: &Self::AccumulatorItem,
79 item: Line<P::EI>,
80 coordinate: ReduceCoordinate,
81 #[comptime] use_planes: bool,
82 ) -> Self::AccumulatorItem {
83 let coordinate = match coordinate {
84 ReduceCoordinate::Required(val) => val,
85 ReduceCoordinate::NotRequired => {
86 comptime! {panic!("Coordinates are required for ArgMin")};
87 #[allow(unreachable_code)]
88 Line::new(0)
89 }
90 };
91
92 let (candidate_item, candidate_coordinate) = if use_planes {
93 let candidate_item = plane_max(item);
94 let candidate_coordinate = lowest_coordinate_matching(candidate_item, item, coordinate);
95 (candidate_item, candidate_coordinate)
96 } else {
97 (item, coordinate)
98 };
99
100 Self::choose_argmax(
101 Line::cast_from(candidate_item),
102 candidate_coordinate,
103 accumulator.0,
104 accumulator.1,
105 )
106 }
107
108 fn fuse_accumulators(
109 _this: &Self,
110 lhs: Self::AccumulatorItem,
111 rhs: Self::AccumulatorItem,
112 ) -> Self::AccumulatorItem {
113 Self::choose_argmax(lhs.0, lhs.1, rhs.0, rhs.1)
114 }
115
116 fn merge_line<Out: Numeric>(
117 _this: &Self,
118 accumulator: Self::AccumulatorItem,
119 _shape_axis_reduce: u32,
120 ) -> Out {
121 let line_size = accumulator.0.size();
122 if comptime!(line_size > 1) {
123 let mut max = P::EA::min_value();
124 let mut coordinate = u32::MAX.runtime();
125 #[unroll]
126 for k in 0..line_size {
127 let acc_element = accumulator.0[k];
128 let acc_coordinate = accumulator.1[k];
129 if acc_element == max && acc_coordinate < coordinate {
130 coordinate = acc_coordinate;
131 } else if acc_element > max {
132 max = acc_element;
133 coordinate = acc_coordinate;
134 }
135 }
136 Out::cast_from(coordinate)
137 } else {
138 Out::cast_from(accumulator.1)
139 }
140 }
141
142 fn to_output_perpendicular<Out: Numeric>(
143 _this: &Self,
144 accumulator: Self::AccumulatorItem,
145 _shape_axis_reduce: u32,
146 ) -> Line<Out> {
147 Line::cast_from(accumulator.1)
148 }
149}