cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs

1use cubecl_core::{
2    CubeCount, Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef,
3};
4use cubecl_matmul::components::stage::NumStages;
5use cubecl_matmul::components::{
6    InvalidConfigError, MatmulIdent, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
7    tile::TileMatmulFamily,
8};
9use cubecl_matmul::components::{
10    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
11};
12use cubecl_std::{
13    CubeOption,
14    tensor::{TensorHandle, into_contiguous_pitched},
15};
16use std::marker::PhantomData;
17
18use crate::components::{
19    ConvolutionProblem, Dimensionality, convolution_matmul_selection,
20    global::single_stage::tma::SimpleTmaConvolutionFamily,
21};
22
23use super::Algorithm;
24
25pub const TMA_STRIDE_ALIGN: usize = 16;
26
27/// Cmma convolution
28pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
29    _tmm: PhantomData<TMM>,
30}
31
32impl<
33    TMM: TileMatmulFamily<
34            LhsTile = Strided,
35            RhsTile = Strided,
36            AccTile = CubeOption<Strided>,
37            OutTile = Strided,
38        >,
39> Algorithm for SimpleTmaConvAlgorithm<TMM>
40{
41    type TileMatmul = TMM;
42    type StageMatmul = PlaneMatmulFamily<
43        Self::TileMatmul,
44        StridedStageFamily,
45        StridedStageFamily,
46        Option<StridedStageFamily>,
47    >;
48    type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
49
50    type Args = TensorMapArgs;
51
52    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
53        let m_stage = selection.tiling_scheme.elements_in_stage_m();
54        let n_stage = selection.tiling_scheme.elements_in_stage_n();
55        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
56        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
57
58        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
59    }
60
61    fn into_tensor_handle<R: Runtime>(
62        client: &ComputeClient<R::Server>,
63        handle: &TensorHandleRef<'_, R>,
64        ident: MatmulIdent,
65        dtype: StorageType,
66    ) -> TensorHandle<R> {
67        into_tensor_handle_tma(client, handle, ident, dtype)
68    }
69
70    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
71    fn num_stages() -> NumStages {
72        (1, 1).into()
73    }
74
75    fn selection<R: Runtime>(
76        client: &ComputeClient<R::Server>,
77        problem: &ConvolutionProblem,
78        plane_dim: u32,
79        dtypes: &mut MatmulElems,
80    ) -> Result<MatmulSelection, MatmulSetupError> {
81        Ok(convolution_matmul_selection::<TMM, R>(
82            client, problem, plane_dim, dtypes,
83        ))
84    }
85}
86
87pub(crate) fn into_tensor_handle_tma<R: Runtime>(
88    client: &ComputeClient<R::Server>,
89    handle: &TensorHandleRef<'_, R>,
90    ident: MatmulIdent,
91    dtype: StorageType,
92) -> TensorHandle<R> {
93    let rank = handle.shape.len();
94    let dim_c = rank - 1;
95    let mut handle = if has_valid_layout(handle, ident) {
96        TensorHandle::from_ref(handle, dtype)
97    } else {
98        into_contiguous_pitched(client, handle, dtype)
99    };
100    match ident {
101        MatmulIdent::Lhs => handle,
102        MatmulIdent::Rhs => {
103            let k_size = handle.shape[1..dim_c].iter().product();
104            handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
105            handle.strides = vec![
106                handle.strides[0],
107                handle.strides[dim_c - 1],
108                handle.strides[dim_c],
109            ];
110            handle
111        }
112        MatmulIdent::Out => unreachable!(),
113    }
114}
115
116pub(crate) fn has_valid_layout<R: Runtime>(
117    handle: &TensorHandleRef<'_, R>,
118    ident: MatmulIdent,
119) -> bool {
120    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
121    let rank = handle.shape.len();
122    let dim_c = rank - 1;
123
124    let aligned = handle.strides[..dim_c]
125        .iter()
126        .all(|stride| stride % stride_align == 0);
127
128    let valid_layout = match ident {
129        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
130        MatmulIdent::Rhs => {
131            let c_major = handle.strides[dim_c] == 1;
132            let mut kernel_contig = true;
133            for i in 1..dim_c - 1 {
134                kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
135            }
136            c_major && kernel_contig
137        }
138        MatmulIdent::Out => unreachable!(),
139    };
140
141    valid_layout && aligned
142}
143
144pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
145    fn check_range(
146        value: isize,
147        name: impl FnOnce() -> String,
148        min: isize,
149        max: isize,
150    ) -> Result<(), InvalidConfigError> {
151        if value < min || value > max {
152            let name = name();
153            Err(Box::new(format!(
154                "value {name} outside of valid range ({min}, {max})"
155            )))
156        } else {
157            Ok(())
158        }
159    }
160
161    let (corner_min, corner_max) = match problem.dimensionality {
162        Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
163        Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
164        Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
165    };
166
167    let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
168    for (i, offs) in corner.iter().enumerate() {
169        check_range(
170            *offs as isize,
171            || format!("corner[{i}]"),
172            corner_min,
173            corner_max,
174        )?;
175    }
176
177    let offset_max = match problem.dimensionality {
178        Dimensionality::Dim1 => 2isize.pow(16) - 1,
179        Dimensionality::Dim2 => 2isize.pow(8) - 1,
180        Dimensionality::Dim3 => 2isize.pow(5) - 1,
181    };
182
183    for i in 0..problem.kernel_size.len() {
184        let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
185        check_range(
186            offset as isize,
187            || format!("kernel size {i}"),
188            0,
189            offset_max,
190        )?;
191        check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
192    }
193
194    Ok(())
195}
196
197pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
198    padding.iter().map(|padding| -*padding).collect()
199}
200
201pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
202    padding
203        .iter()
204        .zip(kernel_size)
205        .zip(dilation)
206        .map(|((padding, kernel_size), dilation)| {
207            *padding - (*kernel_size - 1) as i32 * *dilation as i32
208        })
209        .collect()
210}