cubecl_convolution/algorithm/
simple_tma.rs

1use std::marker::PhantomData;
2
3use cubecl_core::ir::Elem;
4use cubecl_core::{
5    CubeCount, Runtime,
6    client::ComputeClient,
7    prelude::{Numeric, TensorHandleRef},
8};
9use cubecl_matmul::components::MatmulSelection;
10
11use crate::{
12    base::{ConvolutionProblem, Dimensionality},
13    homogeneous::simple_tma::SimpleTmaConvolutionFamily,
14    selection::convolution_matmul_selection,
15};
16use cubecl_matmul::components::stage::NumStages;
17use cubecl_matmul::components::{
18    InputIdent, InvalidConfigError,
19    global::args::TensorMapArgs,
20    stage::{FullReaderFamily, PlaneMatmulFamily},
21    tile::TileMatmulFamily,
22};
23
24use cubecl_std::tensor::{TensorHandle, into_contiguous_pitched};
25
26use super::Algorithm;
27
28pub const TMA_STRIDE_ALIGN: usize = 16;
29
30/// Cmma convolution
31pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
32    _tmm: PhantomData<TMM>,
33}
34
35impl<TMM: TileMatmulFamily> Algorithm for SimpleTmaConvAlgorithm<TMM> {
36    type TileMatmul = TMM;
37    type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, FullReaderFamily>;
38    type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
39
40    type Args = TensorMapArgs;
41
42    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
43        let m_stage = selection.tiling_scheme.elements_in_stage_m();
44        let n_stage = selection.tiling_scheme.elements_in_stage_n();
45        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
46        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
47
48        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
49    }
50
51    fn into_tensor_handle<R: Runtime, E: Numeric>(
52        client: &ComputeClient<R::Server, R::Channel>,
53        handle: &TensorHandleRef<'_, R>,
54        ident: InputIdent,
55    ) -> TensorHandle<R, E> {
56        into_tensor_handle_tma(client, handle, ident)
57    }
58
59    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
60    fn num_stages() -> NumStages {
61        (1, 1).into()
62    }
63
64    fn selection<R: Runtime>(
65        client: &ComputeClient<R::Server, R::Channel>,
66        problem: &ConvolutionProblem,
67        plane_dim: u32,
68        elem_stage: Elem,
69        elem_acc: Elem,
70    ) -> MatmulSelection {
71        convolution_matmul_selection::<TMM, R>(client, problem, plane_dim, elem_stage, elem_acc)
72    }
73}
74
75pub(crate) fn into_tensor_handle_tma<R: Runtime, E: Numeric>(
76    client: &ComputeClient<R::Server, R::Channel>,
77    handle: &TensorHandleRef<'_, R>,
78    ident: InputIdent,
79) -> TensorHandle<R, E> {
80    let rank = handle.shape.len();
81    let dim_c = rank - 1;
82    let mut handle = if has_valid_layout(handle, ident) {
83        TensorHandle::from_ref(handle)
84    } else {
85        into_contiguous_pitched(client, handle)
86    };
87    match ident {
88        InputIdent::Lhs => handle,
89        InputIdent::Rhs => {
90            let k_size = handle.shape[1..dim_c].iter().product();
91            handle.shape = vec![handle.shape[0], k_size, handle.shape[dim_c]];
92            handle.strides = vec![
93                handle.strides[0],
94                handle.strides[dim_c - 1],
95                handle.strides[dim_c],
96            ];
97            handle
98        }
99    }
100}
101
102pub(crate) fn has_valid_layout<R: Runtime>(
103    handle: &TensorHandleRef<'_, R>,
104    ident: InputIdent,
105) -> bool {
106    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
107    let rank = handle.shape.len();
108    let dim_c = rank - 1;
109
110    let aligned = handle.strides[..dim_c]
111        .iter()
112        .all(|stride| stride % stride_align == 0);
113
114    let valid_layout = match ident {
115        InputIdent::Lhs => handle.strides[dim_c] == 1,
116        InputIdent::Rhs => {
117            let c_major = handle.strides[dim_c] == 1;
118            let mut kernel_contig = true;
119            for i in 1..dim_c - 1 {
120                kernel_contig &= handle.strides[i] == handle.strides[i + 1] * handle.shape[i + 1];
121            }
122            c_major && kernel_contig
123        }
124    };
125
126    valid_layout && aligned
127}
128
129pub(crate) fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
130    fn check_range(
131        value: isize,
132        name: impl FnOnce() -> String,
133        min: isize,
134        max: isize,
135    ) -> Result<(), InvalidConfigError> {
136        if value < min || value > max {
137            let name = name();
138            Err(Box::new(format!(
139                "value {name} outside of valid range ({min}, {max})"
140            )))
141        } else {
142            Ok(())
143        }
144    }
145
146    let (corner_min, corner_max) = match problem.dimensionality {
147        Dimensionality::Dim1 => (-(2isize.pow(15)), 2isize.pow(15) - 1),
148        Dimensionality::Dim2 => (-(2isize.pow(7)), 2isize.pow(7) - 1),
149        Dimensionality::Dim3 => (-(2isize.pow(4)), 2isize.pow(4) - 1),
150    };
151
152    let corner = calculate_upper_corner(&problem.padding, &problem.kernel_size, &problem.dilation);
153    for (i, offs) in corner.iter().enumerate() {
154        check_range(
155            *offs as isize,
156            || format!("corner[{i}]"),
157            corner_min,
158            corner_max,
159        )?;
160    }
161
162    let offset_max = match problem.dimensionality {
163        Dimensionality::Dim1 => 2isize.pow(16) - 1,
164        Dimensionality::Dim2 => 2isize.pow(8) - 1,
165        Dimensionality::Dim3 => 2isize.pow(5) - 1,
166    };
167
168    for i in 0..problem.kernel_size.len() {
169        let offset = (problem.kernel_size[i] - 1) * problem.dilation[i];
170        check_range(
171            offset as isize,
172            || format!("kernel size {i}"),
173            0,
174            offset_max,
175        )?;
176        check_range(problem.stride[i] as isize, || format!("stride[{i}]"), 1, 8)?;
177    }
178
179    Ok(())
180}
181
182pub fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
183    padding.iter().map(|padding| -*padding).collect()
184}
185
186pub fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
187    padding
188        .iter()
189        .zip(kernel_size)
190        .zip(dilation)
191        .map(|((padding, kernel_size), dilation)| {
192            *padding - (*kernel_size - 1) as i32 * *dilation as i32
193        })
194        .collect()
195}