cubecl_reduce/instructions/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4pub trait ReduceFamily: Send + Sync + 'static + std::fmt::Debug {
5    type Instruction<In: Numeric>: ReduceInstruction<In, Config = Self::Config>;
6    type Config: CubeComptime + Send + Sync;
7}
8
9#[derive(CubeType)]
10pub struct ReduceRequirements {
11    #[cube(comptime)]
12    pub coordinates: bool,
13}
14
15/// An instruction for a reduce algorithm that works with [`Line`].
16///
17/// See a provided implementation, such as [`Sum`](super::Sum) or [`ArgMax`](super::ArgMax) for an example how to implement
18/// this trait for a custom instruction.
19///
20/// A reduction works at three levels. First, it takes input data of type `In` and reduce them
21/// with their coordinate into an `AccumulatorItem`. Then, multiple `AccumulatorItem` are possibly fused
22/// together into a single accumulator that is converted to the expected output type.
23#[cube]
24pub trait ReduceInstruction<In: Numeric>:
25    Send + Sync + 'static + std::fmt::Debug + CubeType
26{
27    type Config: CubeComptime + Send + Sync;
28
29    /// Requirements of the reduce.
30    fn requirements(this: &Self) -> ReduceRequirements;
31
32    /// The intermediate state into which we accumulate new input elements.
33    /// This is most likely a `Line<T>` or a struct or tuple of lines.
34    type AccumulatorItem: CubeType;
35
36    /// When multiple agents are collaborating to reduce a single slice,
37    /// we need a share accumulator to store multiple `AccumulatorItem`.
38    /// This is most likely a `SharedMemory<Line<T>>` or a struct or tuple of lined shared memories.
39    type SharedAccumulator: SharedAccumulator<In, Item = Self::AccumulatorItem>;
40
41    fn from_config(#[comptime] config: Self::Config) -> Self;
42    /// A input such that `Self::reduce(accumulator, Self::null_input(), coordinate, use_planes)`
43    /// is guaranteed to return `accumulator` unchanged for any choice of `coordinate`.
44    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In>;
45
46    /// A accumulator such that `Self::fuse_accumulators(accumulator, Self::null_accumulator()` always returns
47    /// is guaranteed to return `accumulator` unchanged.
48    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem;
49
50    /// Assign the value of `source` into `destination`.
51    /// In spirit, this is equivalent to `destination = source;`,
52    /// but this syntax is not currently supported by CubeCL.
53    fn assign_accumulator(
54        this: &Self,
55        destination: &mut Self::AccumulatorItem,
56        source: &Self::AccumulatorItem,
57    );
58
59    /// If `use_planes` is `true`, reduce all the `item` and `coordinate` within the `accumulator`.
60    /// Else, reduce the given `item` and `coordinate` into the accumulator.
61    fn reduce(
62        this: &Self,
63        accumulator: &Self::AccumulatorItem,
64        item: Line<In>,
65        coordinate: ReduceCoordinate,
66        #[comptime] use_planes: bool,
67    ) -> Self::AccumulatorItem;
68
69    /// Reduce two accumulators into a single accumulator.
70    fn fuse_accumulators(
71        this: &Self,
72        lhs: Self::AccumulatorItem,
73        rhs: Self::AccumulatorItem,
74    ) -> Self::AccumulatorItem;
75
76    /// Reduce all elements of the accumulator into a single output element of type `Out`.
77    fn merge_line<Out: Numeric>(
78        this: &Self,
79        accumulator: Self::AccumulatorItem,
80        shape_axis_reduce: u32,
81    ) -> Out;
82
83    /// Convert each element of the accumulator into the expected output element of type `Out`.
84    fn to_output_perpendicular<Out: Numeric>(
85        this: &Self,
86        accumulator: Self::AccumulatorItem,
87        shape_axis_reduce: u32,
88    ) -> Line<Out>;
89}
90
91#[derive(CubeType)]
92pub enum ReduceCoordinate {
93    Required(Line<u32>),
94    NotRequired,
95}
96
97/// A simple trait that abstract over a single or multiple shared memory.
98#[cube]
99pub trait SharedAccumulator<In: Numeric>: CubeType + Send + Sync + 'static {
100    type Item: CubeType;
101
102    fn allocate(
103        #[comptime] length: u32,
104        #[comptime] line_size: u32,
105        #[comptime] _coordinate: bool,
106    ) -> Self;
107
108    fn read(accumulator: &Self, index: u32) -> Self::Item;
109
110    fn write(accumulator: &mut Self, index: u32, item: Self::Item);
111}
112
113#[cube]
114impl<In: Numeric> SharedAccumulator<In> for SharedMemory<Line<In>> {
115    type Item = Line<In>;
116
117    fn allocate(
118        #[comptime] length: u32,
119        #[comptime] line_size: u32,
120        #[comptime] _coordinate: bool,
121    ) -> Self {
122        SharedMemory::new_lined(length, line_size)
123    }
124
125    fn read(accumulator: &Self, index: u32) -> Self::Item {
126        accumulator[index]
127    }
128
129    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
130        accumulator[index] = item;
131    }
132}
133
134/// A pair of shared memory used for [`ArgMax`](super::ArgMax) and [`ArgMin`](super::ArgMin).
135#[derive(CubeType)]
136pub struct ArgAccumulator<N: Numeric> {
137    pub elements: SharedMemory<Line<N>>,
138    pub args: SharedMemory<Line<u32>>,
139}
140
141#[cube]
142impl<In: Numeric> SharedAccumulator<In> for ArgAccumulator<In> {
143    type Item = (Line<In>, Line<u32>);
144
145    fn allocate(
146        #[comptime] length: u32,
147        #[comptime] line_size: u32,
148        #[comptime] _coordinate: bool,
149    ) -> Self {
150        ArgAccumulator::<In> {
151            elements: SharedMemory::new_lined(length, line_size),
152            args: SharedMemory::new_lined(length, line_size),
153        }
154    }
155
156    fn read(accumulator: &Self, index: u32) -> Self::Item {
157        (accumulator.elements[index], accumulator.args[index])
158    }
159
160    fn write(accumulator: &mut Self, index: u32, item: Self::Item) {
161        accumulator.elements[index] = item.0;
162        accumulator.args[index] = item.1;
163    }
164}
165
166#[cube]
167pub fn reduce_inplace<In: Numeric, R: ReduceInstruction<In>>(
168    inst: &R,
169    accumulator: &mut R::AccumulatorItem,
170    item: Line<In>,
171    coordinate: ReduceCoordinate,
172    #[comptime] use_planes: bool,
173) {
174    let reduction = &R::reduce(inst, accumulator, item, coordinate, use_planes);
175    R::assign_accumulator(inst, accumulator, reduction);
176}
177
178#[cube]
179pub fn reduce_shared_inplace<In: Numeric, R: ReduceInstruction<In>>(
180    inst: &R,
181    accumulator: &mut R::SharedAccumulator,
182    index: u32,
183    item: Line<In>,
184    coordinate: ReduceCoordinate,
185    #[comptime] use_planes: bool,
186) {
187    let acc_item = R::SharedAccumulator::read(accumulator, index);
188    let reduction = R::reduce(inst, &acc_item, item, coordinate, use_planes);
189    R::SharedAccumulator::write(accumulator, index, reduction);
190}
191
192#[cube]
193pub fn fuse_accumulator_inplace<In: Numeric, R: ReduceInstruction<In>>(
194    inst: &R,
195    accumulator: &mut R::SharedAccumulator,
196    destination: u32,
197    origin: u32,
198) {
199    let fused = R::fuse_accumulators(
200        inst,
201        R::SharedAccumulator::read(accumulator, destination),
202        R::SharedAccumulator::read(accumulator, origin),
203    );
204    R::SharedAccumulator::write(accumulator, destination, fused);
205}