Skip to main content

cubek_convolution/kernels/backward_data/
launch.rs

1use crate::{
2    backward_data::args::ConcreteArgs,
3    components::{ConvolutionOperation, global::args::RuntimeArgs},
4    launch::ConvolutionArgs,
5};
6use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
7use crate::{
8    components::{ConvolutionProblem, Dimensionality},
9    routines::Routine,
10};
11use cubecl::{Runtime, client::ComputeClient, prelude::*};
12use cubek_matmul::{
13    definition::{AvailableVectorSizes, MatmulElems, MatmulSetupError},
14    routines::BlueprintStrategy,
15};
16use cubek_std::{InputBinding, MatrixLayout};
17
18/// Backward-data dispatch helper.
19///
20/// Called by `cubek_convolution::launch_ref` after the routine and
21/// blueprint-strategy have been resolved. Backward-data does not currently
22/// support the TMA reading strategy: requesting it here returns a setup error.
23#[allow(clippy::result_large_err, clippy::too_many_arguments)]
24pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
25    client: &ComputeClient<R>,
26    out_grad: InputBinding<R>,
27    weights: InputBinding<R>,
28    in_grad: TensorBinding<R>,
29    args: ConvolutionArgs<N_SPATIAL>,
30    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
31    dtypes: MatmulElems,
32) -> Result<(), ConvSetupError>
33where
34    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
35{
36    let ConvolutionArgs {
37        stride,
38        padding,
39        dilation,
40    } = args;
41
42    let dimensionality = match N_SPATIAL {
43        1 => Dimensionality::Dim1,
44        2 => Dimensionality::Dim2,
45        3 => Dimensionality::Dim3,
46        other => unimplemented!("Unsupported dimensionality {other}"),
47    };
48
49    launch_with_routine::<R, Rt>(
50        client,
51        out_grad,
52        weights,
53        in_grad,
54        (&stride, &padding, &dilation),
55        dimensionality,
56        blueprint_strategy,
57        dtypes,
58    )
59}
60
61#[allow(clippy::too_many_arguments)]
62fn launch_with_routine<R: Runtime, Rt: Routine>(
63    client: &ComputeClient<R>,
64    out_grad: InputBinding<R>,
65    weights: InputBinding<R>,
66    in_grad: TensorBinding<R>,
67    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
68    dimensionality: Dimensionality,
69    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
70    dtypes: MatmulElems,
71) -> Result<(), ConvSetupError>
72where
73    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
74{
75    let rank = in_grad.shape.len();
76    let dim_c = rank - 1;
77
78    let n = in_grad.shape[0];
79    let c = in_grad.shape[dim_c];
80
81    let out_c = out_grad.shape()[dim_c];
82
83    let in_shape = &in_grad.shape[1..dim_c];
84    let kernel_shape = &weights.shape()[1..dim_c];
85    let out_shape = &out_grad.shape()[1..dim_c];
86
87    let op = ConvolutionOperation::BackwardData;
88
89    let out_grad_tmp = out_grad.clone();
90    let weights_tmp = weights.clone();
91
92    let out_grad_data =
93        Rt::correct_layout(client, out_grad_tmp.into_data(), dtypes.lhs_global, op)?;
94    let weights_data = Rt::correct_layout(client, weights_tmp.into_data(), dtypes.rhs_global, op)?;
95
96    let mut out_grad = out_grad.clone();
97    let mut weights = weights.clone();
98
99    *out_grad.data_mut() = out_grad_data;
100    *weights.data_mut() = weights_data;
101
102    let address_type = out_grad
103        .required_address_type()
104        .max(weights.required_address_type())
105        .max(in_grad.required_address_type(dtypes.acc_global.size()));
106
107    let problem = ConvolutionProblem {
108        m: n * in_shape.iter().product::<usize>(),
109        n: c,
110        k: out_c * kernel_shape.iter().product::<usize>(),
111
112        lhs_strides: out_grad.data().strides.clone(),
113        rhs_strides: weights.data().strides.clone(),
114        lhs_layout: MatrixLayout::RowMajor,
115        rhs_layout: MatrixLayout::RowMajor,
116        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
117        stride: stride.iter().map(|it| *it as u32).collect(),
118        padding: padding.iter().map(|it| *it as i32).collect(),
119        dilation: dilation.iter().map(|it| *it as u32).collect(),
120
121        batches: n,
122        in_shape: in_shape.into(),
123        out_shape: out_shape.into(),
124        channels: c,
125        out_channels: out_c,
126
127        padded_channels: out_c,
128        operation: op,
129
130        dimensionality,
131        global_dtypes: dtypes.as_global_elems(),
132        address_type,
133    };
134
135    launch_kernel::<R, Rt>(
136        client,
137        out_grad,
138        weights,
139        in_grad,
140        problem,
141        blueprint_strategy,
142        dtypes,
143    )
144}
145
146#[allow(clippy::result_large_err, clippy::too_many_arguments)]
147pub fn launch_kernel<R: Runtime, Rt: Routine>(
148    client: &ComputeClient<R>,
149    out_grad: InputBinding<R>,
150    weights: InputBinding<R>,
151    in_grad: TensorBinding<R>,
152    problem: ConvolutionProblem,
153    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
154    dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157    Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
158{
159    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
160    // So for the sake of selecting a vector size, the shape/strides are always row-major.
161    let vector_sizes = AvailableVectorSizes::from_type_sizes(
162        client,
163        out_grad.data_elem_size(),
164        weights.data_elem_size(),
165        dtypes.acc_global.size(),
166    )
167    .filter_lhs_with_tensor(
168        &out_grad.data().strides,
169        &out_grad.data().shape,
170        MatrixLayout::RowMajor,
171    )
172    .filter_rhs_with_tensor(
173        &weights.data().strides,
174        &weights.data().shape,
175        MatrixLayout::RowMajor,
176    )
177    .filter_out_with_tensor(&in_grad.strides, &in_grad.shape);
178
179    let vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
180
181    launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
182        client,
183        out_grad,
184        weights,
185        in_grad,
186        problem,
187        vector_sizes,
188        blueprint_strategy,
189        &dtypes,
190    )
191}
192
193/// Returned by the unified `launch_ref` when the requested routine is not
194/// supported for backward-data. Currently only the TMA reading strategy is
195/// rejected.
196#[allow(dead_code)]
197pub(crate) fn unsupported_tma_error() -> ConvSetupError {
198    ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(Box::new(
199        "Data backprop doesn't yet work with current TMA tiling strategy",
200    )))
201}