cubecl_convolution/
args.rs

1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use crate::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner};
7use cubecl_matmul::components::{
8    MatmulLineSizes, MatmulSelection,
9    global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
10};
11
12use super::base::ConvolutionProblem;
13
14pub trait ConvInputsLaunch: LaunchArg {
15    fn create<'a, R: Runtime>(
16        lhs: &'a TensorHandleRef<'a, R>,
17        rhs: &'a TensorHandleRef<'a, R>,
18        selection: &MatmulSelection,
19        problem: &ConvolutionProblem,
20        line_sizes: &MatmulLineSizes,
21    ) -> Self::RuntimeArg<'a, R>;
22}
23
24impl<EI: Numeric> ConvInputsLaunch for TensorInputs<EI> {
25    fn create<'a, R: Runtime>(
26        lhs: &'a TensorHandleRef<'a, R>,
27        rhs: &'a TensorHandleRef<'a, R>,
28        _selection: &MatmulSelection,
29        _problem: &ConvolutionProblem,
30        line_sizes: &MatmulLineSizes,
31    ) -> Self::RuntimeArg<'a, R> {
32        TensorInputsLaunch::new(
33            lhs.as_tensor_arg(line_sizes.lhs),
34            None.into(),
35            rhs.as_tensor_arg(line_sizes.rhs),
36            None.into(),
37        )
38    }
39}
40
41impl<EI: Numeric> ConvInputsLaunch for TensorMapInputs<EI> {
42    fn create<'a, R: Runtime>(
43        lhs: &'a TensorHandleRef<'a, R>,
44        rhs: &'a TensorHandleRef<'a, R>,
45        selection: &MatmulSelection,
46        problem: &ConvolutionProblem,
47        line_sizes: &MatmulLineSizes,
48    ) -> Self::RuntimeArg<'a, R> {
49        let tiling_scheme = selection.tiling_scheme;
50        let stage_m = tiling_scheme.elements_in_stage_m();
51        let stage_n = tiling_scheme.elements_in_stage_n();
52        let tile_size_k = tiling_scheme.elements_in_tile_k();
53        let stage_size_rhs = vec![stage_n, 1, tile_size_k];
54
55        let elem_size = size_of::<EI>();
56
57        fn prefetch(bytes: usize) -> TensorMapPrefetch {
58            match bytes {
59                ..64 => TensorMapPrefetch::None,
60                64..128 => TensorMapPrefetch::B64,
61                128..256 => TensorMapPrefetch::B128,
62                256.. => TensorMapPrefetch::B256,
63            }
64        }
65
66        let prefetch_lhs = prefetch(tile_size_k as usize * elem_size);
67        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
68
69        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
70        // It shouldn't matter, but it's better to be safe.
71        let elem = if TypeId::of::<EI>() == TypeId::of::<f32>() {
72            tf32::as_elem_native_unchecked()
73        } else {
74            EI::as_elem_native_unchecked()
75        };
76
77        let mut elem_stride = vec![1; 2 + problem.stride.len()];
78
79        for (i, stride) in problem.stride.iter().enumerate() {
80            elem_stride[i + 1] = *stride as usize;
81        }
82
83        let lhs = TensorMapArg::new(
84            TensorMapFormat::Im2col {
85                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
86                pixel_box_upper_corner: calculate_upper_corner(
87                    &problem.padding,
88                    &problem.kernel_size,
89                    &problem.dilation,
90                ),
91                channels_per_pixel: tile_size_k,
92                pixels_per_column: stage_m,
93            },
94            lhs.as_tensor_arg(line_sizes.lhs),
95            elem,
96        )
97        .with_elem_stride(elem_stride)
98        .with_prefetch(prefetch_lhs);
99
100        let rhs = TensorMapArg::new(
101            TensorMapFormat::Tiled {
102                tile_size: stage_size_rhs,
103            },
104            rhs.as_tensor_arg(1),
105            EI::as_elem_native_unchecked(),
106        )
107        .with_prefetch(prefetch_rhs);
108
109        TensorMapInputsLaunch::new(lhs, rhs)
110    }
111}