Skip to main content

cubek_reduce/launch/
base.rs

1use crate::{
2    ReduceError, ReducePrecision, VectorizationMode,
3    components::{
4        args::{NumericVector, ReduceArgs, TensorArgs, init_tensors},
5        global::{
6            cube::GlobalFullCubeReduce, plane::GlobalFullPlaneReduce, unit::GlobalFullUnitReduce,
7        },
8        instructions::*,
9    },
10    launch::{ReduceStrategy, RoutineStrategy, generate_vector_size},
11    output_vectorization_axis,
12    routines::{
13        GlobalReduceBlueprint, ReduceBlueprint, ReduceProblem, ReduceVectorSettings, Routine,
14        cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
15    },
16};
17use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};
18
19#[derive(Clone, Copy, Debug)]
20pub struct ReduceDtypes {
21    pub input: StorageType,
22    pub output: StorageType,
23    pub accumulation: StorageType,
24}
25
26/// Launch a reduce kernel. This function assumes that all parameters are already validated.
27/// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function
28/// with the appropriate assumptions.
29#[allow(clippy::too_many_arguments)]
30pub(crate) fn launch_reduce<Run: Runtime>(
31    client: &ComputeClient<Run>,
32    input: TensorBinding<Run>,
33    output: TensorBinding<Run>,
34    reduce_axis: usize,
35    strategy: ReduceStrategy,
36    dtypes: ReduceDtypes,
37    inst: ReduceOperationConfig,
38) -> Result<(), ReduceError> {
39    let address_type = input
40        .required_address_type(dtypes.input.size())
41        .max(output.required_address_type(dtypes.output.size()));
42
43    // Number of distinct reductions = product of non-reduce input dims.
44    let reduce_len = input.shape[reduce_axis];
45    let input_elems: usize = input.shape.iter().copied().product();
46    let reduce_count = input_elems / reduce_len;
47
48    let problem = ReduceProblem {
49        reduce_len,
50        reduce_count,
51        axis: reduce_axis,
52        dtypes,
53        address_type,
54    };
55    let vectorization_mode = match input.strides[reduce_axis] {
56        1 => VectorizationMode::Parallel,
57        _ => VectorizationMode::Perpendicular,
58    };
59
60    let out_vec_axis = output_vectorization_axis(&input.strides, reduce_axis, vectorization_mode);
61
62    let (vector_size_input, vector_size_output) = generate_vector_size::<Run>(
63        client,
64        &input,
65        &output,
66        reduce_axis,
67        problem.dtypes.input,
68        vectorization_mode,
69        &strategy.vectorization,
70    );
71    let settings = ReduceVectorSettings {
72        vectorization_mode,
73        vector_size_input,
74        vector_size_output,
75    };
76
77    let (blueprint, settings) = match strategy.routine {
78        RoutineStrategy::Unit(strategy) => {
79            let routine = UnitRoutine;
80            routine.prepare(client, problem, settings, strategy)?
81        }
82        RoutineStrategy::Plane(strategy) => {
83            let routine = PlaneRoutine;
84            routine.prepare(client, problem, settings, strategy)?
85        }
86        RoutineStrategy::Cube(strategy) => {
87            let routine = CubeRoutine;
88            routine.prepare(client, problem, settings, strategy)?
89        }
90    };
91
92    unsafe {
93        reduce_kernel::launch_unchecked::<TensorArgs, Run>(
94            client,
95            settings.cube_count,
96            settings.cube_dim,
97            settings.address_type,
98            settings.vector.vector_size_input,
99            settings.vector.vector_size_output,
100            input.into_tensor_arg(),
101            output.into_tensor_arg(),
102            reduce_axis,
103            out_vec_axis,
104            blueprint,
105            inst,
106            dtypes.input,
107            dtypes.output,
108            dtypes.accumulation,
109        )
110    };
111
112    Ok(())
113}
114
115#[cube(launch_unchecked, address_type = "dynamic")]
116pub fn reduce_kernel<
117    In: Numeric,
118    InSize: Size,
119    Out: Numeric,
120    OutSize: Size,
121    Acc: Numeric,
122    RA: ReduceArgs,
123>(
124    input: &RA::Input<In, InSize>,
125    output: &mut RA::Output<Out, OutSize>,
126    reduce_axis: usize,
127    out_vec_axis: usize,
128    #[comptime] blueprint: ReduceBlueprint,
129    #[comptime] config: ReduceOperationConfig,
130    #[define(In)] _input_dtype: StorageType,
131    #[define(Out)] _output_dtype: StorageType,
132    #[define(Acc)] _acc_dtype: StorageType,
133) {
134    let (input, mut output) = init_tensors::<RA, In, InSize, Out, OutSize>(input, output);
135    reduce_kernel_virtual::<In, InSize, Out, OutSize, Acc>(
136        &input,
137        &mut output,
138        reduce_axis,
139        out_vec_axis,
140        blueprint,
141        config,
142    );
143}
144
145#[cube]
146pub fn reduce_kernel_virtual<
147    In: Numeric,
148    InSize: Size,
149    Out: Numeric,
150    OutSize: Size,
151    Acc: Numeric,
152>(
153    input: &VirtualTensor<In, InSize>,
154    output: &mut VirtualTensor<Out, OutSize, ReadWrite>,
155    reduce_axis: usize,
156    out_vec_axis: usize,
157    #[comptime] blueprint: ReduceBlueprint,
158    #[comptime] config: ReduceOperationConfig,
159) {
160    reduce_kernel_inner::<(In, InSize, Acc), (Out, OutSize), ReduceOperation>(
161        input,
162        output,
163        reduce_axis,
164        out_vec_axis,
165        blueprint,
166        config,
167    )
168}
169
170#[cube]
171fn reduce_kernel_inner<P: ReducePrecision, Out: NumericVector, R: ReduceFamily>(
172    input: &VirtualTensor<P::EI, P::SI>,
173    output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,
174    reduce_axis: usize,
175    out_vec_axis: usize,
176    #[comptime] blueprint: ReduceBlueprint,
177    #[comptime] config: R::Config,
178) {
179    let inst = &R::Instruction::<P>::from_config(config);
180
181    match blueprint.global {
182        GlobalReduceBlueprint::Cube(cube) => {
183            GlobalFullCubeReduce::execute::<P, Out, R::Instruction<P>>(
184                input,
185                output,
186                reduce_axis,
187                out_vec_axis,
188                inst,
189                blueprint.vectorization_mode,
190                cube,
191            )
192        }
193        GlobalReduceBlueprint::Plane(plane) => {
194            GlobalFullPlaneReduce::execute::<P, Out, R::Instruction<P>>(
195                input,
196                output,
197                reduce_axis,
198                out_vec_axis,
199                inst,
200                blueprint.vectorization_mode,
201                plane,
202            )
203        }
204        GlobalReduceBlueprint::Unit(unit) => {
205            GlobalFullUnitReduce::execute::<P, Out, R::Instruction<P>>(
206                input,
207                output,
208                reduce_axis,
209                out_vec_axis,
210                inst,
211                blueprint.vectorization_mode,
212                unit,
213            )
214        }
215    };
216}