cubek_reduce/launch/
base.rs

1use crate::{
2    LineMode, ReduceError, ReducePrecision,
3    components::{
4        args::{ReduceArgs, TensorArgs, init_tensors},
5        global::{
6            cube::GlobalFullCubeReduce, plane::GlobalFullPlaneReduce, unit::GlobalFullUnitReduce,
7        },
8        instructions::*,
9    },
10    launch::{ReduceStrategy, RoutineStrategy, generate_line_size},
11    routines::{
12        GlobalReduceBlueprint, ReduceBlueprint, ReduceLineSettings, ReduceProblem, Routine,
13        cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
14    },
15};
16use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};
17
18#[derive(Clone, Copy, Debug)]
19pub struct ReduceDtypes {
20    pub input: StorageType,
21    pub output: StorageType,
22    pub accumulation: StorageType,
23}
24
25/// Launch a reduce kernel. This function assumes that all parameters are already validated.
26/// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function
27/// with the appropriate assumptions.
28#[allow(clippy::too_many_arguments)]
29pub(crate) fn launch_reduce<Run: Runtime>(
30    client: &ComputeClient<Run>,
31    input: TensorHandleRef<Run>,
32    output: TensorHandleRef<Run>,
33    axis: usize,
34    strategy: ReduceStrategy,
35    dtypes: ReduceDtypes,
36    inst: ReduceOperationConfig,
37) -> Result<(), ReduceError> {
38    let problem = ReduceProblem {
39        vector_size: input.shape[axis],
40        vector_count: output.shape.iter().copied().product(),
41        axis,
42        dtypes,
43    };
44    let line_mode = match input.strides[axis] {
45        1 => LineMode::Parallel,
46        _ => LineMode::Perpendicular,
47    };
48    let (line_size_input, line_size_output) = generate_line_size::<Run>(
49        client,
50        &input,
51        &output,
52        axis,
53        problem.dtypes.input,
54        line_mode,
55        &strategy.line_size,
56    );
57    let settings = ReduceLineSettings {
58        line_mode,
59        line_size_input,
60        line_size_output,
61    };
62
63    let (blueprint, settings) = match strategy.routine {
64        RoutineStrategy::Unit(strategy) => {
65            let routine = UnitRoutine;
66            routine.prepare(client, problem, settings, strategy)?
67        }
68        RoutineStrategy::Plane(strategy) => {
69            let routine = PlaneRoutine;
70            routine.prepare(client, problem, settings, strategy)?
71        }
72        RoutineStrategy::Cube(strategy) => {
73            let routine = CubeRoutine;
74            routine.prepare(client, problem, settings, strategy)?
75        }
76    };
77
78    unsafe {
79        reduce_kernel::launch_unchecked::<TensorArgs, Run>(
80            client,
81            settings.cube_count,
82            settings.cube_dim,
83            input.as_tensor_arg(settings.line.line_size_input),
84            output.as_tensor_arg(settings.line.line_size_output),
85            ScalarArg::new(axis),
86            blueprint,
87            inst,
88            dtypes.input,
89            dtypes.output,
90            dtypes.accumulation,
91        )
92        .map_err(ReduceError::Launch)
93    }
94}
95
96#[cube(launch_unchecked)]
97pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, RA: ReduceArgs>(
98    input: &RA::Input<In>,
99    output: &mut RA::Output<Out>,
100    axis_reduce: usize,
101    #[comptime] blueprint: ReduceBlueprint,
102    #[comptime] config: ReduceOperationConfig,
103    #[define(In)] _input_dtype: StorageType,
104    #[define(Out)] _output_dtype: StorageType,
105    #[define(Acc)] _acc_dtype: StorageType,
106) {
107    let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
108    reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
109}
110
111#[cube]
112pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric>(
113    input: &VirtualTensor<In>,
114    output: &mut VirtualTensor<Out, ReadWrite>,
115    axis_reduce: usize,
116    #[comptime] blueprint: ReduceBlueprint,
117    #[comptime] config: ReduceOperationConfig,
118) {
119    reduce_kernel_inner::<(In, Acc), Out, ReduceOperation>(
120        input,
121        output,
122        axis_reduce,
123        blueprint,
124        config,
125    )
126}
127
128#[cube]
129fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
130    input: &VirtualTensor<P::EI>,
131    output: &mut VirtualTensor<Out, ReadWrite>,
132    axis_reduce: usize,
133    #[comptime] blueprint: ReduceBlueprint,
134    #[comptime] config: R::Config,
135) {
136    let inst = &R::Instruction::<P>::from_config(config);
137
138    match blueprint.global {
139        GlobalReduceBlueprint::Cube(cube) => {
140            GlobalFullCubeReduce::execute::<P, Out, R::Instruction<P>>(
141                input,
142                output,
143                axis_reduce,
144                inst,
145                blueprint.line_mode,
146                cube,
147            )
148        }
149        GlobalReduceBlueprint::Plane(plane) => {
150            GlobalFullPlaneReduce::execute::<P, Out, R::Instruction<P>>(
151                input,
152                output,
153                axis_reduce,
154                inst,
155                blueprint.line_mode,
156                plane,
157            )
158        }
159        GlobalReduceBlueprint::Unit(unit) => {
160            GlobalFullUnitReduce::execute::<P, Out, R::Instruction<P>>(
161                input,
162                output,
163                axis_reduce,
164                inst,
165                blueprint.line_mode,
166                unit,
167            )
168        }
169    };
170}