Skip to main content

burn_cubecl_fusion/optim/reduce/
optimization.rs

1use super::args::{
2    FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
3};
4#[cfg(feature = "autotune")]
5use super::tune::fused_reduce_autotune;
6use crate::{
7    CubeFusionHandle, FallbackOperation,
8    engine::{
9        codegen::ir::{
10            FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout,
11            multi_block_variables_init,
12        },
13        launch::{
14            FuseTraceLauncher,
15            runner::{TraceRunner, Vectorization},
16        },
17        trace::{FuseTrace, TraceError, TuneOutput},
18    },
19    optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
20};
21use burn_fusion::stream::Context;
22use burn_ir::ReduceDimOpIr;
23use burn_std::DType;
24use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
25use cubek::reduce::{
26    ReduceDtypes, ReduceError, VectorizationMode,
27    components::instructions::ReduceOperationConfig,
28    init_tensors,
29    launch::{RoutineStrategy, reduce_kernel_virtual},
30    output_vectorization_axis,
31    routines::{
32        ReduceBlueprint, ReduceLaunchSettings, ReduceProblem, ReduceVectorSettings, Routine,
33        cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
34    },
35};
36use serde::{Deserialize, Serialize};
37use std::sync::Arc;
38
39#[cfg(not(feature = "autotune"))]
40use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
41
42pub struct ReduceOptimization<R: Runtime> {
43    pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
44}
45
46pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
47    pub(crate) trace: FuseTrace,
48    trace_read_fallback: FuseTrace,
49    trace_write_fallback: FuseTrace,
50    pub(crate) client: ComputeClient<R>,
51    pub(crate) device: R::Device,
52    pub(crate) len: usize,
53    pub(crate) len_read: usize,
54    pub(crate) reduce: FusedReduce,
55    settings: ReduceSettings,
56}
57
58impl<R: Runtime> ReduceOptimizationInfo<R> {
59    pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
60        let client = R::client(device);
61
62        Self {
63            trace: state.trace,
64            trace_read_fallback: state.trace_read_fallback,
65            trace_write_fallback: state.trace_write_fallback,
66            client,
67            device: device.clone(),
68            len: state.len,
69            len_read: state.len_read,
70            reduce: state.reduce,
71            settings: state.settings,
72        }
73    }
74    pub fn to_state(&self) -> ReduceOptimizationState {
75        ReduceOptimizationState {
76            trace: self.trace.clone(),
77            trace_read_fallback: self.trace_read_fallback.clone(),
78            trace_write_fallback: self.trace_write_fallback.clone(),
79            len: self.len,
80            len_read: self.len_read,
81            reduce: self.reduce.clone(),
82            settings: self.settings,
83        }
84    }
85}
86
87#[derive(Serialize, Deserialize, Copy, Clone)]
88pub enum ReduceSettings {
89    Always,
90    /// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise
91    /// vectorization is impossible. Only [VectorizationMode::Perpendicular] supports vectorization.
92    ///
93    /// We could still fuse some output operations, but it would probably lead to worse performance.
94    OnlyParallel,
95    Never,
96}
97
98pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
99    pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
100    pub(crate) fallback: Arc<Box<dyn FallbackOperation<R>>>,
101}
102
103impl<R: Runtime> Clone for ReduceOptimizationTuneArg<R> {
104    fn clone(&self) -> Self {
105        Self {
106            info: self.info.clone(),
107            fallback: self.fallback.clone(),
108        }
109    }
110}
111
112#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
113pub enum ReduceInstruction {
114    ArgMax,
115    ArgMin,
116    Mean,
117    Prod,
118    Sum,
119    Max,
120    Min,
121    MaxAbs,
122}
123
124pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
125    fn run(&self, context: &mut Context<CubeFusionHandle<R>>);
126}
127
128#[derive(Serialize, Deserialize)]
129pub struct ReduceOptimizationState {
130    pub(crate) trace: FuseTrace,
131    pub(crate) trace_read_fallback: FuseTrace,
132    pub(crate) trace_write_fallback: FuseTrace,
133    pub(crate) reduce: FusedReduce,
134    pub(crate) len: usize,
135    pub(crate) len_read: usize,
136    pub(crate) settings: ReduceSettings,
137}
138
139impl core::fmt::Debug for ReduceOptimizationState {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.write_fmt(format_args!(
142            "{{ len_read: {}, len_total: {} }}",
143            self.len_read, self.len
144        ))
145    }
146}
147
148#[derive(Clone, Debug, Serialize, Deserialize)]
149pub struct FusedReduce {
150    pub(crate) input: FuseArg,
151    pub(crate) output: FuseArg,
152    pub(crate) acc: FuseType,
153    pub(crate) axis: usize,
154    pub(crate) op: ReduceDimOpIr,
155    pub(crate) use_planes: bool,
156    pub(crate) shared: bool,
157    pub(crate) inst: ReduceInstruction,
158}
159
160#[derive(new)]
161pub struct FusedReduceLaunch<'a> {
162    reduce: &'a FusedReduce,
163    strategy: RoutineStrategy,
164}
165
166#[derive(Debug)]
167pub enum FusedReduceError {
168    Reduce(ReduceError),
169    InvalidSelection(Box<&'static str>),
170    InvalidInput,
171}
172
173impl From<ReduceError> for FusedReduceError {
174    fn from(value: ReduceError) -> Self {
175        Self::Reduce(value)
176    }
177}
178
179impl<R: Runtime> ReduceOptimizationTuneArg<R> {
180    pub fn execute_fused(
181        &self,
182        context: &mut Context<CubeFusionHandle<R>>,
183        strategy: RoutineStrategy,
184    ) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
185        let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
186        let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
187        launcher.launch(&self.info.client, &self.info.device, context)
188    }
189
190    pub fn execute_fallback(&self, context: &mut Context<CubeFusionHandle<R>>) -> TuneOutput<R> {
191        let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
192
193        #[allow(unused_mut)] // It is used when `autotune-checks` is activated.
194        let mut output_read = launcher
195            .launch(&self.info.client, &self.info.device, context)
196            .unwrap();
197
198        self.fallback.run(context);
199
200        #[cfg(feature = "autotune-checks")]
201        if let TuneOutput::Checked { handles } = &mut output_read {
202            let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
203            let handle_out = context
204                .handles
205                .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
206
207            handles.insert(
208                self.info.reduce.op.out.id,
209                (out_desc.shape.clone(), handle_out.clone()),
210            );
211        }
212
213        let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
214
215        let output_write = launcher
216            .launch(&self.info.client, &self.info.device, context)
217            .unwrap();
218
219        output_read.merge(output_write)
220    }
221}
222
223#[allow(clippy::too_many_arguments)]
224impl<R: Runtime> ReduceOptimization<R> {
225    pub fn new(
226        trace: FuseTrace,
227        trace_read_fallback: FuseTrace,
228        trace_write_fallback: FuseTrace,
229        client: ComputeClient<R>,
230        device: R::Device,
231        len: usize,
232        len_read: usize,
233        reduce: FusedReduce,
234        settings: ReduceSettings,
235    ) -> Self {
236        let info = ReduceOptimizationInfo {
237            trace,
238            trace_read_fallback,
239            trace_write_fallback,
240            client,
241            device,
242            len,
243            len_read,
244            reduce,
245            settings,
246        };
247
248        Self {
249            info: Arc::new(info),
250        }
251    }
252    /// Execute the optimization.
253    pub fn execute(
254        &mut self,
255        context: &mut Context<CubeFusionHandle<R>>,
256        fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
257    ) {
258        // The index of the fallback reduce is the number of ops fused as read.
259        let fallback = fallback(self.info.len_read);
260        let arg = ReduceOptimizationTuneArg {
261            info: self.info.clone(),
262            fallback: Arc::new(fallback),
263        };
264
265        #[cfg(feature = "autotune")]
266        fused_reduce_autotune::<R>(arg, context);
267
268        #[cfg(not(feature = "autotune"))]
269        if arg
270            .execute_fused(
271                context,
272                RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
273            )
274            .is_err()
275        {
276            arg.execute_fallback(context);
277        }
278    }
279
280    pub fn num_output_buffers(&self) -> usize {
281        self.info.trace_read_fallback.resources.outputs.len()
282    }
283
284    pub fn to_state(&self) -> ReduceOptimizationState {
285        ReduceOptimizationState {
286            trace: self.info.trace.clone(),
287            trace_read_fallback: self.info.trace_read_fallback.clone(),
288            trace_write_fallback: self.info.trace_write_fallback.clone(),
289            reduce: self.info.reduce.clone(),
290            len: self.info.len,
291            len_read: self.info.len_read,
292            settings: self.info.settings,
293        }
294    }
295
296    pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
297        let client = R::client(device);
298
299        let info = ReduceOptimizationInfo {
300            trace: state.trace,
301            trace_read_fallback: state.trace_read_fallback,
302            trace_write_fallback: state.trace_write_fallback,
303            reduce: state.reduce,
304            len: state.len,
305            len_read: state.len_read,
306            client,
307            device: device.clone(),
308            settings: state.settings,
309        };
310
311        Self {
312            info: Arc::new(info),
313        }
314    }
315
316    /// Returns the number of output buffers added by fusion.
317    pub fn num_ops_fused(&self) -> usize {
318        self.info.len
319    }
320}
321
322// TODO: Implement better vectorization here.
323impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
324
325impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
326    type Error = FusedReduceError;
327
328    fn run<'a>(
329        &'a self,
330        client: &'a ComputeClient<R>,
331        inputs: GlobalArgsLaunch<R>,
332        outputs: GlobalArgsLaunch<R>,
333        configs: &'a [FuseBlockConfig],
334    ) -> Result<(), FusedReduceError> {
335        let [config_read, config_write] = [&configs[0], &configs[1]];
336        let shape = match &config_read.ref_layout {
337            RefLayout::Concrete(FuseArg::Output(..)) => {
338                outputs.shape_ref(&config_read.ref_layout, config_read.rank)
339            }
340            _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
341        };
342        let reduce_count: usize = shape
343            .iter()
344            .enumerate()
345            .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
346            .product();
347
348        let vectorization_mode = match self.reduce.axis == config_read.rank - 1 {
349            true => VectorizationMode::Parallel,
350            false => VectorizationMode::Perpendicular,
351        };
352        let address_type = inputs
353            .required_address_type()
354            .max(outputs.required_address_type());
355
356        let settings = ReduceVectorSettings {
357            vectorization_mode,
358            vector_size_input: config_read.width,
359            vector_size_output: config_write.width,
360        };
361        let problem = ReduceProblem {
362            reduce_len: shape[self.reduce.axis],
363            reduce_count,
364            axis: self.reduce.axis,
365            dtypes: ReduceDtypes {
366                input: self.reduce.op.input.dtype.into(),
367                output: self.reduce.op.out.dtype.into(),
368                accumulation: self.reduce.acc.into_elem().into(),
369            },
370            address_type,
371        };
372
373        let (blueprint, settings) = match self.strategy.clone() {
374            RoutineStrategy::Unit(strategy) => {
375                let routine = UnitRoutine;
376                routine.prepare(client, problem, settings, strategy)?
377            }
378            RoutineStrategy::Plane(strategy) => {
379                let routine = PlaneRoutine;
380                routine.prepare(client, problem, settings, strategy)?
381            }
382            RoutineStrategy::Cube(strategy) => {
383                let routine = CubeRoutine;
384                routine.prepare(client, problem, settings, strategy)?
385            }
386        };
387
388        let out_vec_axis = output_vectorization_axis(
389            &inputs.strides_ref(&config_read.ref_layout, config_read.rank),
390            self.reduce.axis,
391            vectorization_mode,
392        );
393
394        let kwargs = ReduceKwArgs {
395            client,
396            inputs,
397            outputs,
398            reduce_axis: self.reduce.axis,
399            out_vec_axis,
400            config_fuse_read: config_read.clone(),
401            config_fuse_write: config_write.clone(),
402            input: self.reduce.input.clone(),
403            output: self.reduce.output.clone(),
404            blueprint,
405            settings,
406        };
407        let result = launch_reduce_mixed_precision(
408            kwargs,
409            self.reduce.inst,
410            self.reduce.op.input.dtype,
411            self.reduce.op.out.dtype,
412            DType::from(self.reduce.acc.into_elem()),
413        );
414
415        match result {
416            Ok(_) => Ok(()),
417            Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
418        }
419    }
420}
421
422struct ReduceKwArgs<'b, Run: Runtime> {
423    client: &'b ComputeClient<Run>,
424    inputs: GlobalArgsLaunch<Run>,
425    outputs: GlobalArgsLaunch<Run>,
426    reduce_axis: usize,
427    out_vec_axis: usize,
428    blueprint: ReduceBlueprint,
429    settings: ReduceLaunchSettings,
430    config_fuse_read: FuseBlockConfig,
431    config_fuse_write: FuseBlockConfig,
432    input: FuseArg,
433    output: FuseArg,
434}
435
436fn launch_reduce_mixed_precision<Run: Runtime>(
437    kwargs: ReduceKwArgs<'_, Run>,
438    instruction: ReduceInstruction,
439    dtype_input: DType,
440    dtype_output: DType,
441    dtype_acc: DType,
442) -> Result<(), LaunchError> {
443    let config = match instruction {
444        ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
445        ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
446        ReduceInstruction::Prod => ReduceOperationConfig::Prod,
447        ReduceInstruction::Mean => ReduceOperationConfig::Mean,
448        ReduceInstruction::Sum => ReduceOperationConfig::Sum,
449        ReduceInstruction::Max => ReduceOperationConfig::Max,
450        ReduceInstruction::Min => ReduceOperationConfig::Min,
451        ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
452    };
453    launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
454}
455
456fn launch_reduce<Run: Runtime>(
457    kwargs: ReduceKwArgs<'_, Run>,
458    inst: ReduceOperationConfig,
459    dtype_input: DType,
460    dtype_output: DType,
461    dtype_acc: DType,
462) -> Result<(), LaunchError> {
463    unsafe {
464        reduce_kernel_fused::launch_unchecked::<Run>(
465            kwargs.client,
466            kwargs.settings.cube_count,
467            kwargs.settings.cube_dim,
468            kwargs.settings.address_type,
469            kwargs.config_fuse_read.width,
470            kwargs.config_fuse_write.width,
471            FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
472            FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
473            kwargs.reduce_axis,
474            kwargs.out_vec_axis,
475            kwargs.blueprint,
476            inst,
477            dtype_input.into(),
478            dtype_output.into(),
479            dtype_acc.into(),
480        )
481    };
482
483    Ok(())
484}
485
486#[cube(launch_unchecked, address_type = "dynamic")]
487pub fn reduce_kernel_fused<In: Numeric, SizeIn: Size, Out: Numeric, SizeOut: Size, Acc: Numeric>(
488    input: &FusedReduceInput,
489    output: &mut FusedReduceOutput,
490    reduce_axis: usize,
491    out_vec_axis: usize,
492    #[comptime] blueprint: ReduceBlueprint,
493    #[comptime] config: ReduceOperationConfig,
494    #[define(In)] _input_dtype: StorageType,
495    #[define(Out)] _output_dtype: StorageType,
496    #[define(Acc)] _acc_dtype: StorageType,
497) {
498    multi_block_variables_init(&input.config, &mut output.global.variables);
499    multi_block_variables_init(&output.config, &mut output.global.variables);
500
501    let (input, mut output) =
502        init_tensors::<FusedReduceArgs, In, SizeIn, Out, SizeOut>(input, output);
503
504    reduce_kernel_virtual::<In, SizeIn, Out, SizeOut, Acc>(
505        &input,
506        &mut output,
507        reduce_axis,
508        out_vec_axis,
509        blueprint,
510        config,
511    );
512}