cubecl_linalg/convolution/algorithm/
simple_tma.rs1use std::marker::PhantomData;
2
3use cubecl_core::{
4 CubeCount, CubeDim, Runtime,
5 client::ComputeClient,
6 prelude::{Numeric, TensorHandleRef},
7};
8
9use crate::{
10 convolution::{
11 base::{ConvolutionConfigFactory, ConvolutionProblem},
12 homogeneous::simple_tma::SimpleTmaConvolutionFamily,
13 },
14 matmul::components::{
15 InputIdent, InvalidConfigError, MatmulSelection,
16 global::args::TensorMapArgs,
17 stage::{FullReaderFamily, plane_matmul::PlaneMatmulFamily},
18 tile::TileMatmulFamily,
19 },
20 tensor::{TensorHandle, into_contiguous_pitched},
21};
22
23use super::Algorithm;
24
25pub const TMA_STRIDE_ALIGN: usize = 16;
26
27pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
29 _tmm: PhantomData<TMM>,
30}
31
32impl<TMM: TileMatmulFamily> Algorithm for SimpleTmaConvAlgorithm<TMM> {
33 type TileMatmul = TMM;
34 type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily>;
35 type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
36
37 type Args = TensorMapArgs;
38
39 fn cube_dim(selection: &MatmulSelection) -> CubeDim {
40 CubeDim::new(selection.plane_dim, selection.tile_count.m, 1)
41 }
42
43 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
44 let m_stage = selection.tile_count.m * selection.tile_shape.m;
45 let n_stage = selection.tile_count.n * selection.tile_shape.n;
46 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
47 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
48
49 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
50 }
51
52 fn make_config(
53 input: <Self::GlobalConvolution as ConvolutionConfigFactory>::Input,
54 problem: &ConvolutionProblem,
55 cube_dim: &CubeDim,
56 cube_count: &CubeCount,
57 ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, InvalidConfigError>
58 {
59 check_problem_tma(problem)?;
60
61 let config = Self::GlobalConvolution::make_config(input, problem, cube_dim, cube_count);
62 Self::GlobalConvolution::check_config(&config)?;
63 Ok(config)
64 }
65
66 fn check_availability<R: Runtime, MP: crate::matmul::components::MatmulPrecision>(
67 client: &ComputeClient<R::Server, R::Channel>,
68 config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
69 ) -> Result<(), crate::matmul::kernels::MatmulAvailabilityError> {
70 <Self::GlobalConvolution as ConvolutionConfigFactory>::check_availability::<R, MP>(
71 client, config,
72 )?;
73
74 if !client
75 .properties()
76 .feature_enabled(cubecl_core::Feature::Tma(cubecl_core::TmaFeature::Base))
77 {
78 return Err(crate::matmul::kernels::MatmulAvailabilityError::TmaUnavailable);
79 }
80
81 Ok(())
82 }
83
84 fn into_tensor_handle<R: Runtime, E: Numeric>(
85 client: &ComputeClient<R::Server, R::Channel>,
86 handle: &TensorHandleRef<'_, R>,
87 ident: InputIdent,
88 ) -> TensorHandle<R, E> {
89 let mut handle = if has_valid_layout(handle, ident) {
90 TensorHandle::from_ref(handle)
91 } else {
92 into_contiguous_pitched(client, handle)
93 };
94 match ident {
95 InputIdent::Lhs => handle,
96 InputIdent::Rhs => {
97 handle.shape = vec![
98 handle.shape[0],
99 handle.shape[1] * handle.shape[2],
100 handle.shape[3],
101 ];
102 handle.strides = vec![handle.strides[0], handle.strides[2], handle.strides[3]];
103 handle
104 }
105 }
106 }
107}
108
109fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
110 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
111
112 let aligned = handle.strides[..3]
113 .iter()
114 .all(|stride| stride % stride_align == 0);
115
116 let valid_layout = match ident {
117 InputIdent::Lhs => handle.strides[3] == 1,
118 InputIdent::Rhs => {
119 let c_major = handle.strides[3] == 1;
120 let kernel_contig = handle.strides[2] * handle.shape[2] == handle.strides[1];
121 c_major && kernel_contig
122 }
123 };
124
125 valid_layout && aligned
126}
127
128fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
129 fn check_range(
130 value: isize,
131 name: &str,
132 min: isize,
133 max: isize,
134 ) -> Result<(), InvalidConfigError> {
135 if value < min || value > max {
136 Err(Box::new(format!(
137 "value {name} outside of valid range ({min}, {max})"
138 )))
139 } else {
140 Ok(())
141 }
142 }
143
144 let corner = calculate_upper_corner(problem.padding, problem.kernel_size, problem.dilation);
145 check_range(corner[0] as isize, "corner_h", -128, 127)?;
146 check_range(corner[1] as isize, "corner_w", -128, 127)?;
147
148 let offset_h = (problem.kernel_size.0 - 1) * problem.dilation.0;
149 let offset_w = (problem.kernel_size.1 - 1) * problem.dilation.1;
150 check_range(offset_h as isize, "kernel size h", 0, 255)?;
151 check_range(offset_w as isize, "kernel size w", 0, 255)?;
152
153 check_range(problem.stride.0 as isize, "stride_h", 1, 8)?;
154 check_range(problem.stride.1 as isize, "stride_w", 1, 8)?;
155
156 Ok(())
157}
158
159pub fn calculate_upper_corner(
160 padding: (i32, i32),
161 kernel_size: (u32, u32),
162 dilation: (u32, u32),
163) -> Vec<i32> {
164 let corner_h = padding.0 - (kernel_size.0 - 1) as i32 * dilation.0 as i32;
165 let corner_w = padding.1 - (kernel_size.1 - 1) as i32 * dilation.1 as i32;
166
167 vec![corner_h, corner_w]
168}