cubecl_convolution/algorithm/
simple_tma.rs1use std::marker::PhantomData;
2
3use cubecl_core::ir::Elem;
4use cubecl_core::{
5 CubeCount, Runtime,
6 client::ComputeClient,
7 prelude::{Numeric, TensorHandleRef},
8};
9use cubecl_matmul::components::MatmulSelection;
10
11use crate::{
12 base::{ConvolutionProblem, Dimensionality},
13 homogeneous::simple_tma::SimpleTmaConvolutionFamily,
14 selection::convolution_matmul_selection,
15};
16use cubecl_matmul::components::stage::NumStages;
17use cubecl_matmul::components::{
18 InputIdent, InvalidConfigError,
19 global::args::TensorMapArgs,
20 stage::{FullReaderFamily, PlaneMatmulFamily},
21 tile::TileMatmulFamily,
22};
23
24use cubecl_std::tensor::{TensorHandle, into_contiguous_pitched};
25
26use super::Algorithm;
27
28pub const TMA_STRIDE_ALIGN: usize = 16;
29
30pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
32 _tmm: PhantomData<TMM>,
33}
34
35impl<TMM: TileMatmulFamily> Algorithm for SimpleTmaConvAlgorithm<TMM> {
36 type TileMatmul = TMM;
37 type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, FullReaderFamily>;
38 type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
39
40 type Args = TensorMapArgs;
41
42 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
43 let m_stage = selection.tiling_scheme.elements_in_stage_m();
44 let n_stage = selection.tiling_scheme.elements_in_stage_n();
45 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
46 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
47
48 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
49 }
50
51 fn into_tensor_handle<R: Runtime, E: Numeric>(
52 client: &ComputeClient<R::Server, R::Channel>,
53 handle: &TensorHandleRef<'_, R>,
54 ident: InputIdent,
55 ) -> TensorHandle<R, E> {
56 into_tensor_handle_tma(client, handle, ident)
57 }
58
59 fn num_stages() -> NumStages {
61 (1, 1).into()
62 }
63
64 fn selection<R: Runtime>(
65 client: &ComputeClient<R::Server, R::Channel>,
66 problem: &ConvolutionProblem,
67 plane_dim: u32,
68 elem_stage: Elem,
69 elem_acc: Elem,
70 ) -> MatmulSelection {
71 convolution_matmul_selection::<TMM, R>(client, problem, plane_dim, elem_stage, elem_acc)
72 }
73}
74
75pub(crate) fn into_tensor_handle_tma<R: Runtime, E: Numeric>(
76 client: &ComputeClient<R::Server, R::Channel>,
77 handle: &TensorHandleRef<'_, R>,
78 ident: InputIdent,
79) -> TensorHandle<R, E> {
80 let rank = handle.shape.len();
81 let dim_c = rank - 1;
82 let mut handle = if has_valid_layout(handle, ident) {
83 TensorHandle::from_ref(handle)
84 } else {
85 into_contiguous_pitched(client, handle)
86 };
87 match ident {
88 InputIdent::Lhs => handle,
89 InputIdent::Rhs => {
90 let k_size = handle.shape[1..dim_c].iter().product();
91 handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
92 handle.strides = vec![
93 handle.strides[0],
94 handle.strides[dim_c - 1],
95 handle.strides[dim_c],
96 ];
97 handle
98 }
99 }
100}
101
102pub(crate) fn has_valid_layout<R: Runtime>(
103 handle: &TensorHandleRef<'_, R>,
104 ident: InputIdent,
105) -> bool {
106 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
107 let rank = handle.shape.len();
108 let dim_c = rank - 1;
109
110 let aligned = handle.strides[..dim_c]
111 .iter()
112 .all(|stride| stride % stride_align == 0);
113
114 let valid_layout = match ident {
115 InputIdent::Lhs => handle.strides[dim_c] == 1,
116 InputIdent::Rhs => {
117 let c_major = handle.strides[dim_c] == 1;
118 let mut kernel_contig = true;
119 for i in 1..dim_c - 1 {
120 kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
121 }
122 c_major && kernel_contig
123 }
124 };
125
126 valid_layout && aligned
127}
128
129pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
130 fn check_range(
131 value: isize,
132 name: impl FnOnce() -> String,
133 min: isize,
134 max: isize,
135 ) -> Result<(), InvalidConfigError> {
136 if value < min || value > max {
137 let name = name();
138 Err(Box::new(format!(
139 "value {name} outside of valid range ({min}, {max})"
140 )))
141 } else {
142 Ok(())
143 }
144 }
145
146 let (corner_min, corner_max) = match problem.dimensionality {
147 Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
148 Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
149 Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
150 };
151
152 let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
153 for (i, offs) in corner.iter().enumerate() {
154 check_range(
155 *offs as isize,
156 || format!("corner[{i}]"),
157 corner_min,
158 corner_max,
159 )?;
160 }
161
162 let offset_max = match problem.dimensionality {
163 Dimensionality::Dim1 => 2isize.pow(16) - 1,
164 Dimensionality::Dim2 => 2isize.pow(8) - 1,
165 Dimensionality::Dim3 => 2isize.pow(5) - 1,
166 };
167
168 for i in 0..problem.kernel_size.len() {
169 let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
170 check_range(
171 offset as isize,
172 || format!("kernel size {i}"),
173 0,
174 offset_max,
175 )?;
176 check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
177 }
178
179 Ok(())
180}
181
182pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
183 padding.iter().map(|padding| -*padding).collect()
184}
185
186pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
187 padding
188 .iter()
189 .zip(kernel_size)
190 .zip(dilation)
191 .map(|((padding, kernel_size), dilation)| {
192 *padding - (*kernel_size - 1) as i32 * *dilation as i32
193 })
194 .collect()
195}