cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs1use cubecl_core::server::LaunchError;
2use cubecl_core::{
3 CubeCount, Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef,
4};
5use cubecl_matmul::components::{
6 InvalidConfigError, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
7 tile::TileMatmulFamily,
8};
9use cubecl_matmul::components::{
10 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
11};
12use cubecl_matmul::components::{MatmulLineSizes, stage::NumStages};
13use cubecl_std::{
14 CubeOption,
15 tensor::{TensorHandle, into_contiguous_pitched},
16};
17use std::marker::PhantomData;
18
19use crate::components::{
20 ConvolutionProblem, Dimensionality, convolution_matmul_selection,
21 global::single_stage::tma::SimpleTmaConvolutionFamily,
22};
23
24use super::Algorithm;
25
26pub const TMA_STRIDE_ALIGN: usize = 16;
27
28pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
30 _tmm: PhantomData<TMM>,
31}
32
33impl<
34 TMM: TileMatmulFamily<
35 LhsTile = Strided,
36 RhsTile = Strided,
37 AccTile = CubeOption<Strided>,
38 OutTile = Strided,
39 >,
40> Algorithm for SimpleTmaConvAlgorithm<TMM>
41{
42 type TileMatmul = TMM;
43 type StageMatmul = PlaneMatmulFamily<
44 Self::TileMatmul,
45 StridedStageFamily,
46 StridedStageFamily,
47 Option<StridedStageFamily>,
48 >;
49 type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
50
51 type Args = TensorMapArgs;
52
53 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
54 let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
55 let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
56 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
57 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
58
59 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
60 }
61
62 fn into_tensor_handle<R: Runtime>(
63 client: &ComputeClient<R>,
64 handle: &TensorHandleRef<'_, R>,
65 dtype: StorageType,
66 ) -> Result<TensorHandle<R>, LaunchError> {
67 into_tensor_handle_tma(client, handle, dtype)
68 }
69
70 fn num_stages() -> NumStages {
72 (1, 1).into()
73 }
74
75 fn selection<R: Runtime>(
76 client: &ComputeClient<R>,
77 problem: &ConvolutionProblem,
78 plane_dim: u32,
79 line_sizes: &MatmulLineSizes,
80 dtypes: &mut MatmulElems,
81 ) -> Result<MatmulSelection, MatmulSetupError> {
82 Ok(convolution_matmul_selection::<TMM, R>(
83 client, problem, plane_dim, false, line_sizes, dtypes,
84 )?)
85 }
86}
87
88pub(crate) fn into_tensor_handle_tma<R: Runtime>(
89 client: &ComputeClient<R>,
90 handle: &TensorHandleRef<'_, R>,
91 dtype: StorageType,
92) -> Result<TensorHandle<R>, LaunchError> {
93 let handle = if has_valid_layout(handle) {
94 TensorHandle::from_ref(handle, dtype)
95 } else {
96 into_contiguous_pitched(client, handle, dtype)?
97 };
98 Ok(handle)
99}
100
101pub(crate) fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
102 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
103 let rank = handle.shape.len();
104 let dim_c = rank - 1;
105
106 let aligned = handle.strides[..dim_c]
107 .iter()
108 .all(|stride| stride % stride_align == 0);
109
110 let valid_layout = handle.strides[dim_c] == 1;
111
112 valid_layout && aligned
113}
114
115pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
116 fn check_range(
117 value: isize,
118 name: impl FnOnce() -> String,
119 min: isize,
120 max: isize,
121 ) -> Result<(), InvalidConfigError> {
122 if value < min || value > max {
123 let name = name();
124 Err(Box::new(format!(
125 "value {name} outside of valid range ({min}, {max})"
126 )))
127 } else {
128 Ok(())
129 }
130 }
131
132 let (corner_min, corner_max) = match problem.dimensionality {
133 Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
134 Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
135 Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
136 };
137
138 let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
139 for (i, offs) in corner.iter().enumerate() {
140 check_range(
141 *offs as isize,
142 || format!("corner[{i}]"),
143 corner_min,
144 corner_max,
145 )?;
146 }
147
148 let offset_max = match problem.dimensionality {
149 Dimensionality::Dim1 => 2isize.pow(16) - 1,
150 Dimensionality::Dim2 => 2isize.pow(8) - 1,
151 Dimensionality::Dim3 => 2isize.pow(5) - 1,
152 };
153
154 for i in 0..problem.kernel_size.len() {
155 let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
156 check_range(
157 offset as isize,
158 || format!("kernel size {i}"),
159 0,
160 offset_max,
161 )?;
162 check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
163 }
164
165 Ok(())
166}
167
168pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
169 padding.iter().map(|padding| -*padding).collect()
170}
171
172pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
173 padding
174 .iter()
175 .zip(kernel_size)
176 .zip(dilation)
177 .map(|((padding, kernel_size), dilation)| {
178 *padding - (*kernel_size - 1) as i32 * *dilation as i32
179 })
180 .collect()
181}