Skip to main content

cubek_convolution/kernels/forward/
launch.rs

1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::routines::Routine;
3use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
4use crate::{
5    components::ConvolutionOperation, components::global::args::RuntimeArgs,
6    forward::args::ConcreteArgs, launch::ConvolutionArgs,
7};
8use cubecl::{Runtime, client::ComputeClient, prelude::*};
9use cubek_matmul::definition::{AvailableVectorSizes, MatmulElems};
10use cubek_matmul::routines::BlueprintStrategy;
11use cubek_std::{InputBinding, MatrixLayout};
12
13/// Forward-convolution dispatch helper.
14///
15/// Called by `cubek_convolution::launch_ref` after the routine and
16/// blueprint-strategy have been resolved. Not meant for direct external use.
17#[allow(clippy::result_large_err, clippy::too_many_arguments)]
18pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
19    client: &ComputeClient<R>,
20    input: InputBinding<R>,
21    weight: InputBinding<R>,
22    bias: Option<InputBinding<R>>,
23    out: TensorBinding<R>,
24    args: ConvolutionArgs<N_SPATIAL>,
25    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
26    dtypes: MatmulElems,
27) -> Result<(), ConvSetupError>
28where
29    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
30{
31    let ConvolutionArgs {
32        stride,
33        padding,
34        dilation,
35    } = args;
36
37    let dimensionality = match N_SPATIAL {
38        1 => Dimensionality::Dim1,
39        2 => Dimensionality::Dim2,
40        3 => Dimensionality::Dim3,
41        other => unimplemented!("Unsupported dimensionality {other}"),
42    };
43
44    launch_with_routine::<R, Rt>(
45        client,
46        input,
47        weight,
48        bias,
49        out,
50        (&stride, &padding, &dilation),
51        dimensionality,
52        blueprint_strategy,
53        dtypes,
54    )
55}
56
57#[allow(clippy::too_many_arguments)]
58fn launch_with_routine<R: Runtime, Rt: Routine>(
59    client: &ComputeClient<R>,
60    input: InputBinding<R>,
61    weight: InputBinding<R>,
62    bias: Option<InputBinding<R>>,
63    out: TensorBinding<R>,
64    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
65    dimensionality: Dimensionality,
66    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
67    dtypes: MatmulElems,
68) -> Result<(), ConvSetupError>
69where
70    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
71{
72    let rank = input.data().shape.len();
73    let dim_c = rank - 1;
74
75    let n = input.data().shape[0];
76    let c = input.data().shape[dim_c];
77
78    let out_c = weight.data().shape[0];
79
80    let in_shape = &input.data().shape[1..dim_c];
81    let kernel_shape = &weight.data().shape[1..dim_c];
82    let out_shape = &out.shape[1..dim_c];
83
84    let op = ConvolutionOperation::Forward;
85
86    let input_data = Rt::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
87    let weight_data =
88        Rt::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
89
90    let mut input = input.clone();
91    let mut weight = weight.clone();
92
93    *input.data_mut() = input_data;
94    *weight.data_mut() = weight_data;
95
96    let address_type = input
97        .required_address_type()
98        .max(weight.required_address_type())
99        .max(
100            bias.clone()
101                .map(|bias| bias.required_address_type())
102                .unwrap_or_default(),
103        )
104        .max(out.required_address_type(dtypes.acc_global.size()));
105
106    let problem = ConvolutionProblem {
107        m: n * out_shape.iter().product::<usize>(),
108        n: out_c,
109        k: c * kernel_shape.iter().product::<usize>(),
110        lhs_strides: input.data().strides.clone(),
111        rhs_strides: weight.data().strides.clone(),
112        lhs_layout: MatrixLayout::RowMajor,
113        rhs_layout: MatrixLayout::ColMajor,
114        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
115        stride: stride.iter().map(|it| *it as u32).collect(),
116        padding: padding.iter().map(|it| *it as i32).collect(),
117        dilation: dilation.iter().map(|it| *it as u32).collect(),
118
119        batches: n,
120        in_shape: in_shape.into(),
121        out_shape: out_shape.into(),
122        channels: c,
123        out_channels: out_c,
124
125        padded_channels: c,
126        operation: op,
127
128        dimensionality,
129        global_dtypes: dtypes.as_global_elems(),
130        address_type,
131    };
132
133    launch_kernel::<R, Rt>(
134        client,
135        input,
136        weight,
137        bias,
138        out,
139        problem,
140        blueprint_strategy,
141        dtypes,
142    )
143}
144
145#[allow(clippy::result_large_err, clippy::too_many_arguments)]
146pub fn launch_kernel<R: Runtime, Rt: Routine>(
147    client: &ComputeClient<R>,
148    input: InputBinding<R>,
149    weight: InputBinding<R>,
150    bias: Option<InputBinding<R>>,
151    out: TensorBinding<R>,
152    problem: ConvolutionProblem,
153    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
154    dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
158{
159    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
160    // So for the sake of selecting a vector size, the shape/strides are always row-major.
161    let vector_sizes = AvailableVectorSizes::from_type_sizes(
162        client,
163        input.data_elem_size(),
164        weight.data_elem_size(),
165        dtypes.acc_global.size(),
166    )
167    .filter_lhs_with_tensor(
168        &input.data().strides,
169        &input.data().shape,
170        MatrixLayout::RowMajor,
171    )
172    .filter_rhs_with_tensor(
173        &weight.data().strides,
174        &weight.data().shape,
175        MatrixLayout::RowMajor,
176    )
177    .filter_out_with_tensor(&out.strides, &out.shape);
178
179    let mut vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
180
181    // The large vector size resulting from dequantizing ends up slower due to restrictions on
182    // algorithms. Use this as a quick and dirty fix.
183    if input.scale().is_some() {
184        vector_sizes.lhs = 1;
185    }
186    if weight.scale().is_some() {
187        vector_sizes.rhs = 1;
188    }
189
190    launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
191        client,
192        input,
193        weight,
194        bias,
195        out,
196        problem,
197        vector_sizes,
198        blueprint_strategy,
199        &dtypes,
200    )
201}