cubecl_reduce/
launch.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use crate::BoundChecksInner;
6use crate::args::ReduceArgs;
7use crate::args::TensorArgs;
8use crate::args::init_tensors;
9use crate::instructions::*;
10use crate::precision::ReducePrecision;
11use crate::primitives::*;
12use crate::{LineMode, ReduceConfig, ReduceStrategy};
13
14#[derive(Clone, Copy, Debug)]
15pub struct ReduceDtypes {
16    pub input: StorageType,
17    pub output: StorageType,
18    pub accumulation: StorageType,
19}
20
21/// Launch a reduce kernel. This function assumes that all parameters are already validated.
22/// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function
23/// with the appropriate assumptions.
24#[allow(clippy::too_many_arguments)]
25pub(crate) fn launch_reduce<Run: Runtime, Rd: ReduceFamily>(
26    client: &ComputeClient<Run::Server>,
27    input: TensorHandleRef<Run>,
28    output: TensorHandleRef<Run>,
29    axis: u32,
30    config: ReduceConfig,
31    strategy: ReduceStrategy,
32    dtypes: ReduceDtypes,
33    inst: Rd::Config,
34) {
35    let settings = ReduceParams {
36        shared: strategy.shared.then(|| {
37            if strategy.use_planes {
38                config.cube_dim.y
39            } else {
40                config.cube_dim.num_elems()
41            }
42        }),
43        use_planes: strategy.use_planes,
44        line_size_input: config.line_size_input,
45        line_size_output: config.line_size_output,
46        line_mode: config.line_mode,
47        bound_checks: config.bound_checks,
48        bound_checks_inner: config.bound_checks_inner,
49    };
50    unsafe {
51        reduce_kernel::launch_unchecked::<Rd, TensorArgs, Run>(
52            client,
53            config.cube_count,
54            config.cube_dim,
55            input.as_tensor_arg(config.line_size_input as u8),
56            output.as_tensor_arg(config.line_size_output as u8),
57            ScalarArg::new(axis),
58            settings,
59            inst,
60            dtypes.input,
61            dtypes.output,
62            dtypes.accumulation,
63        );
64    }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub struct ReduceParams {
69    pub shared: Option<u32>, // shared if Some(x) where x is the accumulator size.
70    pub use_planes: bool,
71    pub line_size_input: u32,
72    pub line_size_output: u32,
73    pub line_mode: LineMode,
74    pub bound_checks: bool,
75    pub bound_checks_inner: BoundChecksInner,
76}
77
78#[cube(launch_unchecked)]
79pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily, RA: ReduceArgs>(
80    input: &RA::Input<In>,
81    output: &mut RA::Output<Out>,
82    axis_reduce: u32,
83    #[comptime] params: ReduceParams,
84    #[comptime] config: R::Config,
85    #[define(In)] _input_dtype: StorageType,
86    #[define(Out)] _output_dtype: StorageType,
87    #[define(Acc)] _acc_dtype: StorageType,
88) {
89    let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
90    reduce_kernel_virtual::<In, Out, Acc, R>(&input, &mut output, axis_reduce, params, config);
91}
92
93#[cube]
94pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily>(
95    input: &VirtualTensor<In>,
96    output: &mut VirtualTensor<Out, ReadWrite>,
97    axis_reduce: u32,
98    #[comptime] params: ReduceParams,
99    #[comptime] config: R::Config,
100) {
101    let reduce_index = get_reduce_index(params);
102
103    #[allow(clippy::collapsible_if)]
104    if comptime![params.bound_checks] {
105        if reduce_index >= get_reduce_count(output.len() * params.line_size_output, params) {
106            terminate!();
107        }
108    }
109
110    reduce_kernel_inner::<(In, Acc), Out, R>(
111        input,
112        output,
113        axis_reduce,
114        reduce_index,
115        params,
116        config,
117    )
118}
119
120#[cube]
121fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
122    input: &VirtualTensor<P::EI>,
123    output: &mut VirtualTensor<Out, ReadWrite>,
124    axis_reduce: u32,
125    reduce_index: u32,
126    #[comptime] params: ReduceParams,
127    #[comptime] config: R::Config,
128) {
129    let range = ReduceRange::new::<P, Out>(reduce_index, input, output, axis_reduce, params);
130
131    let inst = &R::Instruction::<P>::from_config(config);
132    let accumulator = match comptime!((params.shared, params.use_planes)) {
133        (Some(accumulator_size), use_planes) => {
134            let mut accumulator = reduce_slice_shared::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
135                input,
136                inst,
137                range,
138                accumulator_size,
139                params.line_size_input,
140                params.line_mode,
141                use_planes,
142                params.bound_checks_inner,
143            );
144            sync_cube();
145            reduce_tree::<P, R::Instruction<P>>(inst, &mut accumulator, accumulator_size)
146        }
147        (None, true) => reduce_slice_plane::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
148            input,
149            inst,
150            range,
151            params.line_size_input,
152            params.line_mode,
153            params.bound_checks_inner,
154        ),
155        (None, false) => reduce_slice::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
156            input,
157            range,
158            inst,
159            params.line_size_input,
160            params.line_mode,
161        ),
162    };
163
164    if elected_writer(params) {
165        write_to_output::<P, Out, R::Instruction<P>>(
166            output,
167            accumulator,
168            reduce_index,
169            input.shape(axis_reduce),
170            params,
171            inst,
172        );
173    }
174}
175
176#[cube]
177fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
178    if params.shared.is_some() {
179        CUBE_POS
180    } else if params.use_planes {
181        CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
182    } else {
183        ABSOLUTE_POS
184    }
185}
186
187#[cube]
188fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
189    match comptime!(params.line_mode) {
190        LineMode::Parallel => output_size,
191        LineMode::Perpendicular => output_size / params.line_size_input,
192    }
193}
194
195#[cube]
196fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
197    if settings.shared.is_some() {
198        UNIT_POS == 0
199    } else if settings.use_planes {
200        UNIT_POS_X == 0
201    } else {
202        true.runtime()
203    }
204}
205
206#[cube]
207fn write_to_output<P: ReducePrecision, Out: Numeric, R: ReduceInstruction<P>>(
208    output: &mut VirtualTensor<Out, ReadWrite>,
209    accumulator: R::AccumulatorItem,
210    reduce_index: u32,
211    shape_axis_reduce: u32,
212    #[comptime] settings: ReduceParams,
213    inst: &R,
214) {
215    match comptime!(settings.line_mode) {
216        LineMode::Parallel => {
217            let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
218            output.write(reduce_index, Line::cast_from(result))
219        }
220        LineMode::Perpendicular => {
221            let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
222
223            if comptime![settings.line_size_output == settings.line_size_input] {
224                output.write(reduce_index, out);
225            } else {
226                let num_iters = comptime![settings.line_size_input / settings.line_size_output];
227
228                #[unroll]
229                for i in 0..num_iters {
230                    let mut tmp = Line::empty(settings.line_size_output);
231
232                    #[unroll]
233                    for j in 0..settings.line_size_output {
234                        tmp[j] = out[i * settings.line_size_output + j];
235                    }
236
237                    let index = num_iters * reduce_index + i;
238                    output.write(index, tmp);
239                }
240            }
241        }
242    }
243}