cubecl_reduce/instructions/
min.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8// TODO Add to test framework.
9/// Return the item with the maximum absolute value.
10#[derive(Debug, CubeType, Clone)]
11pub struct Min;
12
13impl ReduceFamily for Min {
14    type Instruction<P: ReducePrecision> = Self;
15    type Config = ();
16}
17
18#[cube]
19impl<P: ReducePrecision> ReduceInstruction<P> for Min {
20    type AccumulatorItem = Line<P::EA>;
21    type SharedAccumulator = SharedMemory<Line<P::EA>>;
22    type Config = ();
23
24    fn requirements(_this: &Self) -> ReduceRequirements {
25        ReduceRequirements { coordinates: false }
26    }
27
28    fn from_config(_config: Self::Config) -> Self {
29        Min {}
30    }
31    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
32        Line::empty(line_size).fill(P::EI::max_value())
33    }
34
35    fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
36        Line::empty(line_size).fill(P::EA::max_value())
37    }
38
39    fn assign_accumulator(
40        _this: &Self,
41        destination: &mut Self::AccumulatorItem,
42        source: &Self::AccumulatorItem,
43    ) {
44        *destination = *source;
45    }
46
47    fn reduce(
48        _this: &Self,
49        accumulator: &Self::AccumulatorItem,
50        item: Line<P::EI>,
51        _coordinate: ReduceCoordinate,
52        #[comptime] use_planes: bool,
53    ) -> Self::AccumulatorItem {
54        if use_planes {
55            let candidate_item = Line::cast_from(plane_min(item));
56            select_many(
57                accumulator.less_than(candidate_item),
58                *accumulator,
59                candidate_item,
60            )
61        } else {
62            let item = Line::cast_from(item);
63            select_many(accumulator.less_than(item), *accumulator, item)
64        }
65    }
66
67    fn fuse_accumulators(
68        _this: &Self,
69        lhs: Self::AccumulatorItem,
70        rhs: Self::AccumulatorItem,
71    ) -> Self::AccumulatorItem {
72        select_many(lhs.less_than(rhs), lhs, rhs)
73    }
74
75    fn merge_line<Out: Numeric>(
76        _this: &Self,
77        accumulator: Self::AccumulatorItem,
78        _shape_axis_reduce: u32,
79    ) -> Out {
80        let mut min = P::EA::max_value();
81        #[unroll]
82        for k in 0..accumulator.size() {
83            let candidate = accumulator[k];
84            min = select(candidate < min, candidate, min);
85        }
86        Out::cast_from(min)
87    }
88
89    fn to_output_perpendicular<Out: Numeric>(
90        _this: &Self,
91        accumulator: Self::AccumulatorItem,
92        _shape_axis_reduce: u32,
93    ) -> Line<Out> {
94        Line::cast_from(accumulator)
95    }
96}