1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*, tensor_line_size_parallel};
4use half::f16;
5
6use crate::matmul::{
7 components::global::args::{ConcreteOutputFactory, MatmulArgs},
8 kernels::MatmulLaunchError,
9};
10use crate::{
11 convolution::base::ConvolutionLaunch,
12 matmul::components::{self, InputIdent, MatmulPrecision, MatmulSelection},
13};
14
15use super::{
16 ConvLaunchError,
17 algorithm::{Algorithm, StageInput},
18 args::ConvInputsLaunch,
19 base::ConvolutionProblem,
20 selection::select_matmul,
21};
22
23type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<<MP as MatmulPrecision>::EI>;
24type Output<Alg, MP> =
25 <<Alg as Algorithm>::Args as MatmulArgs>::Output<<MP as MatmulPrecision>::EO>;
26
27#[derive(Clone)]
28pub struct ConvolutionArgs {
29 pub stride: (usize, usize),
30 pub padding: (usize, usize),
31 pub dilation: (usize, usize),
32}
33
34#[allow(clippy::result_large_err)]
43pub fn launch_conv2d_nhwc<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
44 client: &ComputeClient<R::Server, R::Channel>,
45 input: &TensorHandleRef<'_, R>,
46 weight: &TensorHandleRef<'_, R>,
47 bias: &Option<TensorHandleRef<'_, R>>,
48 out: &TensorHandleRef<'_, R>,
49 args: ConvolutionArgs,
50) -> Result<(), ConvLaunchError>
51where
52 Input<Alg, MP>: ConvInputsLaunch,
53 Output<Alg, MP>: ConcreteOutputFactory,
54{
55 let ConvolutionArgs {
56 stride,
57 padding,
58 dilation,
59 } = args;
60
61 let [n, h, w, c] = input.shape.try_into().unwrap();
62 let [out_c, kh, kw, _] = weight.shape.try_into().unwrap();
63 let out_h = out.shape[1];
64 let out_w = out.shape[2];
65
66 let input = Alg::into_tensor_handle::<R, MP::EI>(client, input, InputIdent::Lhs);
67 let weight = Alg::into_tensor_handle::<R, MP::EI>(client, weight, InputIdent::Rhs);
68
69 let ei_elem = MP::EI::as_elem_native_unchecked();
70 let eo_elem = MP::EO::as_elem_native_unchecked();
71
72 let lhs_line_size =
73 tensor_line_size_parallel(R::line_size_elem(&ei_elem), &input.shape, &input.strides, 3);
74 let rhs_line_size = tensor_line_size_parallel(
75 R::line_size_elem(&ei_elem),
76 &weight.shape,
77 &weight.strides,
78 3,
79 );
80
81 let out_line_size =
82 tensor_line_size_parallel(R::line_size_elem(&eo_elem), out.shape, out.strides, 3);
83
84 let plane_dim = client
85 .properties()
86 .hardware_properties()
87 .defined_plane_size()
88 .unwrap_or(32);
89
90 let problem = ConvolutionProblem {
91 m: n * out_h * out_w,
92 n: out_c,
93 k: c * kh * kw,
94 lhs_layout: components::MatrixLayout::RowMajor,
95 rhs_layout: components::MatrixLayout::ColMajor,
96 lhs_line_size,
97 rhs_line_size,
98 out_line_size,
99 kernel_size: (kh as u32, kw as u32),
100 stride: (stride.0 as u32, stride.1 as u32),
101 padding: (padding.0 as i32, padding.1 as i32),
102 dilation: (dilation.0 as u32, dilation.1 as u32),
103
104 batches: n,
105 height: h,
106 width: w,
107 channels: c,
108
109 out_h,
110 out_w,
111 };
112
113 let (selection, config_input) = select_matmul::<Alg, R, MP>(client, &problem, plane_dim);
114
115 let launch = if TypeId::of::<MP::EI>() == TypeId::of::<f32>() {
116 if tf32::is_supported(client) {
117 launch_kernel::<R, (MP::EI, tf32, f32, MP::EO), Alg>
118 } else {
119 launch_kernel::<R, (MP::EI, f16, f32, MP::EO), Alg>
120 }
121 } else {
122 launch_kernel::<R, MP, Alg>
123 };
124
125 launch(
126 client,
127 &input.as_ref(),
128 &weight.as_ref(),
129 bias,
130 out,
131 problem,
132 selection,
133 config_input,
134 )
135}
136
137#[allow(clippy::result_large_err, clippy::too_many_arguments)]
138pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
139 client: &ComputeClient<R::Server, R::Channel>,
140 input: &TensorHandleRef<'_, R>,
141 weight: &TensorHandleRef<'_, R>,
142 bias: &Option<TensorHandleRef<'_, R>>,
143 out: &TensorHandleRef<'_, R>,
144 problem: ConvolutionProblem,
145 selection: MatmulSelection,
146 config_input: StageInput,
147) -> Result<(), ConvLaunchError>
148where
149 Input<Alg, MP>: ConvInputsLaunch,
150 Output<Alg, MP>: ConcreteOutputFactory,
151{
152 let out_shape = [out.shape[0..3].iter().product(), out.shape[3]];
154 let out_strides = [out.strides[2], out.strides[3]];
155
156 let out = unsafe {
157 TensorHandleRef::from_raw_parts(out.handle, &out_strides, &out_shape, out.elem_size)
158 };
159
160 let cube_dim = Alg::cube_dim(&selection);
161 let cube_count = Alg::cube_count(&selection, &problem);
162
163 let config = Alg::make_config(config_input, &problem, &cube_dim, &cube_count)
164 .map_err(MatmulLaunchError::InvalidConfig)?;
165
166 Alg::check_availability::<R, MP>(client, &config)?;
167
168 let input = <Input<Alg, MP> as ConvInputsLaunch>::create(input, weight, &selection, &problem);
169 let output = <Output<Alg, MP> as ConcreteOutputFactory>::create(
170 &out,
171 &selection,
172 &problem.as_matmul_problem(),
173 );
174 let bias = bias
175 .as_ref()
176 .map(|bias| bias.as_tensor_arg(problem.out_line_size));
177
178 unsafe {
179 Alg::GlobalConvolution::launch_unchecked::<(MP, Alg::Args), R>(
180 client, cube_dim, cube_count, input, bias, output, &problem, config,
181 );
182 }
183
184 Ok(())
185}