cubecl_reduce/
primitives.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use crate::BoundChecksInner;
6use crate::LineMode;
7use crate::ReduceParams;
8use crate::instructions::*;
9use crate::precision::ReducePrecision;
10
11/// A simple range to specify how to iterate a slice when performing a reduction.
12#[derive(CubeType)]
13pub struct ReduceRange {
14    pub index_start: u32,
15    pub index_step: u32,
16    pub coordinate_start: u32,
17    pub coordinate_end: u32,
18    pub coordinate_step: u32,
19}
20
21#[cube]
22impl ReduceRange {
23    pub(crate) fn new<P: ReducePrecision, Out: Numeric>(
24        reduce_index: u32,
25        input: &VirtualTensor<P::EI>,
26        output: &mut VirtualTensor<Out, ReadWrite>,
27        axis_reduce: u32,
28        #[comptime] params: ReduceParams,
29    ) -> ReduceRange {
30        match comptime!(params.line_mode) {
31            LineMode::Parallel => {
32                Self::new_parallel::<P, Out>(reduce_index, input, output, axis_reduce, params)
33            }
34            LineMode::Perpendicular => {
35                Self::new_perpendicular::<P, Out>(reduce_index, input, output, axis_reduce, params)
36            }
37        }
38    }
39
40    fn new_parallel<P: ReducePrecision, Out: Numeric>(
41        reduce_index: u32,
42        input: &VirtualTensor<P::EI>,
43        output: &mut VirtualTensor<Out, ReadWrite>,
44        axis_reduce: u32,
45        #[comptime] params: ReduceParams,
46    ) -> ReduceRange {
47        let shape_axis = input.shape(axis_reduce);
48
49        let mut index_start = 0;
50        for axis in 0..input.rank() {
51            let coordinate = output.coordinate(reduce_index, axis);
52            index_start += coordinate * input.stride(axis);
53        }
54        index_start /= params.line_size_input;
55
56        let coordinate_end = shape_axis;
57
58        let coordinate_step = if params.shared.is_some() {
59            CUBE_DIM * params.line_size_input
60        } else if params.use_planes {
61            CUBE_DIM_X * params.line_size_input
62        } else {
63            params.line_size_input.runtime()
64        };
65
66        ReduceRange {
67            index_start,
68            index_step: 1,
69            coordinate_start: 0,
70            coordinate_end,
71            coordinate_step,
72        }
73    }
74
75    fn new_perpendicular<P: ReducePrecision, Out: Numeric>(
76        reduce_index: u32,
77        input: &VirtualTensor<P::EI>,
78        output: &mut VirtualTensor<Out, ReadWrite>,
79        axis_reduce: u32,
80        #[comptime] params: ReduceParams,
81    ) -> ReduceRange {
82        let shape_axis = input.shape(axis_reduce);
83
84        let mut index_start = 0;
85        for axis in 0..input.rank() {
86            let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
87            index_start += coordinate * input.stride(axis);
88        }
89        index_start /= params.line_size_input;
90
91        let index_step = input.stride(axis_reduce) / params.line_size_input;
92
93        let coordinate_end = shape_axis;
94
95        let coordinate_step = if params.shared.is_some() {
96            CUBE_DIM
97        } else if params.use_planes {
98            CUBE_DIM_X
99        } else {
100            1_u32.runtime()
101        };
102
103        ReduceRange {
104            index_start,
105            index_step,
106            coordinate_start: 0,
107            coordinate_step,
108            coordinate_end,
109        }
110    }
111}
112
113/// Use an individual unit to reduce the `items` with the specified range.
114/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
115/// until `items[range.end]` (exclusive).
116///
117/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
118///
119/// Since each individual unit performs a reduction, this function is meant to be called
120/// with either a different `items` for each unit, a different `range` or both based on ABSOLUTE_UNIT_POS.
121#[cube]
122pub fn reduce_slice<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
123    items: &I,
124    range: ReduceRange,
125    inst: &R,
126    #[comptime] line_size: u32,
127    #[comptime] line_mode: LineMode,
128) -> R::AccumulatorItem {
129    let mut accumulator = R::null_accumulator(inst, line_size);
130
131    let mut index = range.index_start;
132    for coordinate in range_stepped(
133        range.coordinate_start,
134        range.coordinate_end,
135        range.coordinate_step,
136    ) {
137        let requirements = R::requirements(inst);
138        let coordinates = if comptime![requirements.coordinates] {
139            ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
140        } else {
141            ReduceCoordinate::new_NotRequired()
142        };
143        reduce_inplace::<P, R>(
144            inst,
145            &mut accumulator,
146            items.read(index),
147            coordinates,
148            false,
149        );
150        index += range.index_step;
151    }
152
153    accumulator
154}
155
156/// Use an individual plane  to reduce the `items` with the specified range.
157/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
158/// until `items[range.end]` (exclusive).
159///
160/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
161///
162/// This assumes that `UNIT_POS_X` provides the index of unit with a plane and that `CUBE_DIM_X` is the plane dimension.
163/// That is, the cube_dim is `CubeDim::new_2d(plane_dim, plane_count)`.
164///
165/// Since each individual plane performs a reduction, this function is meant to be called
166/// with either a different `items` for each plane, a different `range` or both based on
167/// the absolute plane position (`CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y`).
168#[cube]
169pub fn reduce_slice_plane<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
170    items: &I,
171    inst: &R,
172    range: ReduceRange,
173    #[comptime] line_size: u32,
174    #[comptime] line_mode: LineMode,
175    #[comptime] bound_checks: BoundChecksInner,
176) -> R::AccumulatorItem {
177    let plane_dim = CUBE_DIM_X;
178
179    let mut accumulator = R::null_accumulator(inst, line_size);
180
181    let mut first_index = range.index_start;
182    for first_coordinate in range_stepped(
183        range.coordinate_start,
184        range.coordinate_end,
185        range.coordinate_step,
186    ) {
187        let unit_coordinate_offset = match line_mode {
188            LineMode::Parallel => UNIT_POS_X * line_size,
189            LineMode::Perpendicular => UNIT_POS_X,
190        };
191        let unit_coordinate = first_coordinate + unit_coordinate_offset;
192
193        let requirements = R::requirements(inst);
194        let coordinates = if comptime![requirements.coordinates] {
195            ReduceCoordinate::new_Required(fill_coordinate_line(
196                unit_coordinate,
197                line_size,
198                line_mode,
199            ))
200        } else {
201            ReduceCoordinate::new_NotRequired()
202        };
203
204        let index = first_index + UNIT_POS_X * range.index_step;
205        let item = match bound_checks {
206            BoundChecksInner::None => items.read(index),
207            BoundChecksInner::Mask => {
208                let mask = unit_coordinate < range.coordinate_end;
209                let index = index * u32::cast_from(mask);
210                select(mask, items.read(index), R::null_input(inst, line_size))
211            }
212            BoundChecksInner::Branch => {
213                if unit_coordinate < range.coordinate_end {
214                    items.read(index)
215                } else {
216                    R::null_input(inst, line_size)
217                }
218            }
219        };
220
221        reduce_inplace::<P, R>(inst, &mut accumulator, item, coordinates, true);
222
223        first_index += plane_dim * range.index_step;
224    }
225    accumulator
226}
227
228/// Use an individual cube to reduce the `items` with the specified range.
229/// That is, this will reduces `items[range.start]`, `items[range.start + range.step]`
230/// until `items[range.end]` (exclusive). Inside a cube, the reduction will use plane operations
231/// if `use_planes` is set to `true`.
232///
233/// This reduces using the given `line_mode` but doesn't reduce the accumulator itself.
234///
235/// When `use_planes` is `true`, this assumes that `UNIT_POS_Y` provides the relative position
236/// of a plane within its cube.
237///
238/// Since each individual cube performs a reduction, this function is meant to be called
239/// with either a different `items` for each cube, a different `range` or both based on `CUBE_POS`.
240#[cube]
241pub fn reduce_slice_shared<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
242    items: &I,
243    inst: &R,
244    range: ReduceRange,
245    #[comptime] accumulator_size: u32,
246    #[comptime] line_size: u32,
247    #[comptime] line_mode: LineMode,
248    #[comptime] use_planes: bool,
249    #[comptime] bound_checks: BoundChecksInner,
250) -> R::SharedAccumulator {
251    // The index used to read and write into the accumulator.
252    let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
253
254    let requirements = R::requirements(inst);
255    let mut accumulator =
256        R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
257
258    R::SharedAccumulator::write(
259        &mut accumulator,
260        accumulator_index,
261        R::null_accumulator(inst, line_size),
262    );
263
264    let mut first_index = range.index_start;
265    for first_coordinate in range_stepped(
266        range.coordinate_start,
267        range.coordinate_end,
268        range.coordinate_step,
269    ) {
270        let unit_coordinate_offset = match line_mode {
271            LineMode::Parallel => UNIT_POS * line_size,
272            LineMode::Perpendicular => UNIT_POS,
273        };
274        let unit_coordinate = first_coordinate + unit_coordinate_offset;
275
276        let index = first_index + UNIT_POS * range.index_step;
277
278        let item = match bound_checks {
279            BoundChecksInner::None => items.read(index),
280            BoundChecksInner::Mask => {
281                let mask = unit_coordinate < range.coordinate_end;
282                let index = index * u32::cast_from(mask);
283                select(mask, items.read(index), R::null_input(inst, line_size))
284            }
285            BoundChecksInner::Branch => {
286                if unit_coordinate < range.coordinate_end {
287                    items.read(index)
288                } else {
289                    R::null_input(inst, line_size)
290                }
291            }
292        };
293
294        let coordinates = if comptime! {requirements.coordinates} {
295            let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
296            let coordinate = select(
297                unit_coordinate < range.coordinate_end,
298                coordinate,
299                Line::empty(line_size).fill(u32::MAX),
300            );
301
302            ReduceCoordinate::new_Required(coordinate)
303        } else {
304            ReduceCoordinate::new_NotRequired()
305        };
306
307        reduce_shared_inplace::<P, R>(
308            inst,
309            &mut accumulator,
310            accumulator_index,
311            item,
312            coordinates,
313            use_planes,
314        );
315        first_index += range.index_step * CUBE_DIM;
316    }
317    accumulator
318}
319
320// If line mode is parallel, fill a line with `x, x+1, ... x+ line_size - 1` where `x = first`.
321// If line mode is perpendicular, fill a line with `x, x, ... x` where `x = first`.
322#[cube]
323fn fill_coordinate_line(
324    first: u32,
325    #[comptime] line_size: u32,
326    #[comptime] line_mode: LineMode,
327) -> Line<u32> {
328    match comptime!(line_mode) {
329        LineMode::Parallel => {
330            let mut coordinates = Line::empty(line_size);
331            #[unroll]
332            for j in 0..line_size {
333                coordinates[j] = first + j;
334            }
335            coordinates
336        }
337        LineMode::Perpendicular => Line::empty(line_size).fill(first),
338    }
339}
340
341/// Use all units within a cube to fuse the first `size` elements of `accumulator` inplace like this with some padding if `size` is not a power of 2.
342///
343///
344/// ```ignored
345///
346///     0   1   2   3   4   5   6   7
347///     |   |   |   |   |   |   |   |
348///     +---+   +---+   +---+   +---+
349///     |       |       |       |
350///     +-------+       +-------+
351///     |               |
352///     +---------------+
353///     |
354///     *
355///
356/// ```
357///
358/// The outcome is stored in the first element of the accumulator and also returned by this function for convenience.
359///
360/// Since each individual cube performs a reduction, this function is meant to be called
361/// with a different `accumulator` for each cube based on `CUBE_POS`.
362///
363/// There is no out-of-bound check, so it is the responsibility of the caller to ensure that `size` is at most the length
364/// of the shared memory and that there are at least `size` units within each cube.
365#[cube]
366pub fn reduce_tree<P: ReducePrecision, Inst: ReduceInstruction<P>>(
367    inst: &Inst,
368    accumulator: &mut Inst::SharedAccumulator,
369    #[comptime] size: u32,
370) -> Inst::AccumulatorItem {
371    if comptime!(size.is_power_of_two()) {
372        let mut num_active_units = size.runtime();
373        let mut jump = 1;
374        while num_active_units > 1 {
375            num_active_units /= 2;
376            let destination = jump * 2 * UNIT_POS;
377            let origin = jump * (2 * UNIT_POS + 1);
378            if UNIT_POS < num_active_units {
379                fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
380            }
381            jump *= 2;
382            sync_cube();
383        }
384    } else {
385        let mut num_remaining_items = size.runtime();
386        let mut jump = 1;
387        while num_remaining_items > 1 {
388            let destination = jump * 2 * UNIT_POS;
389            let origin = jump * (2 * UNIT_POS + 1);
390            if UNIT_POS < num_remaining_items / 2 {
391                fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
392            }
393            num_remaining_items = num_remaining_items.div_ceil(2);
394            jump *= 2;
395            sync_cube();
396        }
397    }
398    sync_cube();
399    Inst::SharedAccumulator::read(accumulator, 0)
400}
401
402/// Warp-level sum reduction using shuffle operations.
403///
404/// All lanes get the sum of all 32 values in the warp using butterfly reduction:
405/// ```ignored
406/// Step 1 (offset=16): Lane 0 ← Lane 0 + Lane 16, Lane 1 ← Lane 1 + Lane 17, ...
407/// Step 2 (offset=8):  Lane 0 ← Lane 0 + Lane 8,  Lane 1 ← Lane 1 + Lane 9,  ...
408/// Step 3 (offset=4):  Lane 0 ← Lane 0 + Lane 4,  Lane 1 ← Lane 1 + Lane 5,  ...
409/// Step 4 (offset=2):  Lane 0 ← Lane 0 + Lane 2,  Lane 1 ← Lane 1 + Lane 3,  ...
410/// Step 5 (offset=1):  Lane 0 ← Lane 0 + Lane 1,  ...
411/// ```
412///
413/// # Performance
414/// - ~5 cycles per shuffle × 5 steps = ~25 cycles total
415/// - Compare to shared memory: ~110 (write) + ~110 (read) × log2(32) = ~1100+ cycles
416///
417/// # Example
418/// ```ignored
419/// #[cube]
420/// fn warp_sum_example(value: f32) -> f32 {
421///     reduce_sum_shuffle(value)  // All lanes get the sum
422/// }
423/// ```
424#[cube]
425pub fn reduce_sum_shuffle<F: Float>(value: F) -> F {
426    // Manually unrolled butterfly reduction
427    let v1 = value + plane_shuffle_xor(value, 16);
428    let v2 = v1 + plane_shuffle_xor(v1, 8);
429    let v3 = v2 + plane_shuffle_xor(v2, 4);
430    let v4 = v3 + plane_shuffle_xor(v3, 2);
431    v4 + plane_shuffle_xor(v4, 1)
432}
433
434/// Warp-level max reduction using shuffle operations.
435/// All lanes get the maximum of all 32 values in the warp.
436#[cube]
437pub fn reduce_max_shuffle<F: Float>(value: F) -> F {
438    let v1 = F::max(value, plane_shuffle_xor(value, 16));
439    let v2 = F::max(v1, plane_shuffle_xor(v1, 8));
440    let v3 = F::max(v2, plane_shuffle_xor(v2, 4));
441    let v4 = F::max(v3, plane_shuffle_xor(v3, 2));
442    F::max(v4, plane_shuffle_xor(v4, 1))
443}
444
445/// Warp-level min reduction using shuffle operations.
446/// All lanes get the minimum of all 32 values in the warp.
447#[cube]
448pub fn reduce_min_shuffle<F: Float>(value: F) -> F {
449    let v1 = F::min(value, plane_shuffle_xor(value, 16));
450    let v2 = F::min(v1, plane_shuffle_xor(v1, 8));
451    let v3 = F::min(v2, plane_shuffle_xor(v2, 4));
452    let v4 = F::min(v3, plane_shuffle_xor(v3, 2));
453    F::min(v4, plane_shuffle_xor(v4, 1))
454}
455
456/// Warp-level product reduction using shuffle operations.
457/// All lanes get the product of all 32 values in the warp.
458#[cube]
459pub fn reduce_prod_shuffle<F: Float>(value: F) -> F {
460    let v1 = value * plane_shuffle_xor(value, 16);
461    let v2 = v1 * plane_shuffle_xor(v1, 8);
462    let v3 = v2 * plane_shuffle_xor(v2, 4);
463    let v4 = v3 * plane_shuffle_xor(v3, 2);
464    v4 * plane_shuffle_xor(v4, 1)
465}