burn_cubecl_fusion/optim/reduce/
optimization.rs

1use super::args::{
2    FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,
3};
4#[cfg(feature = "autotune")]
5use super::tune::fused_reduce_autotune;
6use crate::{
7    CubeFusionHandle, FallbackOperation,
8    engine::{
9        codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout},
10        launch::{
11            FuseTraceLauncher,
12            runner::{TraceRunner, Vectorization},
13        },
14        trace::{FuseTrace, TraceError, TuneOutput},
15    },
16    optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},
17};
18use burn_fusion::stream::Context;
19use burn_ir::ReduceDimOpIr;
20use burn_std::DType;
21use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};
22use cubek::reduce::{
23    LineMode, ReduceDtypes, ReduceError,
24    components::instructions::ReduceOperationConfig,
25    init_tensors,
26    launch::{RoutineStrategy, reduce_kernel_virtual},
27    routines::{
28        ReduceBlueprint, ReduceLaunchSettings, ReduceLineSettings, ReduceProblem, Routine,
29        cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
30    },
31};
32use serde::{Deserialize, Serialize};
33use std::sync::Arc;
34
35#[cfg(not(feature = "autotune"))]
36use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};
37
38pub struct ReduceOptimization<R: Runtime> {
39    info: Arc<ReduceOptimizationInfo<R>>,
40}
41
42pub(crate) struct ReduceOptimizationInfo<R: Runtime> {
43    pub(crate) trace: FuseTrace,
44    trace_read_fallback: FuseTrace,
45    trace_write_fallback: FuseTrace,
46    pub(crate) client: ComputeClient<R>,
47    pub(crate) device: R::Device,
48    pub(crate) len: usize,
49    pub(crate) len_read: usize,
50    pub(crate) reduce: FusedReduce,
51}
52
53pub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {
54    pub(crate) info: Arc<ReduceOptimizationInfo<R>>,
55    pub(crate) fallback: Box<dyn FallbackOperation<R>>,
56}
57
58#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
59pub enum ReduceInstruction {
60    ArgMax,
61    ArgMin,
62    Mean,
63    Prod,
64    Sum,
65    Max,
66    Min,
67    MaxAbs,
68}
69
70pub trait ReduceFallbackFn<R: Runtime>: Send + Sync {
71    fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
72}
73
74#[derive(Serialize, Deserialize)]
75pub struct ReduceOptimizationState {
76    trace: FuseTrace,
77    trace_read_fallback: FuseTrace,
78    trace_write_fallback: FuseTrace,
79    pub(crate) reduce: FusedReduce,
80    len: usize,
81    len_read: usize,
82}
83
84impl core::fmt::Debug for ReduceOptimizationState {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.write_fmt(format_args!(
87            "{{ len_read: {}, len_total: {} }}",
88            self.len_read, self.len
89        ))
90    }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct FusedReduce {
95    pub(crate) input: FuseArg,
96    pub(crate) output: FuseArg,
97    pub(crate) acc: FuseType,
98    pub(crate) axis: usize,
99    pub(crate) op: ReduceDimOpIr,
100    pub(crate) use_planes: bool,
101    pub(crate) shared: bool,
102    pub(crate) inst: ReduceInstruction,
103}
104
105#[derive(new)]
106pub struct FusedReduceLaunch<'a> {
107    reduce: &'a FusedReduce,
108    strategy: RoutineStrategy,
109}
110
111#[derive(Debug)]
112pub enum FusedReduceError {
113    Reduce(ReduceError),
114    InvalidSelection(Box<&'static str>),
115    InvalidInput,
116}
117
118impl From<ReduceError> for FusedReduceError {
119    fn from(value: ReduceError) -> Self {
120        Self::Reduce(value)
121    }
122}
123
124impl<R: Runtime> ReduceOptimizationTuneArg<R> {
125    pub fn execute_fused<BT: CubeElement>(
126        &self,
127        context: &mut Context<'_, CubeFusionHandle<R>>,
128        strategy: RoutineStrategy,
129    ) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {
130        let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);
131        let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);
132        launcher.launch::<BT>(&self.info.client, &self.info.device, context)
133    }
134
135    pub fn execute_fallback<BT: CubeElement>(
136        &self,
137        context: &mut Context<'_, CubeFusionHandle<R>>,
138    ) -> TuneOutput<R> {
139        let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);
140
141        #[allow(unused_mut)] // It is used when `autotune-checks` is activated.
142        let mut output_read = launcher
143            .launch::<BT>(&self.info.client, &self.info.device, context)
144            .unwrap();
145
146        self.fallback.run(context);
147
148        #[cfg(feature = "autotune-checks")]
149        if let TuneOutput::Checked { handles } = &mut output_read {
150            let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();
151            let handle_out = context
152                .handles
153                .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
154
155            handles.insert(
156                self.info.reduce.op.out.id,
157                (out_desc.shape.dims.clone(), handle_out.clone()),
158            );
159        }
160
161        let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);
162
163        let output_write = launcher
164            .launch::<BT>(&self.info.client, &self.info.device, context)
165            .unwrap();
166
167        output_read.merge(output_write)
168    }
169}
170
171#[allow(clippy::too_many_arguments)]
172impl<R: Runtime> ReduceOptimization<R> {
173    pub fn new(
174        trace: FuseTrace,
175        trace_read_fallback: FuseTrace,
176        trace_write_fallback: FuseTrace,
177        client: ComputeClient<R>,
178        device: R::Device,
179        len: usize,
180        len_read: usize,
181        reduce: FusedReduce,
182    ) -> Self {
183        let info = ReduceOptimizationInfo {
184            trace,
185            trace_read_fallback,
186            trace_write_fallback,
187            client,
188            device,
189            len,
190            len_read,
191            reduce,
192        };
193
194        Self {
195            info: Arc::new(info),
196        }
197    }
198    /// Execute the optimization.
199    pub fn execute<BT: CubeElement>(
200        &mut self,
201        context: &mut Context<'_, CubeFusionHandle<R>>,
202        fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
203    ) {
204        // The index of the fallback reduce is the number of ops fused as read.
205        let fallback = fallback(self.info.len_read);
206        let arg = ReduceOptimizationTuneArg {
207            info: self.info.clone(),
208            fallback,
209        };
210
211        #[cfg(feature = "autotune")]
212        fused_reduce_autotune::<R, BT>(arg, context);
213
214        #[cfg(not(feature = "autotune"))]
215        if arg
216            .execute_fused::<BT>(
217                context,
218                RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
219            )
220            .is_err()
221        {
222            arg.execute_fallback::<BT>(context);
223        }
224    }
225
226    pub fn num_output_buffers(&self) -> usize {
227        self.info.trace_read_fallback.resources.outputs.len()
228    }
229
230    pub fn to_state(&self) -> ReduceOptimizationState {
231        ReduceOptimizationState {
232            trace: self.info.trace.clone(),
233            trace_read_fallback: self.info.trace_read_fallback.clone(),
234            trace_write_fallback: self.info.trace_write_fallback.clone(),
235            reduce: self.info.reduce.clone(),
236            len: self.info.len,
237            len_read: self.info.len_read,
238        }
239    }
240
241    pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {
242        let client = R::client(device);
243
244        let info = ReduceOptimizationInfo {
245            trace: state.trace,
246            trace_read_fallback: state.trace_read_fallback,
247            trace_write_fallback: state.trace_write_fallback,
248            reduce: state.reduce,
249            len: state.len,
250            len_read: state.len_read,
251            client,
252            device: device.clone(),
253        };
254
255        Self {
256            info: Arc::new(info),
257        }
258    }
259
260    /// Returns the number of output buffers added by fusion.
261    pub fn num_ops_fused(&self) -> usize {
262        self.info.len
263    }
264}
265
266// TODO: Implement better vectorization here.
267impl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}
268
269impl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {
270    type Error = FusedReduceError;
271
272    fn run<'a>(
273        &'a self,
274        client: &'a ComputeClient<R>,
275        inputs: GlobalArgsLaunch<'a, R>,
276        outputs: GlobalArgsLaunch<'a, R>,
277        configs: &'a [FuseBlockConfig],
278    ) -> Result<(), FusedReduceError> {
279        let [config_read, config_write] = [&configs[0], &configs[1]];
280        let shape = match &config_read.ref_layout {
281            RefLayout::Concrete(FuseArg::Output(..)) => {
282                outputs.shape_ref(&config_read.ref_layout, config_read.rank)
283            }
284            _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),
285        };
286        let reduce_count: usize = shape
287            .iter()
288            .enumerate()
289            .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })
290            .product();
291
292        let line_mode = match self.reduce.axis == config_read.rank - 1 {
293            true => LineMode::Parallel,
294            false => LineMode::Perpendicular,
295        };
296
297        let settings = ReduceLineSettings {
298            line_mode,
299            line_size_input: config_read.width,
300            line_size_output: config_write.width,
301        };
302        let problem = ReduceProblem {
303            vector_size: shape[self.reduce.axis],
304            vector_count: reduce_count,
305            axis: self.reduce.axis,
306            dtypes: ReduceDtypes {
307                input: self.reduce.op.input.dtype.into(),
308                output: self.reduce.op.out.dtype.into(),
309                accumulation: self.reduce.acc.into_elem().into(),
310            },
311        };
312
313        let (blueprint, settings) = match self.strategy.clone() {
314            RoutineStrategy::Unit(strategy) => {
315                let routine = UnitRoutine;
316                routine.prepare(client, problem, settings, strategy)?
317            }
318            RoutineStrategy::Plane(strategy) => {
319                let routine = PlaneRoutine;
320                routine.prepare(client, problem, settings, strategy)?
321            }
322            RoutineStrategy::Cube(strategy) => {
323                let routine = CubeRoutine;
324                routine.prepare(client, problem, settings, strategy)?
325            }
326        };
327
328        let kwargs = ReduceKwArgs {
329            client,
330            inputs,
331            outputs,
332            axis: self.reduce.axis,
333            config_fuse_read: config_read.clone(),
334            config_fuse_write: config_write.clone(),
335            input: self.reduce.input.clone(),
336            output: self.reduce.output.clone(),
337            blueprint,
338            settings,
339        };
340        let result = launch_reduce_mixed_precision(
341            kwargs,
342            self.reduce.inst,
343            self.reduce.op.input.dtype,
344            self.reduce.op.out.dtype,
345            DType::from(self.reduce.acc.into_elem()),
346        );
347
348        match result {
349            Ok(_) => Ok(()),
350            Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),
351        }
352    }
353}
354
355struct ReduceKwArgs<'a, 'b, Run: Runtime> {
356    client: &'b ComputeClient<Run>,
357    inputs: GlobalArgsLaunch<'a, Run>,
358    outputs: GlobalArgsLaunch<'a, Run>,
359    axis: usize,
360    blueprint: ReduceBlueprint,
361    settings: ReduceLaunchSettings,
362    config_fuse_read: FuseBlockConfig,
363    config_fuse_write: FuseBlockConfig,
364    input: FuseArg,
365    output: FuseArg,
366}
367
368fn launch_reduce_mixed_precision<Run: Runtime>(
369    kwargs: ReduceKwArgs<'_, '_, Run>,
370    instruction: ReduceInstruction,
371    dtype_input: DType,
372    dtype_output: DType,
373    dtype_acc: DType,
374) -> Result<(), LaunchError> {
375    let config = match instruction {
376        ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,
377        ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,
378        ReduceInstruction::Prod => ReduceOperationConfig::Prod,
379        ReduceInstruction::Mean => ReduceOperationConfig::Mean,
380        ReduceInstruction::Sum => ReduceOperationConfig::Sum,
381        ReduceInstruction::Max => ReduceOperationConfig::Max,
382        ReduceInstruction::Min => ReduceOperationConfig::Min,
383        ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,
384    };
385    launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)
386}
387
388fn launch_reduce<Run: Runtime>(
389    kwargs: ReduceKwArgs<'_, '_, Run>,
390    inst: ReduceOperationConfig,
391    dtype_input: DType,
392    dtype_output: DType,
393    dtype_acc: DType,
394) -> Result<(), LaunchError> {
395    unsafe {
396        reduce_kernel::launch_unchecked::<Run>(
397            kwargs.client,
398            kwargs.settings.cube_count,
399            kwargs.settings.cube_dim,
400            FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),
401            FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),
402            ScalarArg::new(kwargs.axis),
403            kwargs.blueprint,
404            inst,
405            dtype_input.into(),
406            dtype_output.into(),
407            dtype_acc.into(),
408        )
409    }
410}
411
412#[cube(launch_unchecked)]
413pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric>(
414    input: &FusedReduceInput,
415    output: &mut FusedReduceOutput,
416    axis_reduce: usize,
417    #[comptime] blueprint: ReduceBlueprint,
418    #[comptime] config: ReduceOperationConfig,
419    #[define(In)] _input_dtype: StorageType,
420    #[define(Out)] _output_dtype: StorageType,
421    #[define(Acc)] _acc_dtype: StorageType,
422) {
423    let (input, mut output) = init_tensors::<FusedReduceArgs, In, Out>(input, output);
424
425    reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
426}