cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs

1use cubecl_core::server::LaunchError;
2use cubecl_core::{
3    CubeCount, Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef,
4};
5use cubecl_matmul::components::stage::NumStages;
6use cubecl_matmul::components::{
7    InvalidConfigError, MatmulIdent, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
8    tile::TileMatmulFamily,
9};
10use cubecl_matmul::components::{
11    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
12};
13use cubecl_std::{
14    CubeOption,
15    tensor::{TensorHandle, into_contiguous_pitched},
16};
17use std::marker::PhantomData;
18
19use crate::components::{
20    ConvolutionProblem, Dimensionality, convolution_matmul_selection,
21    global::single_stage::tma::SimpleTmaConvolutionFamily,
22};
23
24use super::Algorithm;
25
26pub const TMA_STRIDE_ALIGN: usize = 16;
27
28/// Cmma convolution
29pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
30    _tmm: PhantomData<TMM>,
31}
32
33impl<
34    TMM: TileMatmulFamily<
35            LhsTile = Strided,
36            RhsTile = Strided,
37            AccTile = CubeOption<Strided>,
38            OutTile = Strided,
39        >,
40> Algorithm for SimpleTmaConvAlgorithm<TMM>
41{
42    type TileMatmul = TMM;
43    type StageMatmul = PlaneMatmulFamily<
44        Self::TileMatmul,
45        StridedStageFamily,
46        StridedStageFamily,
47        Option<StridedStageFamily>,
48    >;
49    type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
50
51    type Args = TensorMapArgs;
52
53    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
54        let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
55        let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
56        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
57        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
58
59        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
60    }
61
62    fn into_tensor_handle<R: Runtime>(
63        client: &ComputeClient<R>,
64        handle: &TensorHandleRef<'_, R>,
65        ident: MatmulIdent,
66        dtype: StorageType,
67    ) -> Result<TensorHandle<R>, LaunchError> {
68        into_tensor_handle_tma(client, handle, ident, dtype)
69    }
70
71    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
72    fn num_stages() -> NumStages {
73        (1, 1).into()
74    }
75
76    fn selection<R: Runtime>(
77        client: &ComputeClient<R>,
78        problem: &ConvolutionProblem,
79        plane_dim: u32,
80        dtypes: &mut MatmulElems,
81    ) -> Result<MatmulSelection, MatmulSetupError> {
82        Ok(convolution_matmul_selection::<TMM, R>(
83            client, problem, plane_dim, dtypes,
84        )?)
85    }
86}
87
88pub(crate) fn into_tensor_handle_tma<R: Runtime>(
89    client: &ComputeClient<R>,
90    handle: &TensorHandleRef<'_, R>,
91    ident: MatmulIdent,
92    dtype: StorageType,
93) -> Result<TensorHandle<R>, LaunchError> {
94    let rank = handle.shape.len();
95    let dim_c = rank - 1;
96    let mut handle = if has_valid_layout(handle, ident) {
97        TensorHandle::from_ref(handle, dtype)
98    } else {
99        into_contiguous_pitched(client, handle, dtype)?
100    };
101    match ident {
102        MatmulIdent::Lhs => Ok(handle),
103        MatmulIdent::Rhs => {
104            let k_size = handle.shape[1..dim_c].iter().product();
105            handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
106            handle.strides = vec![
107                handle.strides[0],
108                handle.strides[dim_c - 1],
109                handle.strides[dim_c],
110            ];
111            Ok(handle)
112        }
113        MatmulIdent::Out => unreachable!(),
114    }
115}
116
117pub(crate) fn has_valid_layout<R: Runtime>(
118    handle: &TensorHandleRef<'_, R>,
119    ident: MatmulIdent,
120) -> bool {
121    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
122    let rank = handle.shape.len();
123    let dim_c = rank - 1;
124
125    let aligned = handle.strides[..dim_c]
126        .iter()
127        .all(|stride| stride % stride_align == 0);
128
129    let valid_layout = match ident {
130        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
131        MatmulIdent::Rhs => {
132            let c_major = handle.strides[dim_c] == 1;
133            let mut kernel_contig = true;
134            for i in 1..dim_c - 1 {
135                kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
136            }
137            c_major && kernel_contig
138        }
139        MatmulIdent::Out => unreachable!(),
140    };
141
142    valid_layout && aligned
143}
144
145pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
146    fn check_range(
147        value: isize,
148        name: impl FnOnce() -> String,
149        min: isize,
150        max: isize,
151    ) -> Result<(), InvalidConfigError> {
152        if value < min || value > max {
153            let name = name();
154            Err(Box::new(format!(
155                "value {name} outside of valid range ({min}, {max})"
156            )))
157        } else {
158            Ok(())
159        }
160    }
161
162    let (corner_min, corner_max) = match problem.dimensionality {
163        Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
164        Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
165        Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
166    };
167
168    let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
169    for (i, offs) in corner.iter().enumerate() {
170        check_range(
171            *offs as isize,
172            || format!("corner[{i}]"),
173            corner_min,
174            corner_max,
175        )?;
176    }
177
178    let offset_max = match problem.dimensionality {
179        Dimensionality::Dim1 => 2isize.pow(16) - 1,
180        Dimensionality::Dim2 => 2isize.pow(8) - 1,
181        Dimensionality::Dim3 => 2isize.pow(5) - 1,
182    };
183
184    for i in 0..problem.kernel_size.len() {
185        let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
186        check_range(
187            offset as isize,
188            || format!("kernel size {i}"),
189            0,
190            offset_max,
191        )?;
192        check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
193    }
194
195    Ok(())
196}
197
198pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
199    padding.iter().map(|padding| -*padding).collect()
200}
201
202pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
203    padding
204        .iter()
205        .zip(kernel_size)
206        .zip(dilation)
207        .map(|((padding, kernel_size), dilation)| {
208            *padding - (*kernel_size - 1) as i32 * *dilation as i32
209        })
210        .collect()
211}