1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
4use cubecl_matmul::{MatmulInputHandleRef, components::AccG};
5use cubecl_runtime::TypeUsage;
6use half::f16;
7
8use crate::{
9    components::{ConvGemmConfig as _, ConvSetupError},
10    kernels::layered::selector::launch_kernel_concrete,
11};
12use crate::{
13    components::{
14        ConvolutionProblem, Dimensionality,
15        global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
16    },
17    kernels::layered::algorithm::Algorithm,
18};
19use cubecl_matmul::components::global::args::MatmulArgs;
20use cubecl_matmul::components::{
21    self, AvailableLineSizes, LhsG, MatmulElems, MatmulIdent, MatmulPrecision, MatmulSelection,
22    MatrixPrecision, RhsG,
23};
24
25type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<
26    <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global,
27    <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global,
28    <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global,
29>;
30type Output<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Output<
31    <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global,
32>;
33
34#[derive(Clone)]
35pub struct ConvolutionArgs<const N_SPATIAL: usize> {
36    pub stride: [usize; N_SPATIAL],
37    pub padding: [usize; N_SPATIAL],
38    pub dilation: [usize; N_SPATIAL],
39}
40
41#[allow(clippy::result_large_err)]
50pub fn launch_conv<R: Runtime, MP: MatmulPrecision, Alg: Algorithm, const N_SPATIAL: usize>(
51    client: &ComputeClient<R::Server>,
52    input: &MatmulInputHandleRef<'_, R>,
53    weight: &MatmulInputHandleRef<'_, R>,
54    bias: &Option<TensorHandleRef<'_, R>>,
55    out: &TensorHandleRef<'_, R>,
56    args: ConvolutionArgs<N_SPATIAL>,
57) -> Result<(), ConvSetupError>
58where
59    Input<Alg, MP>: ConcreteInputsFactory,
60    Output<Alg, MP>: ConcreteOutputFactory,
61{
62    let ConvolutionArgs {
63        stride,
64        padding,
65        dilation,
66    } = args;
67
68    let dimensionality = match N_SPATIAL {
69        1 => Dimensionality::Dim1,
70        2 => Dimensionality::Dim2,
71        3 => Dimensionality::Dim3,
72        other => unimplemented!("Unsupported dimensionality {other}"),
73    };
74
75    launch::<R, MP, Alg>(
76        client,
77        input,
78        weight,
79        bias,
80        out,
81        (&stride, &padding, &dilation),
82        dimensionality,
83    )
84}
85
86fn launch<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
87    client: &ComputeClient<R::Server>,
88    input: &MatmulInputHandleRef<'_, R>,
89    weight: &MatmulInputHandleRef<'_, R>,
90    bias: &Option<TensorHandleRef<'_, R>>,
91    out: &TensorHandleRef<'_, R>,
92    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
93    dimensionality: Dimensionality,
94) -> Result<(), ConvSetupError>
95where
96    Input<Alg, MP>: ConcreteInputsFactory,
97    Output<Alg, MP>: ConcreteOutputFactory,
98{
99    let rank = input.data().shape.len();
100    let dim_c = rank - 1;
101
102    let n = input.data().shape[0];
103    let c = input.data().shape[dim_c];
104
105    let out_c = weight.data().shape[0];
106
107    let in_shape = &input.data().shape[1..dim_c];
108    let kernel_shape = &weight.data().shape[1..dim_c];
109    let out_shape = &out.shape[1..dim_c];
110
111    let input_data = Alg::into_tensor_handle::<R, LhsG<MP>>(client, input.data(), MatmulIdent::Lhs);
112    let weight_data =
113        Alg::into_tensor_handle::<R, RhsG<MP>>(client, weight.data(), MatmulIdent::Rhs);
114
115    let mut input = *input;
116    let mut weight = *weight;
117
118    *input.data_mut() = input_data.as_ref();
119    *weight.data_mut() = weight_data.as_ref();
120
121    let plane_dim = client.properties().hardware.plane_size_max;
122
123    let problem = ConvolutionProblem {
124        m: n * out_shape.iter().product::<usize>(),
125        n: out_c,
126        k: c * kernel_shape.iter().product::<usize>(),
127        lhs_layout: components::MatrixLayout::RowMajor,
128        rhs_layout: components::MatrixLayout::ColMajor,
129        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
130        stride: stride.iter().map(|it| *it as u32).collect(),
131        padding: padding.iter().map(|it| *it as i32).collect(),
132        dilation: dilation.iter().map(|it| *it as u32).collect(),
133
134        batches: n,
135        shape: in_shape.to_vec(),
136        out_shape: out_shape.to_vec(),
137        channels: c,
138
139        dimensionality,
140    };
141
142    let selection = Alg::selection::<R>(client, &problem, plane_dim, MatmulElems::new::<MP>())?;
143
144    let lhs_is_f32 = TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>();
145    let rhs_is_f32 = TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>();
146
147    let launch = if lhs_is_f32 || rhs_is_f32 {
148        if tf32::supported_uses(client).contains(TypeUsage::Conversion) {
149            if lhs_is_f32 && rhs_is_f32 {
150                launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, tf32, tf32, f32), Alg>
151            } else if lhs_is_f32 {
152                launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, tf32, f32, f32), Alg>
153            } else {
154                launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f32, tf32, f32), Alg>
155            }
156        } else if lhs_is_f32 && rhs_is_f32 {
157            launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f16, f16, f32), Alg>
158        } else if lhs_is_f32 {
159            launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f16, f32, f32), Alg>
160        } else {
161            launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f32, f16, f32), Alg>
162        }
163    } else {
164        launch_kernel::<R, MP, Alg>
165    };
166
167    launch(client, &input, &weight, bias, out, problem, selection)
168}
169
170#[allow(clippy::result_large_err, clippy::too_many_arguments)]
171pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
172    client: &ComputeClient<R::Server>,
173    input: &MatmulInputHandleRef<'_, R>,
174    weight: &MatmulInputHandleRef<'_, R>,
175    bias: &Option<TensorHandleRef<'_, R>>,
176    out: &TensorHandleRef<'_, R>,
177    problem: ConvolutionProblem,
178    selection: MatmulSelection,
179) -> Result<(), ConvSetupError>
180where
181    Input<Alg, MP>: ConcreteInputsFactory,
182    Output<Alg, MP>: ConcreteOutputFactory,
183{
184    let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
185        input.data().elem_size,
186        weight.data().elem_size,
187        out.elem_size,
188    )
189    .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
190    .filter_rhs_with_tensor(
191        weight.data().strides,
192        weight.data().shape,
193        problem.rhs_layout,
194    )
195    .filter_out_with_tensor(out.strides, out.shape);
196
197    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
198
199    let config = Alg::setup::<R, MP>(client, &problem, &selection, &line_sizes)?;
200
201    let line_sizes = config.line_sizes();
202
203    launch_kernel_concrete::<(MP, Alg::Args), R, Alg>(
204        client, input, weight, bias, out, problem, line_sizes, selection,
205    )
206}