cubek_convolution/kernels/forward/
launch.rs1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::routines::Routine;
3use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
4use crate::{
5 components::ConvolutionOperation, components::global::args::RuntimeArgs,
6 forward::args::ConcreteArgs, launch::ConvolutionArgs,
7};
8use cubecl::{Runtime, client::ComputeClient, prelude::*};
9use cubek_matmul::definition::{AvailableVectorSizes, MatmulElems};
10use cubek_matmul::routines::BlueprintStrategy;
11use cubek_std::{InputBinding, MatrixLayout};
12
13#[allow(clippy::result_large_err, clippy::too_many_arguments)]
18pub(crate) fn launch_internal<R: Runtime, const N_SPATIAL: usize, Rt: Routine>(
19 client: &ComputeClient<R>,
20 input: InputBinding<R>,
21 weight: InputBinding<R>,
22 bias: Option<InputBinding<R>>,
23 out: TensorBinding<R>,
24 args: ConvolutionArgs<N_SPATIAL>,
25 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
26 dtypes: MatmulElems,
27) -> Result<(), ConvSetupError>
28where
29 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
30{
31 let ConvolutionArgs {
32 stride,
33 padding,
34 dilation,
35 } = args;
36
37 let dimensionality = match N_SPATIAL {
38 1 => Dimensionality::Dim1,
39 2 => Dimensionality::Dim2,
40 3 => Dimensionality::Dim3,
41 other => unimplemented!("Unsupported dimensionality {other}"),
42 };
43
44 launch_with_routine::<R, Rt>(
45 client,
46 input,
47 weight,
48 bias,
49 out,
50 (&stride, &padding, &dilation),
51 dimensionality,
52 blueprint_strategy,
53 dtypes,
54 )
55}
56
57#[allow(clippy::too_many_arguments)]
58fn launch_with_routine<R: Runtime, Rt: Routine>(
59 client: &ComputeClient<R>,
60 input: InputBinding<R>,
61 weight: InputBinding<R>,
62 bias: Option<InputBinding<R>>,
63 out: TensorBinding<R>,
64 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
65 dimensionality: Dimensionality,
66 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
67 dtypes: MatmulElems,
68) -> Result<(), ConvSetupError>
69where
70 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
71{
72 let rank = input.data().shape.len();
73 let dim_c = rank - 1;
74
75 let n = input.data().shape[0];
76 let c = input.data().shape[dim_c];
77
78 let out_c = weight.data().shape[0];
79
80 let in_shape = &input.data().shape[1..dim_c];
81 let kernel_shape = &weight.data().shape[1..dim_c];
82 let out_shape = &out.shape[1..dim_c];
83
84 let op = ConvolutionOperation::Forward;
85
86 let input_data = Rt::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
87 let weight_data =
88 Rt::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
89
90 let mut input = input.clone();
91 let mut weight = weight.clone();
92
93 *input.data_mut() = input_data;
94 *weight.data_mut() = weight_data;
95
96 let address_type = input
97 .required_address_type()
98 .max(weight.required_address_type())
99 .max(
100 bias.clone()
101 .map(|bias| bias.required_address_type())
102 .unwrap_or_default(),
103 )
104 .max(out.required_address_type(dtypes.acc_global.size()));
105
106 let problem = ConvolutionProblem {
107 m: n * out_shape.iter().product::<usize>(),
108 n: out_c,
109 k: c * kernel_shape.iter().product::<usize>(),
110 lhs_strides: input.data().strides.clone(),
111 rhs_strides: weight.data().strides.clone(),
112 lhs_layout: MatrixLayout::RowMajor,
113 rhs_layout: MatrixLayout::ColMajor,
114 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
115 stride: stride.iter().map(|it| *it as u32).collect(),
116 padding: padding.iter().map(|it| *it as i32).collect(),
117 dilation: dilation.iter().map(|it| *it as u32).collect(),
118
119 batches: n,
120 in_shape: in_shape.into(),
121 out_shape: out_shape.into(),
122 channels: c,
123 out_channels: out_c,
124
125 padded_channels: c,
126 operation: op,
127
128 dimensionality,
129 global_dtypes: dtypes.as_global_elems(),
130 address_type,
131 };
132
133 launch_kernel::<R, Rt>(
134 client,
135 input,
136 weight,
137 bias,
138 out,
139 problem,
140 blueprint_strategy,
141 dtypes,
142 )
143}
144
145#[allow(clippy::result_large_err, clippy::too_many_arguments)]
146pub fn launch_kernel<R: Runtime, Rt: Routine>(
147 client: &ComputeClient<R>,
148 input: InputBinding<R>,
149 weight: InputBinding<R>,
150 bias: Option<InputBinding<R>>,
151 out: TensorBinding<R>,
152 problem: ConvolutionProblem,
153 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>,
154 dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157 Rt::Args: ConcreteArgs<Rt::MatmulRoutine>,
158{
159 let vector_sizes = AvailableVectorSizes::from_type_sizes(
162 client,
163 input.data_elem_size(),
164 weight.data_elem_size(),
165 dtypes.acc_global.size(),
166 )
167 .filter_lhs_with_tensor(
168 &input.data().strides,
169 &input.data().shape,
170 MatrixLayout::RowMajor,
171 )
172 .filter_rhs_with_tensor(
173 &weight.data().strides,
174 &weight.data().shape,
175 MatrixLayout::RowMajor,
176 )
177 .filter_out_with_tensor(&out.strides, &out.shape);
178
179 let mut vector_sizes = Rt::filter_vector_sizes(vector_sizes).pick_max()?;
180
181 if input.scale().is_some() {
184 vector_sizes.lhs = 1;
185 }
186 if weight.scale().is_some() {
187 vector_sizes.rhs = 1;
188 }
189
190 launch_kernel_concrete::<R, Rt::Args, Rt::MatmulRoutine>(
191 client,
192 input,
193 weight,
194 bias,
195 out,
196 problem,
197 vector_sizes,
198 blueprint_strategy,
199 &dtypes,
200 )
201}