1use super::args::{
2 FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
3};
4#[cfg(feature = "autotune")]
5use super::tune::fused_reduce_autotune;
6use crate::{
7 CubeFusionHandle, FallbackOperation,
8 engine::{
9 codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout},
10 launch::{
11 FuseTraceLauncher,
12 runner::{TraceRunner, Vectorization},
13 },
14 trace::{FuseTrace, TraceError, TuneOutput},
15 },
16 optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
17};
18use burn_fusion::stream::Context;
19use burn_ir::ReduceDimOpIr;
20use burn_std::DType;
21use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
22use cubek::reduce::{
23 LineMode, ReduceDtypes, ReduceError,
24 components::instructions::ReduceOperationConfig,
25 init_tensors,
26 launch::{RoutineStrategy, reduce_kernel_virtual},
27 routines::{
28 ReduceBlueprint, ReduceLaunchSettings, ReduceLineSettings, ReduceProblem, Routine,
29 cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
30 },
31};
32use serde::{Deserialize, Serialize};
33use std::sync::Arc;
34
35#[cfg(not(feature = "autotune"))]
36use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
37
38pub struct ReduceOptimization<R: Runtime> {
39 info: Arc<ReduceOptimizationInfo<R>>,
40}
41
42pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
43 pub(crate) trace: FuseTrace,
44 trace_read_fallback: FuseTrace,
45 trace_write_fallback: FuseTrace,
46 pub(crate) client: ComputeClient<R>,
47 pub(crate) device: R::Device,
48 pub(crate) len: usize,
49 pub(crate) len_read: usize,
50 pub(crate) reduce: FusedReduce,
51}
52
53pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
54 pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
55 pub(crate) fallback: Box<dyn FallbackOperation<R>>,
56}
57
58#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
59pub enum ReduceInstruction {
60 ArgMax,
61 ArgMin,
62 Mean,
63 Prod,
64 Sum,
65 Max,
66 Min,
67 MaxAbs,
68}
69
70pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
71 fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
72}
73
74#[derive(Serialize, Deserialize)]
75pub struct ReduceOptimizationState {
76 trace: FuseTrace,
77 trace_read_fallback: FuseTrace,
78 trace_write_fallback: FuseTrace,
79 pub(crate) reduce: FusedReduce,
80 len: usize,
81 len_read: usize,
82}
83
84impl core::fmt::Debug for ReduceOptimizationState {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.write_fmt(format_args!(
87 "{{ len_read: {}, len_total: {} }}",
88 self.len_read, self.len
89 ))
90 }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct FusedReduce {
95 pub(crate) input: FuseArg,
96 pub(crate) output: FuseArg,
97 pub(crate) acc: FuseType,
98 pub(crate) axis: usize,
99 pub(crate) op: ReduceDimOpIr,
100 pub(crate) use_planes: bool,
101 pub(crate) shared: bool,
102 pub(crate) inst: ReduceInstruction,
103}
104
105#[derive(new)]
106pub struct FusedReduceLaunch<'a> {
107 reduce: &'a FusedReduce,
108 strategy: RoutineStrategy,
109}
110
111#[derive(Debug)]
112pub enum FusedReduceError {
113 Reduce(ReduceError),
114 InvalidSelection(Box<&'static str>),
115 InvalidInput,
116}
117
118impl From<ReduceError> for FusedReduceError {
119 fn from(value: ReduceError) -> Self {
120 Self::Reduce(value)
121 }
122}
123
124impl<R: Runtime> ReduceOptimizationTuneArg<R> {
125 pub fn execute_fused<BT: CubeElement>(
126 &self,
127 context: &mut Context<'_, CubeFusionHandle<R>>,
128 strategy: RoutineStrategy,
129 ) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
130 let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
131 let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
132 launcher.launch::<BT>(&self.info.client, &self.info.device, context)
133 }
134
135 pub fn execute_fallback<BT: CubeElement>(
136 &self,
137 context: &mut Context<'_, CubeFusionHandle<R>>,
138 ) -> TuneOutput<R> {
139 let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
140
141 #[allow(unused_mut)] let mut output_read = launcher
143 .launch::<BT>(&self.info.client, &self.info.device, context)
144 .unwrap();
145
146 self.fallback.run(context);
147
148 #[cfg(feature = "autotune-checks")]
149 if let TuneOutput::Checked { handles } = &mut output_read {
150 let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
151 let handle_out = context
152 .handles
153 .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
154
155 handles.insert(
156 self.info.reduce.op.out.id,
157 (out_desc.shape.dims.clone(), handle_out.clone()),
158 );
159 }
160
161 let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
162
163 let output_write = launcher
164 .launch::<BT>(&self.info.client, &self.info.device, context)
165 .unwrap();
166
167 output_read.merge(output_write)
168 }
169}
170
171#[allow(clippy::too_many_arguments)]
172impl<R: Runtime> ReduceOptimization<R> {
173 pub fn new(
174 trace: FuseTrace,
175 trace_read_fallback: FuseTrace,
176 trace_write_fallback: FuseTrace,
177 client: ComputeClient<R>,
178 device: R::Device,
179 len: usize,
180 len_read: usize,
181 reduce: FusedReduce,
182 ) -> Self {
183 let info = ReduceOptimizationInfo {
184 trace,
185 trace_read_fallback,
186 trace_write_fallback,
187 client,
188 device,
189 len,
190 len_read,
191 reduce,
192 };
193
194 Self {
195 info: Arc::new(info),
196 }
197 }
198 pub fn execute<BT: CubeElement>(
200 &mut self,
201 context: &mut Context<'_, CubeFusionHandle<R>>,
202 fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
203 ) {
204 let fallback = fallback(self.info.len_read);
206 let arg = ReduceOptimizationTuneArg {
207 info: self.info.clone(),
208 fallback,
209 };
210
211 #[cfg(feature = "autotune")]
212 fused_reduce_autotune::<R, BT>(arg, context);
213
214 #[cfg(not(feature = "autotune"))]
215 if arg
216 .execute_fused::<BT>(
217 context,
218 RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
219 )
220 .is_err()
221 {
222 arg.execute_fallback::<BT>(context);
223 }
224 }
225
226 pub fn num_output_buffers(&self) -> usize {
227 self.info.trace_read_fallback.resources.outputs.len()
228 }
229
230 pub fn to_state(&self) -> ReduceOptimizationState {
231 ReduceOptimizationState {
232 trace: self.info.trace.clone(),
233 trace_read_fallback: self.info.trace_read_fallback.clone(),
234 trace_write_fallback: self.info.trace_write_fallback.clone(),
235 reduce: self.info.reduce.clone(),
236 len: self.info.len,
237 len_read: self.info.len_read,
238 }
239 }
240
241 pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
242 let client = R::client(device);
243
244 let info = ReduceOptimizationInfo {
245 trace: state.trace,
246 trace_read_fallback: state.trace_read_fallback,
247 trace_write_fallback: state.trace_write_fallback,
248 reduce: state.reduce,
249 len: state.len,
250 len_read: state.len_read,
251 client,
252 device: device.clone(),
253 };
254
255 Self {
256 info: Arc::new(info),
257 }
258 }
259
260 pub fn num_ops_fused(&self) -> usize {
262 self.info.len
263 }
264}
265
266impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
268
269impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
270 type Error = FusedReduceError;
271
272 fn run<'a>(
273 &'a self,
274 client: &'a ComputeClient<R>,
275 inputs: GlobalArgsLaunch<'a, R>,
276 outputs: GlobalArgsLaunch<'a, R>,
277 configs: &'a [FuseBlockConfig],
278 ) -> Result<(), FusedReduceError> {
279 let [config_read, config_write] = [&configs[0], &configs[1]];
280 let shape = match &config_read.ref_layout {
281 RefLayout::Concrete(FuseArg::Output(..)) => {
282 outputs.shape_ref(&config_read.ref_layout, config_read.rank)
283 }
284 _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
285 };
286 let reduce_count: usize = shape
287 .iter()
288 .enumerate()
289 .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
290 .product();
291
292 let line_mode = match self.reduce.axis == config_read.rank - 1 {
293 true => LineMode::Parallel,
294 false => LineMode::Perpendicular,
295 };
296
297 let settings = ReduceLineSettings {
298 line_mode,
299 line_size_input: config_read.width,
300 line_size_output: config_write.width,
301 };
302 let problem = ReduceProblem {
303 vector_size: shape[self.reduce.axis],
304 vector_count: reduce_count,
305 axis: self.reduce.axis,
306 dtypes: ReduceDtypes {
307 input: self.reduce.op.input.dtype.into(),
308 output: self.reduce.op.out.dtype.into(),
309 accumulation: self.reduce.acc.into_elem().into(),
310 },
311 };
312
313 let (blueprint, settings) = match self.strategy.clone() {
314 RoutineStrategy::Unit(strategy) => {
315 let routine = UnitRoutine;
316 routine.prepare(client, problem, settings, strategy)?
317 }
318 RoutineStrategy::Plane(strategy) => {
319 let routine = PlaneRoutine;
320 routine.prepare(client, problem, settings, strategy)?
321 }
322 RoutineStrategy::Cube(strategy) => {
323 let routine = CubeRoutine;
324 routine.prepare(client, problem, settings, strategy)?
325 }
326 };
327
328 let kwargs = ReduceKwArgs {
329 client,
330 inputs,
331 outputs,
332 axis: self.reduce.axis,
333 config_fuse_read: config_read.clone(),
334 config_fuse_write: config_write.clone(),
335 input: self.reduce.input.clone(),
336 output: self.reduce.output.clone(),
337 blueprint,
338 settings,
339 };
340 let result = launch_reduce_mixed_precision(
341 kwargs,
342 self.reduce.inst,
343 self.reduce.op.input.dtype,
344 self.reduce.op.out.dtype,
345 DType::from(self.reduce.acc.into_elem()),
346 );
347
348 match result {
349 Ok(_) => Ok(()),
350 Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
351 }
352 }
353}
354
355struct ReduceKwArgs<'a, 'b, Run: Runtime> {
356 client: &'b ComputeClient<Run>,
357 inputs: GlobalArgsLaunch<'a, Run>,
358 outputs: GlobalArgsLaunch<'a, Run>,
359 axis: usize,
360 blueprint: ReduceBlueprint,
361 settings: ReduceLaunchSettings,
362 config_fuse_read: FuseBlockConfig,
363 config_fuse_write: FuseBlockConfig,
364 input: FuseArg,
365 output: FuseArg,
366}
367
368fn launch_reduce_mixed_precision<Run: Runtime>(
369 kwargs: ReduceKwArgs<'_, '_, Run>,
370 instruction: ReduceInstruction,
371 dtype_input: DType,
372 dtype_output: DType,
373 dtype_acc: DType,
374) -> Result<(), LaunchError> {
375 let config = match instruction {
376 ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
377 ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
378 ReduceInstruction::Prod => ReduceOperationConfig::Prod,
379 ReduceInstruction::Mean => ReduceOperationConfig::Mean,
380 ReduceInstruction::Sum => ReduceOperationConfig::Sum,
381 ReduceInstruction::Max => ReduceOperationConfig::Max,
382 ReduceInstruction::Min => ReduceOperationConfig::Min,
383 ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
384 };
385 launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
386}
387
388fn launch_reduce<Run: Runtime>(
389 kwargs: ReduceKwArgs<'_, '_, Run>,
390 inst: ReduceOperationConfig,
391 dtype_input: DType,
392 dtype_output: DType,
393 dtype_acc: DType,
394) -> Result<(), LaunchError> {
395 unsafe {
396 reduce_kernel::launch_unchecked::<Run>(
397 kwargs.client,
398 kwargs.settings.cube_count,
399 kwargs.settings.cube_dim,
400 FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
401 FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
402 ScalarArg::new(kwargs.axis),
403 kwargs.blueprint,
404 inst,
405 dtype_input.into(),
406 dtype_output.into(),
407 dtype_acc.into(),
408 )
409 }
410}
411
412#[cube(launch_unchecked)]
413pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric>(
414 input: &FusedReduceInput,
415 output: &mut FusedReduceOutput,
416 axis_reduce: usize,
417 #[comptime] blueprint: ReduceBlueprint,
418 #[comptime] config: ReduceOperationConfig,
419 #[define(In)] _input_dtype: StorageType,
420 #[define(Out)] _output_dtype: StorageType,
421 #[define(Acc)] _acc_dtype: StorageType,
422) {
423 let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(input, output);
424
425 reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
426}