cubecl_reduce/
launch.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::args::ReduceArgs;
8use crate::args::TensorArgs;
9use crate::args::init_tensors;
10use crate::instructions::*;
11use crate::primitives::*;
12use crate::{LineMode, ReduceConfig, ReduceStrategy};
13
14/// Launch a reduce kernel. This function assumes that all parameters are already validated.
15/// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function
16/// with the appropriate assumptions.
17pub(crate) fn launch_reduce<Run: Runtime, In: Numeric, Out: Numeric, Rd: ReduceFamily>(
18    client: &ComputeClient<Run::Server, Run::Channel>,
19    input: TensorHandleRef<Run>,
20    output: TensorHandleRef<Run>,
21    axis: u32,
22    config: ReduceConfig,
23    strategy: ReduceStrategy,
24    inst: Rd::Config,
25) {
26    let settings = ReduceParams {
27        shared: strategy.shared.then(|| {
28            if strategy.use_planes {
29                config.cube_dim.y
30            } else {
31                config.cube_dim.num_elems()
32            }
33        }),
34        use_planes: strategy.use_planes,
35        line_size_input: config.line_size_input,
36        line_size_output: config.line_size_output,
37        line_mode: config.line_mode,
38        bound_checks: config.bound_checks,
39        bound_checks_inner: config.bound_checks_inner,
40    };
41    unsafe {
42        reduce_kernel::launch_unchecked::<In, Out, Rd, TensorArgs, Run>(
43            client,
44            config.cube_count,
45            config.cube_dim,
46            input.as_tensor_arg(config.line_size_input as u8),
47            output.as_tensor_arg(config.line_size_output as u8),
48            ScalarArg::new(axis),
49            settings,
50            inst,
51        );
52    }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct ReduceParams {
57    pub shared: Option<u32>, // shared if Some(x) where x is the accumulator size.
58    pub use_planes: bool,
59    pub line_size_input: u32,
60    pub line_size_output: u32,
61    pub line_mode: LineMode,
62    pub bound_checks: bool,
63    pub bound_checks_inner: BoundChecksInner,
64}
65
66#[cube(launch_unchecked)]
67pub fn reduce_kernel<In: Numeric, Out: Numeric, R: ReduceFamily, RA: ReduceArgs>(
68    input: &RA::Input<In>,
69    output: &mut RA::Output<Out>,
70    axis_reduce: u32,
71    #[comptime] params: ReduceParams,
72    #[comptime] config: R::Config,
73) {
74    let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
75    let reduce_index = get_reduce_index(params);
76
77    if comptime![params.bound_checks]
78        && reduce_index >= get_reduce_count(output.len() * params.line_size_output, params)
79    {
80        terminate!();
81    }
82
83    let range = ReduceRange::new::<In, Out>(reduce_index, &input, &mut output, axis_reduce, params);
84
85    let inst = &R::Instruction::<In>::from_config(config);
86    let accumulator = match comptime!((params.shared, params.use_planes)) {
87        (Some(accumulator_size), use_planes) => {
88            let mut accumulator = reduce_slice_shared::<In, VirtualTensor<In>, R::Instruction<In>>(
89                &input,
90                inst,
91                range,
92                accumulator_size,
93                params.line_size_input,
94                params.line_mode,
95                use_planes,
96                params.bound_checks_inner,
97            );
98            sync_units();
99            reduce_tree::<In, R::Instruction<In>>(inst, &mut accumulator, accumulator_size)
100        }
101        (None, true) => reduce_slice_plane::<In, VirtualTensor<In>, R::Instruction<In>>(
102            &input,
103            inst,
104            range,
105            params.line_size_input,
106            params.line_mode,
107            params.bound_checks_inner,
108        ),
109        (None, false) => reduce_slice::<In, VirtualTensor<In>, R::Instruction<In>>(
110            &input,
111            range,
112            inst,
113            params.line_size_input,
114            params.line_mode,
115        ),
116    };
117
118    if elected_writer(params) {
119        write_to_output::<In, Out, R::Instruction<In>>(
120            &mut output,
121            accumulator,
122            reduce_index,
123            input.shape(axis_reduce),
124            params,
125            inst,
126        );
127    }
128}
129
130#[cube]
131fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
132    if params.shared.is_some() {
133        CUBE_POS
134    } else if params.use_planes {
135        CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
136    } else {
137        ABSOLUTE_POS
138    }
139}
140
141#[cube]
142fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
143    match comptime!(params.line_mode) {
144        LineMode::Parallel => output_size,
145        LineMode::Perpendicular => output_size / params.line_size_input,
146    }
147}
148
149#[cube]
150fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
151    if settings.shared.is_some() {
152        UNIT_POS == 0
153    } else if settings.use_planes {
154        UNIT_POS_X == 0
155    } else {
156        true.runtime()
157    }
158}
159
160#[cube]
161fn write_to_output<In: Numeric, Out: Numeric, R: ReduceInstruction<In>>(
162    output: &mut VirtualTensor<Out, ReadWrite>,
163    accumulator: R::AccumulatorItem,
164    reduce_index: u32,
165    shape_axis_reduce: u32,
166    #[comptime] settings: ReduceParams,
167    inst: &R,
168) {
169    match comptime!(settings.line_mode) {
170        LineMode::Parallel => {
171            let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
172            output.write(reduce_index, Line::cast_from(result))
173        }
174        LineMode::Perpendicular => {
175            let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
176
177            if comptime![settings.line_size_output == settings.line_size_input] {
178                output.write(reduce_index, out);
179            } else {
180                let num_iters = comptime![settings.line_size_input / settings.line_size_output];
181
182                #[unroll]
183                for i in 0..num_iters {
184                    let mut tmp = Line::empty(settings.line_size_output);
185
186                    #[unroll]
187                    for j in 0..settings.line_size_output {
188                        tmp[j] = out[i * settings.line_size_output + j];
189                    }
190
191                    let index = num_iters * reduce_index + i;
192                    output.write(index, tmp);
193                }
194            }
195        }
196    }
197}