1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use crate::BoundChecksInner;
6use crate::args::ReduceArgs;
7use crate::args::TensorArgs;
8use crate::args::init_tensors;
9use crate::instructions::*;
10use crate::precision::ReducePrecision;
11use crate::primitives::*;
12use crate::{LineMode, ReduceConfig, ReduceStrategy};
13
14pub(crate) fn launch_reduce<Run: Runtime, P: ReducePrecision, Out: Numeric, Rd: ReduceFamily>(
18 client: &ComputeClient<Run::Server>,
19 input: TensorHandleRef<Run>,
20 output: TensorHandleRef<Run>,
21 axis: u32,
22 config: ReduceConfig,
23 strategy: ReduceStrategy,
24 inst: Rd::Config,
25) {
26 let settings = ReduceParams {
27 shared: strategy.shared.then(|| {
28 if strategy.use_planes {
29 config.cube_dim.y
30 } else {
31 config.cube_dim.num_elems()
32 }
33 }),
34 use_planes: strategy.use_planes,
35 line_size_input: config.line_size_input,
36 line_size_output: config.line_size_output,
37 line_mode: config.line_mode,
38 bound_checks: config.bound_checks,
39 bound_checks_inner: config.bound_checks_inner,
40 };
41 unsafe {
42 reduce_kernel::launch_unchecked::<P::EI, Out, P::EA, Rd, TensorArgs, Run>(
43 client,
44 config.cube_count,
45 config.cube_dim,
46 input.as_tensor_arg(config.line_size_input as u8),
47 output.as_tensor_arg(config.line_size_output as u8),
48 ScalarArg::new(axis),
49 settings,
50 inst,
51 );
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct ReduceParams {
57 pub shared: Option<u32>, pub use_planes: bool,
59 pub line_size_input: u32,
60 pub line_size_output: u32,
61 pub line_mode: LineMode,
62 pub bound_checks: bool,
63 pub bound_checks_inner: BoundChecksInner,
64}
65
66#[cube(launch_unchecked)]
67pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily, RA: ReduceArgs>(
68 input: &RA::Input<In>,
69 output: &mut RA::Output<Out>,
70 axis_reduce: u32,
71 #[comptime] params: ReduceParams,
72 #[comptime] config: R::Config,
73) {
74 let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
75 reduce_kernel_virtual::<In, Out, Acc, R>(&input, &mut output, axis_reduce, params, config);
76}
77
78#[cube]
79pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily>(
80 input: &VirtualTensor<In>,
81 output: &mut VirtualTensor<Out, ReadWrite>,
82 axis_reduce: u32,
83 #[comptime] params: ReduceParams,
84 #[comptime] config: R::Config,
85) {
86 let reduce_index = get_reduce_index(params);
87
88 #[allow(clippy::collapsible_if)]
89 if comptime![params.bound_checks] {
90 if reduce_index >= get_reduce_count(output.len() * params.line_size_output, params) {
91 terminate!();
92 }
93 }
94
95 reduce_kernel_inner::<(In, Acc), Out, R>(
96 input,
97 output,
98 axis_reduce,
99 reduce_index,
100 params,
101 config,
102 )
103}
104
105#[cube]
106fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
107 input: &VirtualTensor<P::EI>,
108 output: &mut VirtualTensor<Out, ReadWrite>,
109 axis_reduce: u32,
110 reduce_index: u32,
111 #[comptime] params: ReduceParams,
112 #[comptime] config: R::Config,
113) {
114 let range = ReduceRange::new::<P, Out>(reduce_index, input, output, axis_reduce, params);
115
116 let inst = &R::Instruction::<P>::from_config(config);
117 let accumulator = match comptime!((params.shared, params.use_planes)) {
118 (Some(accumulator_size), use_planes) => {
119 let mut accumulator = reduce_slice_shared::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
120 input,
121 inst,
122 range,
123 accumulator_size,
124 params.line_size_input,
125 params.line_mode,
126 use_planes,
127 params.bound_checks_inner,
128 );
129 sync_cube();
130 reduce_tree::<P, R::Instruction<P>>(inst, &mut accumulator, accumulator_size)
131 }
132 (None, true) => reduce_slice_plane::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
133 input,
134 inst,
135 range,
136 params.line_size_input,
137 params.line_mode,
138 params.bound_checks_inner,
139 ),
140 (None, false) => reduce_slice::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
141 input,
142 range,
143 inst,
144 params.line_size_input,
145 params.line_mode,
146 ),
147 };
148
149 if elected_writer(params) {
150 write_to_output::<P, Out, R::Instruction<P>>(
151 output,
152 accumulator,
153 reduce_index,
154 input.shape(axis_reduce),
155 params,
156 inst,
157 );
158 }
159}
160
161#[cube]
162fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
163 if params.shared.is_some() {
164 CUBE_POS
165 } else if params.use_planes {
166 CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
167 } else {
168 ABSOLUTE_POS
169 }
170}
171
172#[cube]
173fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
174 match comptime!(params.line_mode) {
175 LineMode::Parallel => output_size,
176 LineMode::Perpendicular => output_size / params.line_size_input,
177 }
178}
179
180#[cube]
181fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
182 if settings.shared.is_some() {
183 UNIT_POS == 0
184 } else if settings.use_planes {
185 UNIT_POS_X == 0
186 } else {
187 true.runtime()
188 }
189}
190
191#[cube]
192fn write_to_output<P: ReducePrecision, Out: Numeric, R: ReduceInstruction<P>>(
193 output: &mut VirtualTensor<Out, ReadWrite>,
194 accumulator: R::AccumulatorItem,
195 reduce_index: u32,
196 shape_axis_reduce: u32,
197 #[comptime] settings: ReduceParams,
198 inst: &R,
199) {
200 match comptime!(settings.line_mode) {
201 LineMode::Parallel => {
202 let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
203 output.write(reduce_index, Line::cast_from(result))
204 }
205 LineMode::Perpendicular => {
206 let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
207
208 if comptime![settings.line_size_output == settings.line_size_input] {
209 output.write(reduce_index, out);
210 } else {
211 let num_iters = comptime![settings.line_size_input / settings.line_size_output];
212
213 #[unroll]
214 for i in 0..num_iters {
215 let mut tmp = Line::empty(settings.line_size_output);
216
217 #[unroll]
218 for j in 0..settings.line_size_output {
219 tmp[j] = out[i * settings.line_size_output + j];
220 }
221
222 let index = num_iters * reduce_index + i;
223 output.write(index, tmp);
224 }
225 }
226 }
227 }
228}