cubek_convolution/kernels/backward_data/
launch.rs1use crate::{
2 backward_data::args::ConcreteArgs,
3 components::{ConvolutionOperation, global::args::RuntimeArgs},
4 launch::ConvolutionArgs,
5};
6use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
7use crate::{
8 components::{ConvolutionProblem, Dimensionality},
9 routines::Routine,
10};
11use cubecl::{Runtime, client::ComputeClient, prelude::*};
12use cubek_matmul::{
13 definition::{AvailableVectorSizes, MatmulElems, MatmulSetupError},
14 routines::BlueprintStrategy,
15};
16use cubek_std::{InputBinding, MatrixLayout};
17
18#[allow(clippy::result_large_err, clippy::too_many_arguments)]
24pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
25 client: &ComputeClient<R>,
26 out_grad: InputBinding<R>,
27 weights: InputBinding<R>,
28 in_grad: TensorBinding<R>,
29 args: ConvolutionArgs<N_SPATIAL>,
30 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
31 dtypes: MatmulElems,
32) -> Result<(), ConvSetupError>
33where
34 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
35{
36 let ConvolutionArgs {
37 stride,
38 padding,
39 dilation,
40 } = args;
41
42 let dimensionality = match N_SPATIAL {
43 1 => Dimensionality::Dim1,
44 2 => Dimensionality::Dim2,
45 3 => Dimensionality::Dim3,
46 other => unimplemented!("Unsupported dimensionality {other}"),
47 };
48
49 launch_with_routine::<R, Rt>(
50 client,
51 out_grad,
52 weights,
53 in_grad,
54 (&stride, &padding, &dilation),
55 dimensionality,
56 blueprint_strategy,
57 dtypes,
58 )
59}
60
61#[allow(clippy::too_many_arguments)]
62fn launch_with_routine<R: Runtime, Rt: Routine>(
63 client: &ComputeClient<R>,
64 out_grad: InputBinding<R>,
65 weights: InputBinding<R>,
66 in_grad: TensorBinding<R>,
67 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
68 dimensionality: Dimensionality,
69 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
70 dtypes: MatmulElems,
71) -> Result<(), ConvSetupError>
72where
73 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
74{
75 let rank = in_grad.shape.len();
76 let dim_c = rank - 1;
77
78 let n = in_grad.shape[0];
79 let c = in_grad.shape[dim_c];
80
81 let out_c = out_grad.shape()[dim_c];
82
83 let in_shape = &in_grad.shape[1..dim_c];
84 let kernel_shape = &weights.shape()[1..dim_c];
85 let out_shape = &out_grad.shape()[1..dim_c];
86
87 let op = ConvolutionOperation::BackwardData;
88
89 let out_grad_tmp = out_grad.clone();
90 let weights_tmp = weights.clone();
91
92 let out_grad_data =
93 Rt::correct_layout(client, out_grad_tmp.into_data(), dtypes.lhs_global, op)?;
94 let weights_data = Rt::correct_layout(client, weights_tmp.into_data(), dtypes.rhs_global, op)?;
95
96 let mut out_grad = out_grad.clone();
97 let mut weights = weights.clone();
98
99 *out_grad.data_mut() = out_grad_data;
100 *weights.data_mut() = weights_data;
101
102 let address_type = out_grad
103 .required_address_type()
104 .max(weights.required_address_type())
105 .max(in_grad.required_address_type(dtypes.acc_global.size()));
106
107 let problem = ConvolutionProblem {
108 m: n * in_shape.iter().product::<usize>(),
109 n: c,
110 k: out_c * kernel_shape.iter().product::<usize>(),
111
112 lhs_strides: out_grad.data().strides.clone(),
113 rhs_strides: weights.data().strides.clone(),
114 lhs_layout: MatrixLayout::RowMajor,
115 rhs_layout: MatrixLayout::RowMajor,
116 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
117 stride: stride.iter().map(|it| *it as u32).collect(),
118 padding: padding.iter().map(|it| *it as i32).collect(),
119 dilation: dilation.iter().map(|it| *it as u32).collect(),
120
121 batches: n,
122 in_shape: in_shape.into(),
123 out_shape: out_shape.into(),
124 channels: c,
125 out_channels: out_c,
126
127 padded_channels: out_c,
128 operation: op,
129
130 dimensionality,
131 global_dtypes: dtypes.as_global_elems(),
132 address_type,
133 };
134
135 launch_kernel::<R, Rt>(
136 client,
137 out_grad,
138 weights,
139 in_grad,
140 problem,
141 blueprint_strategy,
142 dtypes,
143 )
144}
145
146#[allow(clippy::result_large_err, clippy::too_many_arguments)]
147pub fn launch_kernel<R: Runtime, Rt: Routine>(
148 client: &ComputeClient<R>,
149 out_grad: InputBinding<R>,
150 weights: InputBinding<R>,
151 in_grad: TensorBinding<R>,
152 problem: ConvolutionProblem,
153 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
154 dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
158{
159 let vector_sizes = AvailableVectorSizes::from_type_sizes(
162 client,
163 out_grad.data_elem_size(),
164 weights.data_elem_size(),
165 dtypes.acc_global.size(),
166 )
167 .filter_lhs_with_tensor(
168 &out_grad.data().strides,
169 &out_grad.data().shape,
170 MatrixLayout::RowMajor,
171 )
172 .filter_rhs_with_tensor(
173 &weights.data().strides,
174 &weights.data().shape,
175 MatrixLayout::RowMajor,
176 )
177 .filter_out_with_tensor(&in_grad.strides, &in_grad.shape);
178
179 let vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
180
181 launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
182 client,
183 out_grad,
184 weights,
185 in_grad,
186 problem,
187 vector_sizes,
188 blueprint_strategy,
189 &dtypes,
190 )
191}
192
193#[allow(dead_code)]
197pub(crate) fn unsupported_tma_error() -> ConvSetupError {
198 ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(Box::new(
199 "Data backprop doesn't yet work with current TMA tiling strategy",
200 )))
201}