cubecl_linalg/convolution/
args.rs1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use crate::{
7 convolution::algorithm::simple_tma::calculate_upper_corner,
8 matmul::components::{
9 MatmulSelection,
10 global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
11 },
12};
13
14use super::base::ConvolutionProblem;
15
16pub trait ConvInputsLaunch: LaunchArg {
17 fn create<'a, R: Runtime>(
18 lhs: &'a TensorHandleRef<'a, R>,
19 rhs: &'a TensorHandleRef<'a, R>,
20 selection: &MatmulSelection,
21 problem: &ConvolutionProblem,
22 ) -> Self::RuntimeArg<'a, R>;
23}
24
25impl<EI: Numeric> ConvInputsLaunch for TensorInputs<EI> {
26 fn create<'a, R: Runtime>(
27 lhs: &'a TensorHandleRef<'a, R>,
28 rhs: &'a TensorHandleRef<'a, R>,
29 _selection: &MatmulSelection,
30 problem: &ConvolutionProblem,
31 ) -> Self::RuntimeArg<'a, R> {
32 TensorInputsLaunch::new(
33 lhs.as_tensor_arg(problem.lhs_line_size),
34 rhs.as_tensor_arg(problem.rhs_line_size),
35 )
36 }
37}
38
39impl<EI: Numeric> ConvInputsLaunch for TensorMapInputs<EI> {
40 fn create<'a, R: Runtime>(
41 lhs: &'a TensorHandleRef<'a, R>,
42 rhs: &'a TensorHandleRef<'a, R>,
43 selection: &MatmulSelection,
44 problem: &ConvolutionProblem,
45 ) -> Self::RuntimeArg<'a, R> {
46 let stage_m = selection.tile_count.m * selection.tile_shape.m;
47 let stage_n = selection.tile_count.n * selection.tile_shape.n;
48 let stage_size_rhs = vec![stage_n, 1, selection.tile_shape.k];
49
50 let elem_size = size_of::<EI>();
51
52 fn prefetch(bytes: usize) -> TensorMapPrefetch {
53 match bytes {
54 ..64 => TensorMapPrefetch::None,
55 64..128 => TensorMapPrefetch::B64,
56 128..256 => TensorMapPrefetch::B128,
57 256.. => TensorMapPrefetch::B256,
58 }
59 }
60
61 let prefetch_lhs = prefetch(selection.tile_shape.k as usize * elem_size);
62 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
63
64 let elem = if TypeId::of::<EI>() == TypeId::of::<f32>() {
67 tf32::as_elem_native_unchecked()
68 } else {
69 EI::as_elem_native_unchecked()
70 };
71
72 let lhs = TensorMapArg::new(
73 TensorMapFormat::Im2col {
74 pixel_box_lower_corner: vec![-problem.padding.0, -problem.padding.1],
75 pixel_box_upper_corner: calculate_upper_corner(
76 problem.padding,
77 problem.kernel_size,
78 problem.dilation,
79 ),
80 channels_per_pixel: selection.tile_shape.k,
81 pixels_per_column: stage_m,
82 },
83 lhs.as_tensor_arg(problem.lhs_line_size),
84 elem,
85 )
86 .with_elem_stride(vec![
87 1,
88 problem.stride.0 as usize,
89 problem.stride.1 as usize,
90 1,
91 ])
92 .with_prefetch(prefetch_lhs);
93
94 let rhs = TensorMapArg::new(
95 TensorMapFormat::Tiled {
96 tile_size: stage_size_rhs,
97 },
98 rhs.as_tensor_arg(1),
99 EI::as_elem_native_unchecked(),
100 )
101 .with_prefetch(prefetch_rhs);
102
103 TensorMapInputsLaunch::new(lhs, rhs)
104 }
105}