cubecl_convolution/kernels/layered/algorithm/
simple_tma.rs

1use std::marker::PhantomData;
2
3use cubecl_core::{
4    CubeCount, Runtime,
5    client::ComputeClient,
6    prelude::{Numeric, TensorHandleRef},
7};
8use cubecl_matmul::components::{
9    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
10};
11
12use cubecl_matmul::components::stage::NumStages;
13use cubecl_matmul::components::{
14    InvalidConfigError, MatmulIdent, global::args::TensorMapArgs, stage::PlaneMatmulFamily,
15    tile::TileMatmulFamily,
16};
17
18use cubecl_std::{
19    CubeOption,
20    tensor::{TensorHandle, into_contiguous_pitched},
21};
22
23use crate::components::{
24    ConvolutionProblem, Dimensionality, convolution_matmul_selection,
25    global::single_stage::tma::SimpleTmaConvolutionFamily,
26};
27
28use super::Algorithm;
29
30pub const TMA_STRIDE_ALIGN: usize = 16;
31
32/// Cmma convolution
33pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
34    _tmm: PhantomData<TMM>,
35}
36
37impl<
38    TMM: TileMatmulFamily<
39            LhsTile = Strided,
40            RhsTile = Strided,
41            AccTile = CubeOption<Strided>,
42            OutTile = Strided,
43        >,
44> Algorithm for SimpleTmaConvAlgorithm<TMM>
45{
46    type TileMatmul = TMM;
47    type StageMatmul = PlaneMatmulFamily<
48        Self::TileMatmul,
49        StridedStageFamily,
50        StridedStageFamily,
51        Option<StridedStageFamily>,
52    >;
53    type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
54
55    type Args = TensorMapArgs;
56
57    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
58        let m_stage = selection.tiling_scheme.elements_in_stage_m();
59        let n_stage = selection.tiling_scheme.elements_in_stage_n();
60        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
61        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
62
63        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
64    }
65
66    fn into_tensor_handle<R: Runtime, E: Numeric>(
67        client: &ComputeClient<R::Server>,
68        handle: &TensorHandleRef<'_, R>,
69        ident: MatmulIdent,
70    ) -> TensorHandle<R, E> {
71        into_tensor_handle_tma(client, handle, ident)
72    }
73
74    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
75    fn num_stages() -> NumStages {
76        (1, 1).into()
77    }
78
79    fn selection<R: Runtime>(
80        client: &ComputeClient<R::Server>,
81        problem: &ConvolutionProblem,
82        plane_dim: u32,
83        matmul_elems: MatmulElems,
84    ) -> Result<MatmulSelection, MatmulSetupError> {
85        Ok(convolution_matmul_selection::<TMM, R>(
86            client,
87            problem,
88            plane_dim,
89            matmul_elems,
90        ))
91    }
92}
93
94pub(crate) fn into_tensor_handle_tma<R: Runtime, E: Numeric>(
95    client: &ComputeClient<R::Server>,
96    handle: &TensorHandleRef<'_, R>,
97    ident: MatmulIdent,
98) -> TensorHandle<R, E> {
99    let rank = handle.shape.len();
100    let dim_c = rank - 1;
101    let mut handle = if has_valid_layout(handle, ident) {
102        TensorHandle::from_ref(handle)
103    } else {
104        into_contiguous_pitched(client, handle)
105    };
106    match ident {
107        MatmulIdent::Lhs => handle,
108        MatmulIdent::Rhs => {
109            let k_size = handle.shape[1..dim_c].iter().product();
110            handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
111            handle.strides = vec![
112                handle.strides[0],
113                handle.strides[dim_c - 1],
114                handle.strides[dim_c],
115            ];
116            handle
117        }
118        MatmulIdent::Out => unreachable!(),
119    }
120}
121
122pub(crate) fn has_valid_layout<R: Runtime>(
123    handle: &TensorHandleRef<'_, R>,
124    ident: MatmulIdent,
125) -> bool {
126    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
127    let rank = handle.shape.len();
128    let dim_c = rank - 1;
129
130    let aligned = handle.strides[..dim_c]
131        .iter()
132        .all(|stride| stride % stride_align == 0);
133
134    let valid_layout = match ident {
135        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
136        MatmulIdent::Rhs => {
137            let c_major = handle.strides[dim_c] == 1;
138            let mut kernel_contig = true;
139            for i in 1..dim_c - 1 {
140                kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
141            }
142            c_major && kernel_contig
143        }
144        MatmulIdent::Out => unreachable!(),
145    };
146
147    valid_layout && aligned
148}
149
150pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
151    fn check_range(
152        value: isize,
153        name: impl FnOnce() -> String,
154        min: isize,
155        max: isize,
156    ) -> Result<(), InvalidConfigError> {
157        if value < min || value > max {
158            let name = name();
159            Err(Box::new(format!(
160                "value {name} outside of valid range ({min}, {max})"
161            )))
162        } else {
163            Ok(())
164        }
165    }
166
167    let (corner_min, corner_max) = match problem.dimensionality {
168        Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
169        Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
170        Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
171    };
172
173    let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
174    for (i, offs) in corner.iter().enumerate() {
175        check_range(
176            *offs as isize,
177            || format!("corner[{i}]"),
178            corner_min,
179            corner_max,
180        )?;
181    }
182
183    let offset_max = match problem.dimensionality {
184        Dimensionality::Dim1 => 2isize.pow(16) - 1,
185        Dimensionality::Dim2 => 2isize.pow(8) - 1,
186        Dimensionality::Dim3 => 2isize.pow(5) - 1,
187    };
188
189    for i in 0..problem.kernel_size.len() {
190        let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
191        check_range(
192            offset as isize,
193            || format!("kernel size {i}"),
194            0,
195            offset_max,
196        )?;
197        check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
198    }
199
200    Ok(())
201}
202
203pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
204    padding.iter().map(|padding| -*padding).collect()
205}
206
207pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
208    padding
209        .iter()
210        .zip(kernel_size)
211        .zip(dilation)
212        .map(|((padding, kernel_size), dilation)| {
213            *padding - (*kernel_size - 1) as i32 * *dilation as i32
214        })
215        .collect()
216}