cubecl_reduce/
primitives.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::LineMode;
8use crate::ReduceParams;
9use crate::instructions::*;
10use crate::precision::ReducePrecision;
11
12/// A simple range to specify how to iterate a slice when performing a reduction.
13#[derive(CubeType)]
14pub struct ReduceRange {
15    pub index_start: u32,
16    pub index_step: u32,
17    pub coordinate_start: u32,
18    pub coordinate_end: u32,
19    pub coordinate_step: u32,
20}
21
22#[cube]
23impl ReduceRange {
24    pub(crate) fn new<P: ReducePrecision, Out: Numeric>(
25        reduce_index: u32,
26        input: &VirtualTensor<P::EI>,
27        output: &mut VirtualTensor<Out, ReadWrite>,
28        axis_reduce: u32,
29        #[comptime] params: ReduceParams,
30    ) -> ReduceRange {
31        match comptime!(params.line_mode) {
32            LineMode::Parallel => {
33                Self::new_parallel::<P, Out>(reduce_index, input, output, axis_reduce, params)
34            }
35            LineMode::Perpendicular => {
36                Self::new_perpendicular::<P, Out>(reduce_index, input, output, axis_reduce, params)
37            }
38        }
39    }
40
41    fn new_parallel<P: ReducePrecision, Out: Numeric>(
42        reduce_index: u32,
43        input: &VirtualTensor<P::EI>,
44        output: &mut VirtualTensor<Out, ReadWrite>,
45        axis_reduce: u32,
46        #[comptime] params: ReduceParams,
47    ) -> ReduceRange {
48        let shape_axis = input.shape(axis_reduce);
49
50        let mut index_start = 0;
51        for axis in 0..input.rank() {
52            let coordinate = output.coordinate(reduce_index, axis);
53            index_start += coordinate * input.stride(axis);
54        }
55        index_start /= params.line_size_input;
56
57        let coordinate_end = shape_axis;
58
59        let coordinate_step = if params.shared.is_some() {
60            CUBE_DIM * params.line_size_input
61        } else if params.use_planes {
62            CUBE_DIM_X * params.line_size_input
63        } else {
64            params.line_size_input.runtime()
65        };
66
67        ReduceRange {
68            index_start,
69            index_step: 1,
70            coordinate_start: 0,
71            coordinate_end,
72            coordinate_step,
73        }
74    }
75
76    fn new_perpendicular<P: ReducePrecision, Out: Numeric>(
77        reduce_index: u32,
78        input: &VirtualTensor<P::EI>,
79        output: &mut VirtualTensor<Out, ReadWrite>,
80        axis_reduce: u32,
81        #[comptime] params: ReduceParams,
82    ) -> ReduceRange {
83        let shape_axis = input.shape(axis_reduce);
84
85        let mut index_start = 0;
86        for axis in 0..input.rank() {
87            let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
88            index_start += coordinate * input.stride(axis);
89        }
90        index_start /= params.line_size_input;
91
92        let index_step = input.stride(axis_reduce) / params.line_size_input;
93
94        let coordinate_end = shape_axis;
95
96        let coordinate_step = if params.shared.is_some() {
97            CUBE_DIM
98        } else if params.use_planes {
99            CUBE_DIM_X
100        } else {
101            1_u32.runtime()
102        };
103
104        ReduceRange {
105            index_start,
106            index_step,
107            coordinate_start: 0,
108            coordinate_step,
109            coordinate_end,
110        }
111    }
112}
113
114/// Use an individual unit to reduce the `items` with the specified range.
115/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
116/// until `items[range.end]` (exclusive).
117///
118/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
119///
120/// Since each individual unit performs a reduction, this function is meant to be called
121/// with either a different `items` for each unit, a different `range` or both based on ABSOLUTE_UNIT_POS.
122#[cube]
123pub fn reduce_slice<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
124    items: &I,
125    range: ReduceRange,
126    inst: &R,
127    #[comptime] line_size: u32,
128    #[comptime] line_mode: LineMode,
129) -> R::AccumulatorItem {
130    let mut accumulator = R::null_accumulator(inst, line_size);
131
132    let mut index = range.index_start;
133    for coordinate in range_stepped(
134        range.coordinate_start,
135        range.coordinate_end,
136        range.coordinate_step,
137    ) {
138        let requirements = R::requirements(inst);
139        let coordinates = if comptime![requirements.coordinates] {
140            ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
141        } else {
142            ReduceCoordinate::new_NotRequired()
143        };
144        reduce_inplace::<P, R>(
145            inst,
146            &mut accumulator,
147            items.read(index),
148            coordinates,
149            false,
150        );
151        index += range.index_step;
152    }
153
154    accumulator
155}
156
157/// Use an individual plane  to reduce the `items` with the specified range.
158/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
159/// until `items[range.end]` (exclusive).
160///
161/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
162///
163/// This assumes that `UNIT_POS_X` provides the index of unit with a plane and that `CUBE_DIM_X` is the plane dimension.
164/// That is, the cube_dim is `CubeDim::new_2d(plane_dim, plane_count)`.
165///
166/// Since each individual plane performs a reduction, this function is meant to be called
167/// with either a different `items` for each plane, a different `range` or both based on
168/// the absolute plane position (`CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y`).
169#[cube]
170pub fn reduce_slice_plane<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
171    items: &I,
172    inst: &R,
173    range: ReduceRange,
174    #[comptime] line_size: u32,
175    #[comptime] line_mode: LineMode,
176    #[comptime] bound_checks: BoundChecksInner,
177) -> R::AccumulatorItem {
178    let plane_dim = CUBE_DIM_X;
179
180    let mut accumulator = R::null_accumulator(inst, line_size);
181
182    let mut first_index = range.index_start;
183    for first_coordinate in range_stepped(
184        range.coordinate_start,
185        range.coordinate_end,
186        range.coordinate_step,
187    ) {
188        let unit_coordinate_offset = match line_mode {
189            LineMode::Parallel => UNIT_POS_X * line_size,
190            LineMode::Perpendicular => UNIT_POS_X,
191        };
192        let unit_coordinate = first_coordinate + unit_coordinate_offset;
193
194        let requirements = R::requirements(inst);
195        let coordinates = if comptime![requirements.coordinates] {
196            ReduceCoordinate::new_Required(fill_coordinate_line(
197                unit_coordinate,
198                line_size,
199                line_mode,
200            ))
201        } else {
202            ReduceCoordinate::new_NotRequired()
203        };
204
205        let index = first_index + UNIT_POS_X * range.index_step;
206        let item = match bound_checks {
207            BoundChecksInner::None => items.read(index),
208            BoundChecksInner::Mask => {
209                let mask = unit_coordinate < range.coordinate_end;
210                let index = index * u32::cast_from(mask);
211                select(mask, items.read(index), R::null_input(inst, line_size))
212            }
213            BoundChecksInner::Branch => {
214                if unit_coordinate < range.coordinate_end {
215                    items.read(index)
216                } else {
217                    R::null_input(inst, line_size)
218                }
219            }
220        };
221
222        reduce_inplace::<P, R>(inst, &mut accumulator, item, coordinates, true);
223
224        first_index += plane_dim * range.index_step;
225    }
226    accumulator
227}
228
229/// Use an individual cube to reduce the `items` with the specified range.
230/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
231/// until `items[range.end]` (exclusive). Inside a cube, the reduction will use plane operations
232/// if `use_planes` is set to `true`.
233///
234/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
235///
236/// When `use_planes` is `true`, this assumes that `UNIT_POS_Y` provides the relative position
237/// of a plane within its cube.
238///
239/// Since each individual cube performs a reduction, this function is meant to be called
240/// with either a different `items` for each cube, a different `range` or both based on `CUBE_POS`.
241#[cube]
242pub fn reduce_slice_shared<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
243    items: &I,
244    inst: &R,
245    range: ReduceRange,
246    #[comptime] accumulator_size: u32,
247    #[comptime] line_size: u32,
248    #[comptime] line_mode: LineMode,
249    #[comptime] use_planes: bool,
250    #[comptime] bound_checks: BoundChecksInner,
251) -> R::SharedAccumulator {
252    // The index used to read and write into the accumulator.
253    let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
254
255    let requirements = R::requirements(inst);
256    let mut accumulator =
257        R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
258
259    R::SharedAccumulator::write(
260        &mut accumulator,
261        accumulator_index,
262        R::null_accumulator(inst, line_size),
263    );
264
265    let mut first_index = range.index_start;
266    for first_coordinate in range_stepped(
267        range.coordinate_start,
268        range.coordinate_end,
269        range.coordinate_step,
270    ) {
271        let unit_coordinate_offset = match line_mode {
272            LineMode::Parallel => UNIT_POS * line_size,
273            LineMode::Perpendicular => UNIT_POS,
274        };
275        let unit_coordinate = first_coordinate + unit_coordinate_offset;
276
277        let index = first_index + UNIT_POS * range.index_step;
278
279        let item = match bound_checks {
280            BoundChecksInner::None => items.read(index),
281            BoundChecksInner::Mask => {
282                let mask = unit_coordinate < range.coordinate_end;
283                let index = index * u32::cast_from(mask);
284                select(mask, items.read(index), R::null_input(inst, line_size))
285            }
286            BoundChecksInner::Branch => {
287                if unit_coordinate < range.coordinate_end {
288                    items.read(index)
289                } else {
290                    R::null_input(inst, line_size)
291                }
292            }
293        };
294
295        let coordinates = if comptime! {requirements.coordinates} {
296            let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
297            let coordinate = select(
298                unit_coordinate < range.coordinate_end,
299                coordinate,
300                Line::empty(line_size).fill(u32::MAX),
301            );
302
303            ReduceCoordinate::new_Required(coordinate)
304        } else {
305            ReduceCoordinate::new_NotRequired()
306        };
307
308        reduce_shared_inplace::<P, R>(
309            inst,
310            &mut accumulator,
311            accumulator_index,
312            item,
313            coordinates,
314            use_planes,
315        );
316        first_index += range.index_step * CUBE_DIM;
317    }
318    accumulator
319}
320
321// If line mode is parallel, fill a line with `x, x+1, ... x+ line_size - 1` where `x = first`.
322// If line mode is perpendicular, fill a line with `x, x, ... x` where `x = first`.
323#[cube]
324fn fill_coordinate_line(
325    first: u32,
326    #[comptime] line_size: u32,
327    #[comptime] line_mode: LineMode,
328) -> Line<u32> {
329    match comptime!(line_mode) {
330        LineMode::Parallel => {
331            let mut coordinates = Line::empty(line_size);
332            #[unroll]
333            for j in 0..line_size {
334                coordinates[j] = first + j;
335            }
336            coordinates
337        }
338        LineMode::Perpendicular => Line::empty(line_size).fill(first),
339    }
340}
341
342/// Use all units within a cube to fuse the first `size` elements of `accumulator` inplace like this with some padding if `size` is not a power of 2.
343///
344///
345/// ```ignored
346///
347///     0   1   2   3   4   5   6   7
348///     |   |   |   |   |   |   |   |
349///     +---+   +---+   +---+   +---+
350///     |       |       |       |
351///     +-------+       +-------+
352///     |               |
353///     +---------------+
354///     |
355///     *
356///
357/// ```
358///
359/// The outcome is stored in the first element of the accumulator and also returned by this function for convenience.
360///
361/// Since each individual cube performs a reduction, this function is meant to be called
362/// with a different `accumulator` for each cube based on `CUBE_POS`.
363///
364/// There is no out-of-bound check, so it is the responsibility of the caller to ensure that `size` is at most the length
365/// of the shared memory and that there are at least `size` units within each cube.
366#[cube]
367pub fn reduce_tree<P: ReducePrecision, Inst: ReduceInstruction<P>>(
368    inst: &Inst,
369    accumulator: &mut Inst::SharedAccumulator,
370    #[comptime] size: u32,
371) -> Inst::AccumulatorItem {
372    if comptime!(size.is_power_of_two()) {
373        let mut num_active_units = size.runtime();
374        let mut jump = 1;
375        while num_active_units > 1 {
376            num_active_units /= 2;
377            let destination = jump * 2 * UNIT_POS;
378            let origin = jump * (2 * UNIT_POS + 1);
379            if UNIT_POS < num_active_units {
380                fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
381            }
382            jump *= 2;
383            sync_cube();
384        }
385    } else {
386        let mut num_remaining_items = size.runtime();
387        let mut jump = 1;
388        while num_remaining_items > 1 {
389            let destination = jump * 2 * UNIT_POS;
390            let origin = jump * (2 * UNIT_POS + 1);
391            if UNIT_POS < num_remaining_items / 2 {
392                fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
393            }
394            num_remaining_items = div_ceil(num_remaining_items, 2);
395            jump *= 2;
396            sync_cube();
397        }
398    }
399    sync_cube();
400    Inst::SharedAccumulator::read(accumulator, 0)
401}
402
403#[cube]
404#[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
405#[allow(clippy::manual_div_ceil)]
406fn div_ceil(a: u32, b: u32) -> u32 {
407    (a + b - 1) / b
408}