cubecl_reduce/instructions/
maxabs.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8#[derive(Debug, CubeType, Clone)]
11pub struct MaxAbs;
12
13impl ReduceFamily for MaxAbs {
14 type Instruction<P: ReducePrecision> = Self;
15 type Config = ();
16}
17
18#[cube]
19impl<P: ReducePrecision> ReduceInstruction<P> for MaxAbs {
20 type AccumulatorItem = Line<P::EA>;
21 type SharedAccumulator = SharedMemory<Line<P::EA>>;
22 type Config = ();
23
24 fn requirements(_this: &Self) -> ReduceRequirements {
25 ReduceRequirements { coordinates: false }
26 }
27
28 fn from_config(_config: Self::Config) -> Self {
29 MaxAbs {}
30 }
31
32 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
33 Line::empty(line_size).fill(P::EI::from_int(0))
34 }
35
36 fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
37 Line::empty(line_size).fill(P::EA::from_int(0))
38 }
39
40 fn assign_accumulator(
41 _this: &Self,
42 destination: &mut Self::AccumulatorItem,
43 source: &Self::AccumulatorItem,
44 ) {
45 *destination = *source;
46 }
47
48 fn reduce(
49 _this: &Self,
50 accumulator: &Self::AccumulatorItem,
51 item: Line<P::EI>,
52 _coordinate: ReduceCoordinate,
53 #[comptime] use_planes: bool,
54 ) -> Self::AccumulatorItem {
55 if use_planes {
56 let candidate_item = Line::cast_from(plane_max(Line::abs(item)));
57 select_many(
58 accumulator.greater_than(candidate_item),
59 *accumulator,
60 candidate_item,
61 )
62 } else {
63 let item_abs = Line::cast_from(Line::abs(item));
64 select_many(accumulator.greater_than(item_abs), *accumulator, item_abs)
65 }
66 }
67
68 fn fuse_accumulators(
69 _this: &Self,
70 lhs: Self::AccumulatorItem,
71 rhs: Self::AccumulatorItem,
72 ) -> Self::AccumulatorItem {
73 select_many(lhs.greater_than(rhs), lhs, rhs)
74 }
75
76 fn merge_line<Out: Numeric>(
77 _this: &Self,
78 accumulator: Self::AccumulatorItem,
79 _shape_axis_reduce: u32,
80 ) -> Out {
81 let mut max = P::EA::from_int(0);
82 #[unroll]
83 for k in 0..accumulator.size() {
84 let candidate = accumulator[k];
85 max = select(candidate > max, candidate, max);
86 }
87 Out::cast_from(max)
88 }
89
90 fn to_output_perpendicular<Out: Numeric>(
91 _this: &Self,
92 accumulator: Self::AccumulatorItem,
93 _shape_axis_reduce: u32,
94 ) -> Line<Out> {
95 Line::cast_from(accumulator)
96 }
97}