1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use crate::BoundChecksInner;
6use crate::args::ReduceArgs;
7use crate::args::TensorArgs;
8use crate::args::init_tensors;
9use crate::instructions::*;
10use crate::precision::ReducePrecision;
11use crate::primitives::*;
12use crate::{LineMode, ReduceConfig, ReduceStrategy};
13
14#[derive(Clone, Copy, Debug)]
15pub struct ReduceDtypes {
16 pub input: StorageType,
17 pub output: StorageType,
18 pub accumulation: StorageType,
19}
20
21#[allow(clippy::too_many_arguments)]
25pub(crate) fn launch_reduce<Run: Runtime, Rd: ReduceFamily>(
26 client: &ComputeClient<Run::Server>,
27 input: TensorHandleRef<Run>,
28 output: TensorHandleRef<Run>,
29 axis: u32,
30 config: ReduceConfig,
31 strategy: ReduceStrategy,
32 dtypes: ReduceDtypes,
33 inst: Rd::Config,
34) {
35 let settings = ReduceParams {
36 shared: strategy.shared.then(|| {
37 if strategy.use_planes {
38 config.cube_dim.y
39 } else {
40 config.cube_dim.num_elems()
41 }
42 }),
43 use_planes: strategy.use_planes,
44 line_size_input: config.line_size_input,
45 line_size_output: config.line_size_output,
46 line_mode: config.line_mode,
47 bound_checks: config.bound_checks,
48 bound_checks_inner: config.bound_checks_inner,
49 };
50 unsafe {
51 reduce_kernel::launch_unchecked::<Rd, TensorArgs, Run>(
52 client,
53 config.cube_count,
54 config.cube_dim,
55 input.as_tensor_arg(config.line_size_input as u8),
56 output.as_tensor_arg(config.line_size_output as u8),
57 ScalarArg::new(axis),
58 settings,
59 inst,
60 dtypes.input,
61 dtypes.output,
62 dtypes.accumulation,
63 );
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub struct ReduceParams {
69 pub shared: Option<u32>, pub use_planes: bool,
71 pub line_size_input: u32,
72 pub line_size_output: u32,
73 pub line_mode: LineMode,
74 pub bound_checks: bool,
75 pub bound_checks_inner: BoundChecksInner,
76}
77
78#[cube(launch_unchecked)]
79pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily, RA: ReduceArgs>(
80 input: &RA::Input<In>,
81 output: &mut RA::Output<Out>,
82 axis_reduce: u32,
83 #[comptime] params: ReduceParams,
84 #[comptime] config: R::Config,
85 #[define(In)] _input_dtype: StorageType,
86 #[define(Out)] _output_dtype: StorageType,
87 #[define(Acc)] _acc_dtype: StorageType,
88) {
89 let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
90 reduce_kernel_virtual::<In, Out, Acc, R>(&input, &mut output, axis_reduce, params, config);
91}
92
93#[cube]
94pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily>(
95 input: &VirtualTensor<In>,
96 output: &mut VirtualTensor<Out, ReadWrite>,
97 axis_reduce: u32,
98 #[comptime] params: ReduceParams,
99 #[comptime] config: R::Config,
100) {
101 let reduce_index = get_reduce_index(params);
102
103 #[allow(clippy::collapsible_if)]
104 if comptime![params.bound_checks] {
105 if reduce_index >= get_reduce_count(output.len() * params.line_size_output, params) {
106 terminate!();
107 }
108 }
109
110 reduce_kernel_inner::<(In, Acc), Out, R>(
111 input,
112 output,
113 axis_reduce,
114 reduce_index,
115 params,
116 config,
117 )
118}
119
120#[cube]
121fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
122 input: &VirtualTensor<P::EI>,
123 output: &mut VirtualTensor<Out, ReadWrite>,
124 axis_reduce: u32,
125 reduce_index: u32,
126 #[comptime] params: ReduceParams,
127 #[comptime] config: R::Config,
128) {
129 let range = ReduceRange::new::<P, Out>(reduce_index, input, output, axis_reduce, params);
130
131 let inst = &R::Instruction::<P>::from_config(config);
132 let accumulator = match comptime!((params.shared, params.use_planes)) {
133 (Some(accumulator_size), use_planes) => {
134 let mut accumulator = reduce_slice_shared::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
135 input,
136 inst,
137 range,
138 accumulator_size,
139 params.line_size_input,
140 params.line_mode,
141 use_planes,
142 params.bound_checks_inner,
143 );
144 sync_cube();
145 reduce_tree::<P, R::Instruction<P>>(inst, &mut accumulator, accumulator_size)
146 }
147 (None, true) => reduce_slice_plane::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
148 input,
149 inst,
150 range,
151 params.line_size_input,
152 params.line_mode,
153 params.bound_checks_inner,
154 ),
155 (None, false) => reduce_slice::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
156 input,
157 range,
158 inst,
159 params.line_size_input,
160 params.line_mode,
161 ),
162 };
163
164 if elected_writer(params) {
165 write_to_output::<P, Out, R::Instruction<P>>(
166 output,
167 accumulator,
168 reduce_index,
169 input.shape(axis_reduce),
170 params,
171 inst,
172 );
173 }
174}
175
176#[cube]
177fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
178 if params.shared.is_some() {
179 CUBE_POS
180 } else if params.use_planes {
181 CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
182 } else {
183 ABSOLUTE_POS
184 }
185}
186
187#[cube]
188fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
189 match comptime!(params.line_mode) {
190 LineMode::Parallel => output_size,
191 LineMode::Perpendicular => output_size / params.line_size_input,
192 }
193}
194
195#[cube]
196fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
197 if settings.shared.is_some() {
198 UNIT_POS == 0
199 } else if settings.use_planes {
200 UNIT_POS_X == 0
201 } else {
202 true.runtime()
203 }
204}
205
206#[cube]
207fn write_to_output<P: ReducePrecision, Out: Numeric, R: ReduceInstruction<P>>(
208 output: &mut VirtualTensor<Out, ReadWrite>,
209 accumulator: R::AccumulatorItem,
210 reduce_index: u32,
211 shape_axis_reduce: u32,
212 #[comptime] settings: ReduceParams,
213 inst: &R,
214) {
215 match comptime!(settings.line_mode) {
216 LineMode::Parallel => {
217 let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
218 output.write(reduce_index, Line::cast_from(result))
219 }
220 LineMode::Perpendicular => {
221 let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
222
223 if comptime![settings.line_size_output == settings.line_size_input] {
224 output.write(reduce_index, out);
225 } else {
226 let num_iters = comptime![settings.line_size_input / settings.line_size_output];
227
228 #[unroll]
229 for i in 0..num_iters {
230 let mut tmp = Line::empty(settings.line_size_output);
231
232 #[unroll]
233 for j in 0..settings.line_size_output {
234 tmp[j] = out[i * settings.line_size_output + j];
235 }
236
237 let index = num_iters * reduce_index + i;
238 output.write(index, tmp);
239 }
240 }
241 }
242 }
243}