cubecl_reduce/instructions/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6pub trait ReduceFamily: Send + Sync + 'static + std::fmt::Debug {
7    type Instruction<P: ReducePrecision>: ReduceInstruction<P, Config = Self::Config>;
8    type Config: CubeComptime + Send + Sync;
9}
10
11#[derive(CubeType)]
12pub struct ReduceRequirements {
13    #[cube(comptime)]
14    pub coordinates: bool,
15}
16
17/// An instruction for a reduce algorithm that works with [`Line`].
18///
19/// See a provided implementation, such as [`Sum`](super::Sum) or [`ArgMax`](super::ArgMax) for an example how to implement
20/// this trait for a custom instruction.
21///
22/// A reduction works at three levels. First, it takes input data of type `In` and reduce them
23/// with their coordinate into an `AccumulatorItem`. Then, multiple `AccumulatorItem` are possibly fused
24/// together into a single accumulator that is converted to the expected output type.
25#[cube]
26pub trait ReduceInstruction<P: ReducePrecision>:
27    Send + Sync + 'static + std::fmt::Debug + CubeType
28{
29    type Config: CubeComptime + Send + Sync;
30
31    /// Requirements of the reduce.
32    fn requirements(this: &Self) -> ReduceRequirements;
33
34    /// The intermediate state into which we accumulate new input elements.
35    /// This is most likely a `Line<T>` or a struct or tuple of lines.
36    type AccumulatorItem: CubeType;
37
38    /// When multiple agents are collaborating to reduce a single slice,
39    /// we need a share accumulator to store multiple `AccumulatorItem`.
40    /// This is most likely a `SharedMemory<Line<T>>` or a struct or tuple of lined shared memories.
41    type SharedAccumulator: SharedAccumulator<Item = Self::AccumulatorItem>;
42
43    fn from_config(#[comptime] config: Self::Config) -> Self;
44    /// A input such that `Self::reduce(accumulator, Self::null_input(), coordinate, use_planes)`
45    /// is guaranteed to return `accumulator` unchanged for any choice of `coordinate`.
46    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI>;
47
48    /// A accumulator such that `Self::fuse_accumulators(accumulator, Self::null_accumulator()` always returns
49    /// is guaranteed to return `accumulator` unchanged.
50    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem;
51
52    /// Assign the value of `source` into `destination`.
53    /// In spirit, this is equivalent to `destination = source;`,
54    /// but this syntax is not currently supported by CubeCL.
55    fn assign_accumulator(
56        this: &Self,
57        destination: &mut Self::AccumulatorItem,
58        source: &Self::AccumulatorItem,
59    );
60
61    /// If `use_planes` is `true`, reduce all the `item` and `coordinate` within the `accumulator`.
62    /// Else, reduce the given `item` and `coordinate` into the accumulator.
63    fn reduce(
64        this: &Self,
65        accumulator: &Self::AccumulatorItem,
66        item: Line<P::EI>,
67        coordinate: ReduceCoordinate,
68        #[comptime] use_planes: bool,
69    ) -> Self::AccumulatorItem;
70
71    /// Reduce two accumulators into a single accumulator.
72    fn fuse_accumulators(
73        this: &Self,
74        lhs: Self::AccumulatorItem,
75        rhs: Self::AccumulatorItem,
76    ) -> Self::AccumulatorItem;
77
78    /// Reduce all elements of the accumulator into a single output element of type `Out`.
79    fn merge_line<Out: Numeric>(
80        this: &Self,
81        accumulator: Self::AccumulatorItem,
82        shape_axis_reduce: u32,
83    ) -> Out;
84
85    /// Convert each element of the accumulator into the expected output element of type `Out`.
86    fn to_output_perpendicular<Out: Numeric>(
87        this: &Self,
88        accumulator: Self::AccumulatorItem,
89        shape_axis_reduce: u32,
90    ) -> Line<Out>;
91}
92
93#[derive(CubeType)]
94pub enum ReduceCoordinate {
95    Required(Line<u32>),
96    NotRequired,
97}
98
99/// A simple trait that abstract over a single or multiple shared memory.
100#[cube]
101pub trait SharedAccumulator: CubeType + Send + Sync + 'static {
102    type Item: CubeType;
103
104    fn allocate(
105        #[comptime] length: u32,
106        #[comptime] line_size: u32,
107        #[comptime] _coordinate: bool,
108    ) -> Self;
109
110    fn read(accumulator: &Self, index: u32) -> Self::Item;
111
112    fn write(accumulator: &mut Self, index: u32, item: Self::Item);
113}
114
115#[cube]
116impl<In: Numeric> SharedAccumulator for SharedMemory<Line<In>> {
117    type Item = Line<In>;
118
119    fn allocate(
120        #[comptime] length: u32,
121        #[comptime] line_size: u32,
122        #[comptime] _coordinate: bool,
123    ) -> Self {
124        SharedMemory::new_lined(length, line_size)
125    }
126
127    fn read(accumulator: &Self, index: u32) -> Self::Item {
128        accumulator[index]
129    }
130
131    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
132        accumulator[index] = item;
133    }
134}
135
136/// A pair of shared memory used for [`ArgMax`](super::ArgMax) and [`ArgMin`](super::ArgMin).
137#[derive(CubeType)]
138pub struct ArgAccumulator<N: Numeric> {
139    pub elements: SharedMemory<Line<N>>,
140    pub args: SharedMemory<Line<u32>>,
141}
142
143#[cube]
144impl<In: Numeric> SharedAccumulator for ArgAccumulator<In> {
145    type Item = (Line<In>, Line<u32>);
146
147    fn allocate(
148        #[comptime] length: u32,
149        #[comptime] line_size: u32,
150        #[comptime] _coordinate: bool,
151    ) -> Self {
152        ArgAccumulator::<In> {
153            elements: SharedMemory::new_lined(length, line_size),
154            args: SharedMemory::new_lined(length, line_size),
155        }
156    }
157
158    fn read(accumulator: &Self, index: u32) -> Self::Item {
159        (accumulator.elements[index], accumulator.args[index])
160    }
161
162    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
163        accumulator.elements[index] = item.0;
164        accumulator.args[index] = item.1;
165    }
166}
167
168#[cube]
169pub fn reduce_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
170    inst: &R,
171    accumulator: &mut R::AccumulatorItem,
172    item: Line<P::EI>,
173    coordinate: ReduceCoordinate,
174    #[comptime] use_planes: bool,
175) {
176    let reduction = &R::reduce(inst, accumulator, item, coordinate, use_planes);
177    R::assign_accumulator(inst, accumulator, reduction);
178}
179
180#[cube]
181pub fn reduce_shared_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
182    inst: &R,
183    accumulator: &mut R::SharedAccumulator,
184    index: u32,
185    item: Line<P::EI>,
186    coordinate: ReduceCoordinate,
187    #[comptime] use_planes: bool,
188) {
189    let acc_item = R::SharedAccumulator::read(accumulator, index);
190    let reduction = R::reduce(inst, &acc_item, item, coordinate, use_planes);
191    R::SharedAccumulator::write(accumulator, index, reduction);
192}
193
194#[cube]
195pub fn fuse_accumulator_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
196    inst: &R,
197    accumulator: &mut R::SharedAccumulator,
198    destination: u32,
199    origin: u32,
200) {
201    let fused = R::fuse_accumulators(
202        inst,
203        R::SharedAccumulator::read(accumulator, destination),
204        R::SharedAccumulator::read(accumulator, origin),
205    );
206    R::SharedAccumulator::write(accumulator, destination, fused);
207}