cubecl_reduce/instructions/
min.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::instructions::ReduceRequirements;
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8#[derive(Debug, CubeType, Clone)]
11pub struct Min;
12
13impl ReduceFamily for Min {
14 type Instruction<In: Numeric> = Self;
15 type Config = ();
16}
17
18#[cube]
19impl<In: Numeric> ReduceInstruction<In> for Min {
20 type AccumulatorItem = Line<In>;
21 type SharedAccumulator = SharedMemory<Line<In>>;
22 type Config = ();
23
24 fn requirements(_this: &Self) -> ReduceRequirements {
25 ReduceRequirements { coordinates: false }
26 }
27
28 fn from_config(_config: Self::Config) -> Self {
29 Min {}
30 }
31 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<In> {
32 Line::empty(line_size).fill(In::max_value())
33 }
34
35 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
36 Self::null_input(this, line_size)
37 }
38
39 fn assign_accumulator(
40 _this: &Self,
41 destination: &mut Self::AccumulatorItem,
42 source: &Self::AccumulatorItem,
43 ) {
44 *destination = *source;
45 }
46
47 fn reduce(
48 _this: &Self,
49 accumulator: &Self::AccumulatorItem,
50 item: Line<In>,
51 _coordinate: ReduceCoordinate,
52 #[comptime] use_planes: bool,
53 ) -> Self::AccumulatorItem {
54 if use_planes {
55 let candidate_item = plane_min(item);
56 select_many(
57 accumulator.less_than(candidate_item),
58 *accumulator,
59 candidate_item,
60 )
61 } else {
62 select_many(accumulator.less_than(item), *accumulator, item)
63 }
64 }
65
66 fn fuse_accumulators(
67 _this: &Self,
68 lhs: Self::AccumulatorItem,
69 rhs: Self::AccumulatorItem,
70 ) -> Self::AccumulatorItem {
71 select_many(lhs.less_than(rhs), lhs, rhs)
72 }
73
74 fn merge_line<Out: Numeric>(
75 _this: &Self,
76 accumulator: Self::AccumulatorItem,
77 _shape_axis_reduce: u32,
78 ) -> Out {
79 let mut min = In::max_value();
80 #[unroll]
81 for k in 0..accumulator.size() {
82 let candidate = accumulator[k];
83 min = select(candidate < min, candidate, min);
84 }
85 Out::cast_from(min)
86 }
87
88 fn to_output_perpendicular<Out: Numeric>(
89 _this: &Self,
90 accumulator: Self::AccumulatorItem,
91 _shape_axis_reduce: u32,
92 ) -> Line<Out> {
93 Line::cast_from(accumulator)
94 }
95}