cubecl_linalg/convolution/
args.rs

1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use crate::{
7    convolution::algorithm::simple_tma::calculate_upper_corner,
8    matmul::components::{
9        MatmulSelection,
10        global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
11    },
12};
13
14use super::base::ConvolutionProblem;
15
16pub trait ConvInputsLaunch: LaunchArg {
17    fn create<'a, R: Runtime>(
18        lhs: &'a TensorHandleRef<'a, R>,
19        rhs: &'a TensorHandleRef<'a, R>,
20        selection: &MatmulSelection,
21        problem: &ConvolutionProblem,
22    ) -> Self::RuntimeArg<'a, R>;
23}
24
25impl<EI: Numeric> ConvInputsLaunch for TensorInputs<EI> {
26    fn create<'a, R: Runtime>(
27        lhs: &'a TensorHandleRef<'a, R>,
28        rhs: &'a TensorHandleRef<'a, R>,
29        _selection: &MatmulSelection,
30        problem: &ConvolutionProblem,
31    ) -> Self::RuntimeArg<'a, R> {
32        TensorInputsLaunch::new(
33            lhs.as_tensor_arg(problem.lhs_line_size),
34            rhs.as_tensor_arg(problem.rhs_line_size),
35        )
36    }
37}
38
39impl<EI: Numeric> ConvInputsLaunch for TensorMapInputs<EI> {
40    fn create<'a, R: Runtime>(
41        lhs: &'a TensorHandleRef<'a, R>,
42        rhs: &'a TensorHandleRef<'a, R>,
43        selection: &MatmulSelection,
44        problem: &ConvolutionProblem,
45    ) -> Self::RuntimeArg<'a, R> {
46        let stage_m = selection.tile_count.m * selection.tile_shape.m;
47        let stage_n = selection.tile_count.n * selection.tile_shape.n;
48        let stage_size_rhs = vec![stage_n, 1, selection.tile_shape.k];
49
50        let elem_size = size_of::<EI>();
51
52        fn prefetch(bytes: usize) -> TensorMapPrefetch {
53            match bytes {
54                ..64 => TensorMapPrefetch::None,
55                64..128 => TensorMapPrefetch::B64,
56                128..256 => TensorMapPrefetch::B128,
57                256.. => TensorMapPrefetch::B256,
58            }
59        }
60
61        let prefetch_lhs = prefetch(selection.tile_shape.k as usize * elem_size);
62        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
63
64        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
65        // It shouldn't matter, but it's better to be safe.
66        let elem = if TypeId::of::<EI>() == TypeId::of::<f32>() {
67            tf32::as_elem_native_unchecked()
68        } else {
69            EI::as_elem_native_unchecked()
70        };
71
72        let lhs = TensorMapArg::new(
73            TensorMapFormat::Im2col {
74                pixel_box_lower_corner: vec![-problem.padding.0, -problem.padding.1],
75                pixel_box_upper_corner: calculate_upper_corner(
76                    problem.padding,
77                    problem.kernel_size,
78                    problem.dilation,
79                ),
80                channels_per_pixel: selection.tile_shape.k,
81                pixels_per_column: stage_m,
82            },
83            lhs.as_tensor_arg(problem.lhs_line_size),
84            elem,
85        )
86        .with_elem_stride(vec![
87            1,
88            problem.stride.0 as usize,
89            problem.stride.1 as usize,
90            1,
91        ])
92        .with_prefetch(prefetch_lhs);
93
94        let rhs = TensorMapArg::new(
95            TensorMapFormat::Tiled {
96                tile_size: stage_size_rhs,
97            },
98            rhs.as_tensor_arg(1),
99            EI::as_elem_native_unchecked(),
100        )
101        .with_prefetch(prefetch_rhs);
102
103        TensorMapInputsLaunch::new(lhs, rhs)
104    }
105}