1use super::args::{
2 FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
3};
4#[cfg(feature = "autotune")]
5use super::tune::fused_reduce_autotune;
6use crate::{
7 CubeFusionHandle, FallbackOperation,
8 engine::{
9 codegen::ir::{
10 FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout,
11 multi_block_variables_init,
12 },
13 launch::{
14 FuseTraceLauncher,
15 runner::{TraceRunner, Vectorization},
16 },
17 trace::{FuseTrace, TraceError, TuneOutput},
18 },
19 optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
20};
21use burn_fusion::stream::Context;
22use burn_ir::ReduceDimOpIr;
23use burn_std::DType;
24use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
25use cubek::reduce::{
26 ReduceDtypes, ReduceError, VectorizationMode,
27 components::instructions::ReduceOperationConfig,
28 init_tensors,
29 launch::{RoutineStrategy, reduce_kernel_virtual},
30 output_vectorization_axis,
31 routines::{
32 ReduceBlueprint, ReduceLaunchSettings, ReduceProblem, ReduceVectorSettings, Routine,
33 cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
34 },
35};
36use serde::{Deserialize, Serialize};
37use std::sync::Arc;
38
39#[cfg(not(feature = "autotune"))]
40use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
41
42pub struct ReduceOptimization<R: Runtime> {
43 pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
44}
45
46pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
47 pub(crate) trace: FuseTrace,
48 trace_read_fallback: FuseTrace,
49 trace_write_fallback: FuseTrace,
50 pub(crate) client: ComputeClient<R>,
51 pub(crate) device: R::Device,
52 pub(crate) len: usize,
53 pub(crate) len_read: usize,
54 pub(crate) reduce: FusedReduce,
55 settings: ReduceSettings,
56}
57
58impl<R: Runtime> ReduceOptimizationInfo<R> {
59 pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
60 let client = R::client(device);
61
62 Self {
63 trace: state.trace,
64 trace_read_fallback: state.trace_read_fallback,
65 trace_write_fallback: state.trace_write_fallback,
66 client,
67 device: device.clone(),
68 len: state.len,
69 len_read: state.len_read,
70 reduce: state.reduce,
71 settings: state.settings,
72 }
73 }
74 pub fn to_state(&self) -> ReduceOptimizationState {
75 ReduceOptimizationState {
76 trace: self.trace.clone(),
77 trace_read_fallback: self.trace_read_fallback.clone(),
78 trace_write_fallback: self.trace_write_fallback.clone(),
79 len: self.len,
80 len_read: self.len_read,
81 reduce: self.reduce.clone(),
82 settings: self.settings,
83 }
84 }
85}
86
87#[derive(Serialize, Deserialize, Copy, Clone)]
88pub enum ReduceSettings {
89 Always,
90 OnlyParallel,
95 Never,
96}
97
98pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
99 pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
100 pub(crate) fallback: Arc<Box<dyn FallbackOperation<R>>>,
101}
102
103impl<R: Runtime> Clone for ReduceOptimizationTuneArg<R> {
104 fn clone(&self) -> Self {
105 Self {
106 info: self.info.clone(),
107 fallback: self.fallback.clone(),
108 }
109 }
110}
111
112#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
113pub enum ReduceInstruction {
114 ArgMax,
115 ArgMin,
116 Mean,
117 Prod,
118 Sum,
119 Max,
120 Min,
121 MaxAbs,
122}
123
124pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
125 fn run(&self, context: &mut Context<CubeFusionHandle<R>>);
126}
127
128#[derive(Serialize, Deserialize)]
129pub struct ReduceOptimizationState {
130 pub(crate) trace: FuseTrace,
131 pub(crate) trace_read_fallback: FuseTrace,
132 pub(crate) trace_write_fallback: FuseTrace,
133 pub(crate) reduce: FusedReduce,
134 pub(crate) len: usize,
135 pub(crate) len_read: usize,
136 pub(crate) settings: ReduceSettings,
137}
138
139impl core::fmt::Debug for ReduceOptimizationState {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.write_fmt(format_args!(
142 "{{ len_read: {}, len_total: {} }}",
143 self.len_read, self.len
144 ))
145 }
146}
147
148#[derive(Clone, Debug, Serialize, Deserialize)]
149pub struct FusedReduce {
150 pub(crate) input: FuseArg,
151 pub(crate) output: FuseArg,
152 pub(crate) acc: FuseType,
153 pub(crate) axis: usize,
154 pub(crate) op: ReduceDimOpIr,
155 pub(crate) use_planes: bool,
156 pub(crate) shared: bool,
157 pub(crate) inst: ReduceInstruction,
158}
159
160#[derive(new)]
161pub struct FusedReduceLaunch<'a> {
162 reduce: &'a FusedReduce,
163 strategy: RoutineStrategy,
164}
165
166#[derive(Debug)]
167pub enum FusedReduceError {
168 Reduce(ReduceError),
169 InvalidSelection(Box<&'static str>),
170 InvalidInput,
171}
172
173impl From<ReduceError> for FusedReduceError {
174 fn from(value: ReduceError) -> Self {
175 Self::Reduce(value)
176 }
177}
178
179impl<R: Runtime> ReduceOptimizationTuneArg<R> {
180 pub fn execute_fused(
181 &self,
182 context: &mut Context<CubeFusionHandle<R>>,
183 strategy: RoutineStrategy,
184 ) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
185 let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
186 let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
187 launcher.launch(&self.info.client, &self.info.device, context)
188 }
189
190 pub fn execute_fallback(&self, context: &mut Context<CubeFusionHandle<R>>) -> TuneOutput<R> {
191 let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
192
193 #[allow(unused_mut)] let mut output_read = launcher
195 .launch(&self.info.client, &self.info.device, context)
196 .unwrap();
197
198 self.fallback.run(context);
199
200 #[cfg(feature = "autotune-checks")]
201 if let TuneOutput::Checked { handles } = &mut output_read {
202 let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
203 let handle_out = context
204 .handles
205 .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
206
207 handles.insert(
208 self.info.reduce.op.out.id,
209 (out_desc.shape.clone(), handle_out.clone()),
210 );
211 }
212
213 let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
214
215 let output_write = launcher
216 .launch(&self.info.client, &self.info.device, context)
217 .unwrap();
218
219 output_read.merge(output_write)
220 }
221}
222
223#[allow(clippy::too_many_arguments)]
224impl<R: Runtime> ReduceOptimization<R> {
225 pub fn new(
226 trace: FuseTrace,
227 trace_read_fallback: FuseTrace,
228 trace_write_fallback: FuseTrace,
229 client: ComputeClient<R>,
230 device: R::Device,
231 len: usize,
232 len_read: usize,
233 reduce: FusedReduce,
234 settings: ReduceSettings,
235 ) -> Self {
236 let info = ReduceOptimizationInfo {
237 trace,
238 trace_read_fallback,
239 trace_write_fallback,
240 client,
241 device,
242 len,
243 len_read,
244 reduce,
245 settings,
246 };
247
248 Self {
249 info: Arc::new(info),
250 }
251 }
252 pub fn execute(
254 &mut self,
255 context: &mut Context<CubeFusionHandle<R>>,
256 fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
257 ) {
258 let fallback = fallback(self.info.len_read);
260 let arg = ReduceOptimizationTuneArg {
261 info: self.info.clone(),
262 fallback: Arc::new(fallback),
263 };
264
265 #[cfg(feature = "autotune")]
266 fused_reduce_autotune::<R>(arg, context);
267
268 #[cfg(not(feature = "autotune"))]
269 if arg
270 .execute_fused(
271 context,
272 RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
273 )
274 .is_err()
275 {
276 arg.execute_fallback(context);
277 }
278 }
279
280 pub fn num_output_buffers(&self) -> usize {
281 self.info.trace_read_fallback.resources.outputs.len()
282 }
283
284 pub fn to_state(&self) -> ReduceOptimizationState {
285 ReduceOptimizationState {
286 trace: self.info.trace.clone(),
287 trace_read_fallback: self.info.trace_read_fallback.clone(),
288 trace_write_fallback: self.info.trace_write_fallback.clone(),
289 reduce: self.info.reduce.clone(),
290 len: self.info.len,
291 len_read: self.info.len_read,
292 settings: self.info.settings,
293 }
294 }
295
296 pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
297 let client = R::client(device);
298
299 let info = ReduceOptimizationInfo {
300 trace: state.trace,
301 trace_read_fallback: state.trace_read_fallback,
302 trace_write_fallback: state.trace_write_fallback,
303 reduce: state.reduce,
304 len: state.len,
305 len_read: state.len_read,
306 client,
307 device: device.clone(),
308 settings: state.settings,
309 };
310
311 Self {
312 info: Arc::new(info),
313 }
314 }
315
316 pub fn num_ops_fused(&self) -> usize {
318 self.info.len
319 }
320}
321
322impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
324
325impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
326 type Error = FusedReduceError;
327
328 fn run<'a>(
329 &'a self,
330 client: &'a ComputeClient<R>,
331 inputs: GlobalArgsLaunch<R>,
332 outputs: GlobalArgsLaunch<R>,
333 configs: &'a [FuseBlockConfig],
334 ) -> Result<(), FusedReduceError> {
335 let [config_read, config_write] = [&configs[0], &configs[1]];
336 let shape = match &config_read.ref_layout {
337 RefLayout::Concrete(FuseArg::Output(..)) => {
338 outputs.shape_ref(&config_read.ref_layout, config_read.rank)
339 }
340 _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
341 };
342 let reduce_count: usize = shape
343 .iter()
344 .enumerate()
345 .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
346 .product();
347
348 let vectorization_mode = match self.reduce.axis == config_read.rank - 1 {
349 true => VectorizationMode::Parallel,
350 false => VectorizationMode::Perpendicular,
351 };
352 let address_type = inputs
353 .required_address_type()
354 .max(outputs.required_address_type());
355
356 let settings = ReduceVectorSettings {
357 vectorization_mode,
358 vector_size_input: config_read.width,
359 vector_size_output: config_write.width,
360 };
361 let problem = ReduceProblem {
362 reduce_len: shape[self.reduce.axis],
363 reduce_count,
364 axis: self.reduce.axis,
365 dtypes: ReduceDtypes {
366 input: self.reduce.op.input.dtype.into(),
367 output: self.reduce.op.out.dtype.into(),
368 accumulation: self.reduce.acc.into_elem().into(),
369 },
370 address_type,
371 };
372
373 let (blueprint, settings) = match self.strategy.clone() {
374 RoutineStrategy::Unit(strategy) => {
375 let routine = UnitRoutine;
376 routine.prepare(client, problem, settings, strategy)?
377 }
378 RoutineStrategy::Plane(strategy) => {
379 let routine = PlaneRoutine;
380 routine.prepare(client, problem, settings, strategy)?
381 }
382 RoutineStrategy::Cube(strategy) => {
383 let routine = CubeRoutine;
384 routine.prepare(client, problem, settings, strategy)?
385 }
386 };
387
388 let out_vec_axis = output_vectorization_axis(
389 &inputs.strides_ref(&config_read.ref_layout, config_read.rank),
390 self.reduce.axis,
391 vectorization_mode,
392 );
393
394 let kwargs = ReduceKwArgs {
395 client,
396 inputs,
397 outputs,
398 reduce_axis: self.reduce.axis,
399 out_vec_axis,
400 config_fuse_read: config_read.clone(),
401 config_fuse_write: config_write.clone(),
402 input: self.reduce.input.clone(),
403 output: self.reduce.output.clone(),
404 blueprint,
405 settings,
406 };
407 let result = launch_reduce_mixed_precision(
408 kwargs,
409 self.reduce.inst,
410 self.reduce.op.input.dtype,
411 self.reduce.op.out.dtype,
412 DType::from(self.reduce.acc.into_elem()),
413 );
414
415 match result {
416 Ok(_) => Ok(()),
417 Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
418 }
419 }
420}
421
422struct ReduceKwArgs<'b, Run: Runtime> {
423 client: &'b ComputeClient<Run>,
424 inputs: GlobalArgsLaunch<Run>,
425 outputs: GlobalArgsLaunch<Run>,
426 reduce_axis: usize,
427 out_vec_axis: usize,
428 blueprint: ReduceBlueprint,
429 settings: ReduceLaunchSettings,
430 config_fuse_read: FuseBlockConfig,
431 config_fuse_write: FuseBlockConfig,
432 input: FuseArg,
433 output: FuseArg,
434}
435
436fn launch_reduce_mixed_precision<Run: Runtime>(
437 kwargs: ReduceKwArgs<'_, Run>,
438 instruction: ReduceInstruction,
439 dtype_input: DType,
440 dtype_output: DType,
441 dtype_acc: DType,
442) -> Result<(), LaunchError> {
443 let config = match instruction {
444 ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
445 ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
446 ReduceInstruction::Prod => ReduceOperationConfig::Prod,
447 ReduceInstruction::Mean => ReduceOperationConfig::Mean,
448 ReduceInstruction::Sum => ReduceOperationConfig::Sum,
449 ReduceInstruction::Max => ReduceOperationConfig::Max,
450 ReduceInstruction::Min => ReduceOperationConfig::Min,
451 ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
452 };
453 launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
454}
455
456fn launch_reduce<Run: Runtime>(
457 kwargs: ReduceKwArgs<'_, Run>,
458 inst: ReduceOperationConfig,
459 dtype_input: DType,
460 dtype_output: DType,
461 dtype_acc: DType,
462) -> Result<(), LaunchError> {
463 unsafe {
464 reduce_kernel_fused::launch_unchecked::<Run>(
465 kwargs.client,
466 kwargs.settings.cube_count,
467 kwargs.settings.cube_dim,
468 kwargs.settings.address_type,
469 kwargs.config_fuse_read.width,
470 kwargs.config_fuse_write.width,
471 FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
472 FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
473 kwargs.reduce_axis,
474 kwargs.out_vec_axis,
475 kwargs.blueprint,
476 inst,
477 dtype_input.into(),
478 dtype_output.into(),
479 dtype_acc.into(),
480 )
481 };
482
483 Ok(())
484}
485
486#[cube(launch_unchecked, address_type = "dynamic")]
487pub fn reduce_kernel_fused<In: Numeric, SizeIn: Size, Out: Numeric, SizeOut: Size, Acc: Numeric>(
488 input: &FusedReduceInput,
489 output: &mut FusedReduceOutput,
490 reduce_axis: usize,
491 out_vec_axis: usize,
492 #[comptime] blueprint: ReduceBlueprint,
493 #[comptime] config: ReduceOperationConfig,
494 #[define(In)] _input_dtype: StorageType,
495 #[define(Out)] _output_dtype: StorageType,
496 #[define(Acc)] _acc_dtype: StorageType,
497) {
498 multi_block_variables_init(&input.config, &mut output.global.variables);
499 multi_block_variables_init(&output.config, &mut output.global.variables);
500
501 let (input, mut output) =
502 init_tensors::<FusedReduceArgs, In, SizeIn, Out, SizeOut>(input, output);
503
504 reduce_kernel_virtual::<In, SizeIn, Out, SizeOut, Acc>(
505 &input,
506 &mut output,
507 reduce_axis,
508 out_vec_axis,
509 blueprint,
510 config,
511 );
512}