Skip to main content

cubek_convolution/kernels/backward_weight/
launch.rs

1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::kernels::backward_weight::selector::launch_kernel_concrete;
3use crate::launch::ConvolutionArgs;
4use crate::{backward_weight::args::ConcreteArgs, components::ConvSetupError};
5use crate::{
6    components::{ConvolutionOperation, global::args::RuntimeArgs},
7    routines::Routine,
8};
9use cubecl::{Runtime, client::ComputeClient, prelude::*};
10use cubek_matmul::{
11    definition::{AvailableVectorSizes, MatmulElems},
12    routines::BlueprintStrategy,
13};
14use cubek_std::{InputBinding, MatrixLayout};
15
16/// Backward-weight dispatch helper.
17///
18/// Called by `cubek_convolution::launch_ref` after the routine and
19/// blueprint-strategy have been resolved.
20#[allow(clippy::result_large_err, clippy::too_many_arguments)]
21pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
22    client: &ComputeClient<R>,
23    input: InputBinding<R>,
24    out_grad: InputBinding<R>,
25    weight_grad: TensorBinding<R>,
26    args: ConvolutionArgs<N_SPATIAL>,
27    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
28    dtypes: MatmulElems,
29) -> Result<(), ConvSetupError>
30where
31    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
32{
33    let ConvolutionArgs {
34        stride,
35        padding,
36        dilation,
37    } = args;
38
39    let dimensionality = match N_SPATIAL {
40        1 => Dimensionality::Dim1,
41        2 => Dimensionality::Dim2,
42        3 => Dimensionality::Dim3,
43        other => unimplemented!("Unsupported dimensionality {other}"),
44    };
45
46    launch_with_routine::<R, Rt>(
47        client,
48        input,
49        out_grad,
50        weight_grad,
51        (&stride, &padding, &dilation),
52        dimensionality,
53        blueprint_strategy,
54        dtypes,
55    )
56}
57
58#[allow(clippy::too_many_arguments)]
59fn launch_with_routine<R: Runtime, Rt: Routine>(
60    client: &ComputeClient<R>,
61    input: InputBinding<R>,
62    out_grad: InputBinding<R>,
63    weight_grad: TensorBinding<R>,
64    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
65    dimensionality: Dimensionality,
66    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
67    dtypes: MatmulElems,
68) -> Result<(), ConvSetupError>
69where
70    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
71{
72    let rank = input.data().shape.len();
73    let dim_c = rank - 1;
74
75    let n = input.shape()[0];
76    let c = input.shape()[dim_c];
77
78    let out_c = out_grad.shape()[dim_c];
79
80    let in_shape = &input.shape()[1..dim_c];
81    let kernel_shape = &weight_grad.shape[1..dim_c];
82    let out_shape = &out_grad.shape()[1..dim_c];
83
84    let op = ConvolutionOperation::BackwardWeight;
85
86    let input_data = Rt::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
87    let out_grad_data =
88        Rt::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
89
90    let mut input = input.clone();
91    let mut out_grad = out_grad.clone();
92
93    *input.data_mut() = input_data;
94    *out_grad.data_mut() = out_grad_data;
95
96    let address_type = input
97        .required_address_type()
98        .max(out_grad.required_address_type())
99        .max(weight_grad.required_address_type(dtypes.acc_global.size()));
100
101    let problem = ConvolutionProblem {
102        m: out_c,
103        n: c * kernel_shape.iter().product::<usize>(),
104        k: n * out_shape.iter().product::<usize>(),
105        lhs_strides: input.data().strides.clone(),
106        rhs_strides: out_grad.data().strides.clone(),
107        lhs_layout: MatrixLayout::ColMajor,
108        rhs_layout: MatrixLayout::RowMajor,
109        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
110        stride: stride.iter().map(|it| *it as u32).collect(),
111        padding: padding.iter().map(|it| *it as i32).collect(),
112        dilation: dilation.iter().map(|it| *it as u32).collect(),
113
114        batches: n,
115        in_shape: in_shape.into(),
116        out_shape: out_shape.into(),
117        channels: c,
118        out_channels: out_c,
119
120        padded_channels: c,
121        operation: op,
122
123        dimensionality,
124        global_dtypes: dtypes.as_global_elems(),
125        address_type,
126    };
127
128    launch_kernel::<R, Rt>(
129        client,
130        input,
131        out_grad,
132        weight_grad,
133        problem,
134        blueprint_strategy,
135        dtypes,
136    )
137}
138
139#[allow(clippy::result_large_err, clippy::too_many_arguments)]
140pub fn launch_kernel<R: Runtime, Rt: Routine>(
141    client: &ComputeClient<R>,
142    input: InputBinding<R>,
143    out_grad: InputBinding<R>,
144    weight_grad: TensorBinding<R>,
145    problem: ConvolutionProblem,
146    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
147    dtypes: MatmulElems,
148) -> Result<(), ConvSetupError>
149where
150    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
151{
152    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
153    // So for the sake of selecting a vector size, the shape/strides are always row-major.
154    let vector_sizes = AvailableVectorSizes::from_type_sizes(
155        client,
156        input.data_elem_size(),
157        out_grad.data_elem_size(),
158        dtypes.acc_global.size(),
159    )
160    .filter_lhs_with_tensor(
161        &out_grad.data().strides,
162        &out_grad.data().shape,
163        MatrixLayout::RowMajor,
164    )
165    .filter_rhs_with_tensor(
166        &input.data().strides,
167        &input.data().shape,
168        MatrixLayout::RowMajor,
169    )
170    .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
171
172    let vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
173
174    launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
175        client,
176        input,
177        out_grad,
178        weight_grad,
179        problem,
180        vector_sizes,
181        blueprint_strategy,
182        &dtypes,
183    )
184}