cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs1use cubecl_core::{
2 CubeCount, Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef,
3};
4use cubecl_matmul::components::stage::NumStages;
5use cubecl_matmul::components::{
6 InvalidConfigError, MatmulIdent, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
7 tile::TileMatmulFamily,
8};
9use cubecl_matmul::components::{
10 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
11};
12use cubecl_std::{
13 CubeOption,
14 tensor::{TensorHandle, into_contiguous_pitched},
15};
16use std::marker::PhantomData;
17
18use crate::components::{
19 ConvolutionProblem, Dimensionality, convolution_matmul_selection,
20 global::single_stage::tma::SimpleTmaConvolutionFamily,
21};
22
23use super::Algorithm;
24
25pub const TMA_STRIDE_ALIGN: usize = 16;
26
27pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
29 _tmm: PhantomData<TMM>,
30}
31
32impl<
33 TMM: TileMatmulFamily<
34 LhsTile = Strided,
35 RhsTile = Strided,
36 AccTile = CubeOption<Strided>,
37 OutTile = Strided,
38 >,
39> Algorithm for SimpleTmaConvAlgorithm<TMM>
40{
41 type TileMatmul = TMM;
42 type StageMatmul = PlaneMatmulFamily<
43 Self::TileMatmul,
44 StridedStageFamily,
45 StridedStageFamily,
46 Option<StridedStageFamily>,
47 >;
48 type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
49
50 type Args = TensorMapArgs;
51
52 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
53 let m_stage = selection.tiling_scheme.elements_in_stage_m();
54 let n_stage = selection.tiling_scheme.elements_in_stage_n();
55 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
56 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
57
58 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
59 }
60
61 fn into_tensor_handle<R: Runtime>(
62 client: &ComputeClient<R::Server>,
63 handle: &TensorHandleRef<'_, R>,
64 ident: MatmulIdent,
65 dtype: StorageType,
66 ) -> TensorHandle<R> {
67 into_tensor_handle_tma(client, handle, ident, dtype)
68 }
69
70 fn num_stages() -> NumStages {
72 (1, 1).into()
73 }
74
75 fn selection<R: Runtime>(
76 client: &ComputeClient<R::Server>,
77 problem: &ConvolutionProblem,
78 plane_dim: u32,
79 dtypes: &mut MatmulElems,
80 ) -> Result<MatmulSelection, MatmulSetupError> {
81 Ok(convolution_matmul_selection::<TMM, R>(
82 client, problem, plane_dim, dtypes,
83 )?)
84 }
85}
86
87pub(crate) fn into_tensor_handle_tma<R: Runtime>(
88 client: &ComputeClient<R::Server>,
89 handle: &TensorHandleRef<'_, R>,
90 ident: MatmulIdent,
91 dtype: StorageType,
92) -> TensorHandle<R> {
93 let rank = handle.shape.len();
94 let dim_c = rank - 1;
95 let mut handle = if has_valid_layout(handle, ident) {
96 TensorHandle::from_ref(handle, dtype)
97 } else {
98 into_contiguous_pitched(client, handle, dtype)
99 };
100 match ident {
101 MatmulIdent::Lhs => handle,
102 MatmulIdent::Rhs => {
103 let k_size = handle.shape[1..dim_c].iter().product();
104 handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
105 handle.strides = vec![
106 handle.strides[0],
107 handle.strides[dim_c - 1],
108 handle.strides[dim_c],
109 ];
110 handle
111 }
112 MatmulIdent::Out => unreachable!(),
113 }
114}
115
116pub(crate) fn has_valid_layout<R: Runtime>(
117 handle: &TensorHandleRef<'_, R>,
118 ident: MatmulIdent,
119) -> bool {
120 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
121 let rank = handle.shape.len();
122 let dim_c = rank - 1;
123
124 let aligned = handle.strides[..dim_c]
125 .iter()
126 .all(|stride| stride % stride_align == 0);
127
128 let valid_layout = match ident {
129 MatmulIdent::Lhs => handle.strides[dim_c] == 1,
130 MatmulIdent::Rhs => {
131 let c_major = handle.strides[dim_c] == 1;
132 let mut kernel_contig = true;
133 for i in 1..dim_c - 1 {
134 kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
135 }
136 c_major && kernel_contig
137 }
138 MatmulIdent::Out => unreachable!(),
139 };
140
141 valid_layout && aligned
142}
143
144pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
145 fn check_range(
146 value: isize,
147 name: impl FnOnce() -> String,
148 min: isize,
149 max: isize,
150 ) -> Result<(), InvalidConfigError> {
151 if value < min || value > max {
152 let name = name();
153 Err(Box::new(format!(
154 "value {name} outside of valid range ({min}, {max})"
155 )))
156 } else {
157 Ok(())
158 }
159 }
160
161 let (corner_min, corner_max) = match problem.dimensionality {
162 Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
163 Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
164 Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
165 };
166
167 let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
168 for (i, offs) in corner.iter().enumerate() {
169 check_range(
170 *offs as isize,
171 || format!("corner[{i}]"),
172 corner_min,
173 corner_max,
174 )?;
175 }
176
177 let offset_max = match problem.dimensionality {
178 Dimensionality::Dim1 => 2isize.pow(16) - 1,
179 Dimensionality::Dim2 => 2isize.pow(8) - 1,
180 Dimensionality::Dim3 => 2isize.pow(5) - 1,
181 };
182
183 for i in 0..problem.kernel_size.len() {
184 let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
185 check_range(
186 offset as isize,
187 || format!("kernel size {i}"),
188 0,
189 offset_max,
190 )?;
191 check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
192 }
193
194 Ok(())
195}
196
197pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
198 padding.iter().map(|padding| -*padding).collect()
199}
200
201pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
202 padding
203 .iter()
204 .zip(kernel_size)
205 .zip(dilation)
206 .map(|((padding, kernel_size), dilation)| {
207 *padding - (*kernel_size - 1) as i32 * *dilation as i32
208 })
209 .collect()
210}