cubecl_reduce/
launch.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::args::ReduceArgs;
8use crate::args::TensorArgs;
9use crate::args::init_tensors;
10use crate::instructions::*;
11use crate::precision::ReducePrecision;
12use crate::primitives::*;
13use crate::{LineMode, ReduceConfig, ReduceStrategy};
14
15/// Launch a reduce kernel. This function assumes that all parameters are already validated.
16/// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function
17/// with the appropriate assumptions.
18pub(crate) fn launch_reduce<Run: Runtime, P: ReducePrecision, Out: Numeric, Rd: ReduceFamily>(
19    client: &ComputeClient<Run::Server, Run::Channel>,
20    input: TensorHandleRef<Run>,
21    output: TensorHandleRef<Run>,
22    axis: u32,
23    config: ReduceConfig,
24    strategy: ReduceStrategy,
25    inst: Rd::Config,
26) {
27    let settings = ReduceParams {
28        shared: strategy.shared.then(|| {
29            if strategy.use_planes {
30                config.cube_dim.y
31            } else {
32                config.cube_dim.num_elems()
33            }
34        }),
35        use_planes: strategy.use_planes,
36        line_size_input: config.line_size_input,
37        line_size_output: config.line_size_output,
38        line_mode: config.line_mode,
39        bound_checks: config.bound_checks,
40        bound_checks_inner: config.bound_checks_inner,
41    };
42    unsafe {
43        reduce_kernel::launch_unchecked::<P::EI, Out, P::EA, Rd, TensorArgs, Run>(
44            client,
45            config.cube_count,
46            config.cube_dim,
47            input.as_tensor_arg(config.line_size_input as u8),
48            output.as_tensor_arg(config.line_size_output as u8),
49            ScalarArg::new(axis),
50            settings,
51            inst,
52        );
53    }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct ReduceParams {
58    pub shared: Option<u32>, // shared if Some(x) where x is the accumulator size.
59    pub use_planes: bool,
60    pub line_size_input: u32,
61    pub line_size_output: u32,
62    pub line_mode: LineMode,
63    pub bound_checks: bool,
64    pub bound_checks_inner: BoundChecksInner,
65}
66
67#[cube(launch_unchecked)]
68pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily, RA: ReduceArgs>(
69    input: &RA::Input<In>,
70    output: &mut RA::Output<Out>,
71    axis_reduce: u32,
72    #[comptime] params: ReduceParams,
73    #[comptime] config: R::Config,
74) {
75    let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
76    reduce_kernel_virtual::<In, Out, Acc, R>(&input, &mut output, axis_reduce, params, config);
77}
78
79#[cube]
80pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily>(
81    input: &VirtualTensor<In>,
82    output: &mut VirtualTensor<Out, ReadWrite>,
83    axis_reduce: u32,
84    #[comptime] params: ReduceParams,
85    #[comptime] config: R::Config,
86) {
87    let reduce_index = get_reduce_index(params);
88
89    if comptime![params.bound_checks]
90        && reduce_index >= get_reduce_count(output.len() * params.line_size_output, params)
91    {
92        terminate!();
93    }
94
95    reduce_kernel_inner::<(In, Acc), Out, R>(
96        input,
97        output,
98        axis_reduce,
99        reduce_index,
100        params,
101        config,
102    )
103}
104
105#[cube]
106fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
107    input: &VirtualTensor<P::EI>,
108    output: &mut VirtualTensor<Out, ReadWrite>,
109    axis_reduce: u32,
110    reduce_index: u32,
111    #[comptime] params: ReduceParams,
112    #[comptime] config: R::Config,
113) {
114    let range = ReduceRange::new::<P, Out>(reduce_index, input, output, axis_reduce, params);
115
116    let inst = &R::Instruction::<P>::from_config(config);
117    let accumulator = match comptime!((params.shared, params.use_planes)) {
118        (Some(accumulator_size), use_planes) => {
119            let mut accumulator = reduce_slice_shared::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
120                input,
121                inst,
122                range,
123                accumulator_size,
124                params.line_size_input,
125                params.line_mode,
126                use_planes,
127                params.bound_checks_inner,
128            );
129            sync_cube();
130            reduce_tree::<P, R::Instruction<P>>(inst, &mut accumulator, accumulator_size)
131        }
132        (None, true) => reduce_slice_plane::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
133            input,
134            inst,
135            range,
136            params.line_size_input,
137            params.line_mode,
138            params.bound_checks_inner,
139        ),
140        (None, false) => reduce_slice::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
141            input,
142            range,
143            inst,
144            params.line_size_input,
145            params.line_mode,
146        ),
147    };
148
149    if elected_writer(params) {
150        write_to_output::<P, Out, R::Instruction<P>>(
151            output,
152            accumulator,
153            reduce_index,
154            input.shape(axis_reduce),
155            params,
156            inst,
157        );
158    }
159}
160
161#[cube]
162fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
163    if params.shared.is_some() {
164        CUBE_POS
165    } else if params.use_planes {
166        CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
167    } else {
168        ABSOLUTE_POS
169    }
170}
171
172#[cube]
173fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
174    match comptime!(params.line_mode) {
175        LineMode::Parallel => output_size,
176        LineMode::Perpendicular => output_size / params.line_size_input,
177    }
178}
179
180#[cube]
181fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
182    if settings.shared.is_some() {
183        UNIT_POS == 0
184    } else if settings.use_planes {
185        UNIT_POS_X == 0
186    } else {
187        true.runtime()
188    }
189}
190
191#[cube]
192fn write_to_output<P: ReducePrecision, Out: Numeric, R: ReduceInstruction<P>>(
193    output: &mut VirtualTensor<Out, ReadWrite>,
194    accumulator: R::AccumulatorItem,
195    reduce_index: u32,
196    shape_axis_reduce: u32,
197    #[comptime] settings: ReduceParams,
198    inst: &R,
199) {
200    match comptime!(settings.line_mode) {
201        LineMode::Parallel => {
202            let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
203            output.write(reduce_index, Line::cast_from(result))
204        }
205        LineMode::Perpendicular => {
206            let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
207
208            if comptime![settings.line_size_output == settings.line_size_input] {
209                output.write(reduce_index, out);
210            } else {
211                let num_iters = comptime![settings.line_size_input / settings.line_size_output];
212
213                #[unroll]
214                for i in 0..num_iters {
215                    let mut tmp = Line::empty(settings.line_size_output);
216
217                    #[unroll]
218                    for j in 0..settings.line_size_output {
219                        tmp[j] = out[i * settings.line_size_output + j];
220                    }
221
222                    let index = num_iters * reduce_index + i;
223                    output.write(index, tmp);
224                }
225            }
226        }
227    }
228}