cubek_convolution/kernels/backward_weight/
launch.rs1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::kernels::backward_weight::selector::launch_kernel_concrete;
3use crate::launch::ConvolutionArgs;
4use crate::{backward_weight::args::ConcreteArgs, components::ConvSetupError};
5use crate::{
6 components::{ConvolutionOperation, global::args::RuntimeArgs},
7 routines::Routine,
8};
9use cubecl::{Runtime, client::ComputeClient, prelude::*};
10use cubek_matmul::{
11 definition::{AvailableVectorSizes, MatmulElems},
12 routines::BlueprintStrategy,
13};
14use cubek_std::{InputBinding, MatrixLayout};
15
16#[allow(clippy::result_large_err, clippy::too_many_arguments)]
21pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
22 client: &ComputeClient<R>,
23 input: InputBinding<R>,
24 out_grad: InputBinding<R>,
25 weight_grad: TensorBinding<R>,
26 args: ConvolutionArgs<N_SPATIAL>,
27 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
28 dtypes: MatmulElems,
29) -> Result<(), ConvSetupError>
30where
31 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
32{
33 let ConvolutionArgs {
34 stride,
35 padding,
36 dilation,
37 } = args;
38
39 let dimensionality = match N_SPATIAL {
40 1 => Dimensionality::Dim1,
41 2 => Dimensionality::Dim2,
42 3 => Dimensionality::Dim3,
43 other => unimplemented!("Unsupported dimensionality {other}"),
44 };
45
46 launch_with_routine::<R, Rt>(
47 client,
48 input,
49 out_grad,
50 weight_grad,
51 (&stride, &padding, &dilation),
52 dimensionality,
53 blueprint_strategy,
54 dtypes,
55 )
56}
57
58#[allow(clippy::too_many_arguments)]
59fn launch_with_routine<R: Runtime, Rt: Routine>(
60 client: &ComputeClient<R>,
61 input: InputBinding<R>,
62 out_grad: InputBinding<R>,
63 weight_grad: TensorBinding<R>,
64 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
65 dimensionality: Dimensionality,
66 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
67 dtypes: MatmulElems,
68) -> Result<(), ConvSetupError>
69where
70 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
71{
72 let rank = input.data().shape.len();
73 let dim_c = rank - 1;
74
75 let n = input.shape()[0];
76 let c = input.shape()[dim_c];
77
78 let out_c = out_grad.shape()[dim_c];
79
80 let in_shape = &input.shape()[1..dim_c];
81 let kernel_shape = &weight_grad.shape[1..dim_c];
82 let out_shape = &out_grad.shape()[1..dim_c];
83
84 let op = ConvolutionOperation::BackwardWeight;
85
86 let input_data = Rt::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
87 let out_grad_data =
88 Rt::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
89
90 let mut input = input.clone();
91 let mut out_grad = out_grad.clone();
92
93 *input.data_mut() = input_data;
94 *out_grad.data_mut() = out_grad_data;
95
96 let address_type = input
97 .required_address_type()
98 .max(out_grad.required_address_type())
99 .max(weight_grad.required_address_type(dtypes.acc_global.size()));
100
101 let problem = ConvolutionProblem {
102 m: out_c,
103 n: c * kernel_shape.iter().product::<usize>(),
104 k: n * out_shape.iter().product::<usize>(),
105 lhs_strides: input.data().strides.clone(),
106 rhs_strides: out_grad.data().strides.clone(),
107 lhs_layout: MatrixLayout::ColMajor,
108 rhs_layout: MatrixLayout::RowMajor,
109 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
110 stride: stride.iter().map(|it| *it as u32).collect(),
111 padding: padding.iter().map(|it| *it as i32).collect(),
112 dilation: dilation.iter().map(|it| *it as u32).collect(),
113
114 batches: n,
115 in_shape: in_shape.into(),
116 out_shape: out_shape.into(),
117 channels: c,
118 out_channels: out_c,
119
120 padded_channels: c,
121 operation: op,
122
123 dimensionality,
124 global_dtypes: dtypes.as_global_elems(),
125 address_type,
126 };
127
128 launch_kernel::<R, Rt>(
129 client,
130 input,
131 out_grad,
132 weight_grad,
133 problem,
134 blueprint_strategy,
135 dtypes,
136 )
137}
138
139#[allow(clippy::result_large_err, clippy::too_many_arguments)]
140pub fn launch_kernel<R: Runtime, Rt: Routine>(
141 client: &ComputeClient<R>,
142 input: InputBinding<R>,
143 out_grad: InputBinding<R>,
144 weight_grad: TensorBinding<R>,
145 problem: ConvolutionProblem,
146 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
147 dtypes: MatmulElems,
148) -> Result<(), ConvSetupError>
149where
150 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
151{
152 let vector_sizes = AvailableVectorSizes::from_type_sizes(
155 client,
156 input.data_elem_size(),
157 out_grad.data_elem_size(),
158 dtypes.acc_global.size(),
159 )
160 .filter_lhs_with_tensor(
161 &out_grad.data().strides,
162 &out_grad.data().shape,
163 MatrixLayout::RowMajor,
164 )
165 .filter_rhs_with_tensor(
166 &input.data().strides,
167 &input.data().shape,
168 MatrixLayout::RowMajor,
169 )
170 .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
171
172 let vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
173
174 launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
175 client,
176 input,
177 out_grad,
178 weight_grad,
179 problem,
180 vector_sizes,
181 blueprint_strategy,
182 &dtypes,
183 )
184}