cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs1use cubecl_core::server::LaunchError;
2use cubecl_core::{
3 CubeCount, Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef,
4};
5use cubecl_matmul::components::stage::NumStages;
6use cubecl_matmul::components::{
7 InvalidConfigError, MatmulIdent, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
8 tile::TileMatmulFamily,
9};
10use cubecl_matmul::components::{
11 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
12};
13use cubecl_std::{
14 CubeOption,
15 tensor::{TensorHandle, into_contiguous_pitched},
16};
17use std::marker::PhantomData;
18
19use crate::components::{
20 ConvolutionProblem, Dimensionality, convolution_matmul_selection,
21 global::single_stage::tma::SimpleTmaConvolutionFamily,
22};
23
24use super::Algorithm;
25
26pub const TMA_STRIDE_ALIGN: usize = 16;
27
28pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
30 _tmm: PhantomData<TMM>,
31}
32
33impl<
34 TMM: TileMatmulFamily<
35 LhsTile = Strided,
36 RhsTile = Strided,
37 AccTile = CubeOption<Strided>,
38 OutTile = Strided,
39 >,
40> Algorithm for SimpleTmaConvAlgorithm<TMM>
41{
42 type TileMatmul = TMM;
43 type StageMatmul = PlaneMatmulFamily<
44 Self::TileMatmul,
45 StridedStageFamily,
46 StridedStageFamily,
47 Option<StridedStageFamily>,
48 >;
49 type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
50
51 type Args = TensorMapArgs;
52
53 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
54 let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
55 let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
56 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
57 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
58
59 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
60 }
61
62 fn into_tensor_handle<R: Runtime>(
63 client: &ComputeClient<R>,
64 handle: &TensorHandleRef<'_, R>,
65 ident: MatmulIdent,
66 dtype: StorageType,
67 ) -> Result<TensorHandle<R>, LaunchError> {
68 into_tensor_handle_tma(client, handle, ident, dtype)
69 }
70
71 fn num_stages() -> NumStages {
73 (1, 1).into()
74 }
75
76 fn selection<R: Runtime>(
77 client: &ComputeClient<R>,
78 problem: &ConvolutionProblem,
79 plane_dim: u32,
80 dtypes: &mut MatmulElems,
81 ) -> Result<MatmulSelection, MatmulSetupError> {
82 Ok(convolution_matmul_selection::<TMM, R>(
83 client, problem, plane_dim, dtypes,
84 )?)
85 }
86}
87
88pub(crate) fn into_tensor_handle_tma<R: Runtime>(
89 client: &ComputeClient<R>,
90 handle: &TensorHandleRef<'_, R>,
91 ident: MatmulIdent,
92 dtype: StorageType,
93) -> Result<TensorHandle<R>, LaunchError> {
94 let rank = handle.shape.len();
95 let dim_c = rank - 1;
96 let mut handle = if has_valid_layout(handle, ident) {
97 TensorHandle::from_ref(handle, dtype)
98 } else {
99 into_contiguous_pitched(client, handle, dtype)?
100 };
101 match ident {
102 MatmulIdent::Lhs => Ok(handle),
103 MatmulIdent::Rhs => {
104 let k_size = handle.shape[1..dim_c].iter().product();
105 handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
106 handle.strides = vec![
107 handle.strides[0],
108 handle.strides[dim_c - 1],
109 handle.strides[dim_c],
110 ];
111 Ok(handle)
112 }
113 MatmulIdent::Out => unreachable!(),
114 }
115}
116
117pub(crate) fn has_valid_layout<R: Runtime>(
118 handle: &TensorHandleRef<'_, R>,
119 ident: MatmulIdent,
120) -> bool {
121 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
122 let rank = handle.shape.len();
123 let dim_c = rank - 1;
124
125 let aligned = handle.strides[..dim_c]
126 .iter()
127 .all(|stride| stride % stride_align == 0);
128
129 let valid_layout = match ident {
130 MatmulIdent::Lhs => handle.strides[dim_c] == 1,
131 MatmulIdent::Rhs => {
132 let c_major = handle.strides[dim_c] == 1;
133 let mut kernel_contig = true;
134 for i in 1..dim_c - 1 {
135 kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
136 }
137 c_major && kernel_contig
138 }
139 MatmulIdent::Out => unreachable!(),
140 };
141
142 valid_layout && aligned
143}
144
145pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
146 fn check_range(
147 value: isize,
148 name: impl FnOnce() -> String,
149 min: isize,
150 max: isize,
151 ) -> Result<(), InvalidConfigError> {
152 if value < min || value > max {
153 let name = name();
154 Err(Box::new(format!(
155 "value {name} outside of valid range ({min}, {max})"
156 )))
157 } else {
158 Ok(())
159 }
160 }
161
162 let (corner_min, corner_max) = match problem.dimensionality {
163 Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
164 Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
165 Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
166 };
167
168 let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
169 for (i, offs) in corner.iter().enumerate() {
170 check_range(
171 *offs as isize,
172 || format!("corner[{i}]"),
173 corner_min,
174 corner_max,
175 )?;
176 }
177
178 let offset_max = match problem.dimensionality {
179 Dimensionality::Dim1 => 2isize.pow(16) - 1,
180 Dimensionality::Dim2 => 2isize.pow(8) - 1,
181 Dimensionality::Dim3 => 2isize.pow(5) - 1,
182 };
183
184 for i in 0..problem.kernel_size.len() {
185 let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
186 check_range(
187 offset as isize,
188 || format!("kernel size {i}"),
189 0,
190 offset_max,
191 )?;
192 check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
193 }
194
195 Ok(())
196}
197
198pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
199 padding.iter().map(|padding| -*padding).collect()
200}
201
202pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
203 padding
204 .iter()
205 .zip(kernel_size)
206 .zip(dilation)
207 .map(|((padding, kernel_size), dilation)| {
208 *padding - (*kernel_size - 1) as i32 * *dilation as i32
209 })
210 .collect()
211}