cubecl_reduce/instructions/
min.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::instructions::ReduceRequirements;
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8// TODO Add to test framework.
9/// Return the item with the maximum absolute value.
10#[derive(Debug, CubeType, Clone)]
11pub struct Min;
12
13impl ReduceFamily for Min {
14    type Instruction<In: Numeric> = Self;
15    type Config = ();
16}
17
18#[cube]
19impl<In: Numeric> ReduceInstruction<In> for Min {
20    type AccumulatorItem = Line<In>;
21    type SharedAccumulator = SharedMemory<Line<In>>;
22    type Config = ();
23
24    fn requirements(_this: &Self) -> ReduceRequirements {
25        ReduceRequirements { coordinates: false }
26    }
27
28    fn from_config(_config: Self::Config) -> Self {
29        Min {}
30    }
31    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
32        Line::empty(line_size).fill(In::max_value())
33    }
34
35    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
36        Self::null_input(this, line_size)
37    }
38
39    fn assign_accumulator(
40        _this: &Self,
41        destination: &mut Self::AccumulatorItem,
42        source: &Self::AccumulatorItem,
43    ) {
44        *destination = *source;
45    }
46
47    fn reduce(
48        _this: &Self,
49        accumulator: &Self::AccumulatorItem,
50        item: Line<In>,
51        _coordinate: ReduceCoordinate,
52        #[comptime] use_planes: bool,
53    ) -> Self::AccumulatorItem {
54        if use_planes {
55            let candidate_item = plane_min(item);
56            select_many(
57                accumulator.less_than(candidate_item),
58                *accumulator,
59                candidate_item,
60            )
61        } else {
62            select_many(accumulator.less_than(item), *accumulator, item)
63        }
64    }
65
66    fn fuse_accumulators(
67        _this: &Self,
68        lhs: Self::AccumulatorItem,
69        rhs: Self::AccumulatorItem,
70    ) -> Self::AccumulatorItem {
71        select_many(lhs.less_than(rhs), lhs, rhs)
72    }
73
74    fn merge_line<Out: Numeric>(
75        _this: &Self,
76        accumulator: Self::AccumulatorItem,
77        _shape_axis_reduce: u32,
78    ) -> Out {
79        let mut min = In::max_value();
80        #[unroll]
81        for k in 0..accumulator.size() {
82            let candidate = accumulator[k];
83            min = select(candidate < min, candidate, min);
84        }
85        Out::cast_from(min)
86    }
87
88    fn to_output_perpendicular<Out: Numeric>(
89        _this: &Self,
90        accumulator: Self::AccumulatorItem,
91        _shape_axis_reduce: u32,
92    ) -> Line<Out> {
93        Line::cast_from(accumulator)
94    }
95}