cubecl_linalg/convolution/algorithm/
simple_tma.rs

1use std::marker::PhantomData;
2
3use cubecl_core::{
4    CubeCount, CubeDim, Runtime,
5    client::ComputeClient,
6    prelude::{Numeric, TensorHandleRef},
7};
8
9use crate::{
10    convolution::{
11        base::{ConvolutionConfigFactory, ConvolutionProblem},
12        homogeneous::simple_tma::SimpleTmaConvolutionFamily,
13    },
14    matmul::components::{
15        InputIdent, InvalidConfigError, MatmulSelection,
16        global::args::TensorMapArgs,
17        stage::{FullReaderFamily, plane_matmul::PlaneMatmulFamily},
18        tile::TileMatmulFamily,
19    },
20    tensor::{TensorHandle, into_contiguous_pitched},
21};
22
23use super::Algorithm;
24
25pub const TMA_STRIDE_ALIGN: usize = 16;
26
27/// Cmma convolution
28pub struct SimpleTmaConvAlgorithm<TMM: TileMatmulFamily> {
29    _tmm: PhantomData<TMM>,
30}
31
32impl<TMM: TileMatmulFamily> Algorithm for SimpleTmaConvAlgorithm<TMM> {
33    type TileMatmul = TMM;
34    type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily>;
35    type GlobalConvolution = SimpleTmaConvolutionFamily<Self::StageMatmul>;
36
37    type Args = TensorMapArgs;
38
39    fn cube_dim(selection: &MatmulSelection) -> CubeDim {
40        CubeDim::new(selection.plane_dim, selection.tile_count.m, 1)
41    }
42
43    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
44        let m_stage = selection.tile_count.m * selection.tile_shape.m;
45        let n_stage = selection.tile_count.n * selection.tile_shape.n;
46        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
47        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
48
49        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
50    }
51
52    fn make_config(
53        input: <Self::GlobalConvolution as ConvolutionConfigFactory>::Input,
54        problem: &ConvolutionProblem,
55        cube_dim: &CubeDim,
56        cube_count: &CubeCount,
57    ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, InvalidConfigError>
58    {
59        check_problem_tma(problem)?;
60
61        let config = Self::GlobalConvolution::make_config(input, problem, cube_dim, cube_count);
62        Self::GlobalConvolution::check_config(&config)?;
63        Ok(config)
64    }
65
66    fn check_availability<R: Runtime, MP: crate::matmul::components::MatmulPrecision>(
67        client: &ComputeClient<R::Server, R::Channel>,
68        config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
69    ) -> Result<(), crate::matmul::kernels::MatmulAvailabilityError> {
70        <Self::GlobalConvolution as ConvolutionConfigFactory>::check_availability::<R, MP>(
71            client, config,
72        )?;
73
74        if !client
75            .properties()
76            .feature_enabled(cubecl_core::Feature::Tma(cubecl_core::TmaFeature::Base))
77        {
78            return Err(crate::matmul::kernels::MatmulAvailabilityError::TmaUnavailable);
79        }
80
81        Ok(())
82    }
83
84    fn into_tensor_handle<R: Runtime, E: Numeric>(
85        client: &ComputeClient<R::Server, R::Channel>,
86        handle: &TensorHandleRef<'_, R>,
87        ident: InputIdent,
88    ) -> TensorHandle<R, E> {
89        let mut handle = if has_valid_layout(handle, ident) {
90            TensorHandle::from_ref(handle)
91        } else {
92            into_contiguous_pitched(client, handle)
93        };
94        match ident {
95            InputIdent::Lhs => handle,
96            InputIdent::Rhs => {
97                handle.shape = vec![
98                    handle.shape[0],
99                    handle.shape[1] * handle.shape[2],
100                    handle.shape[3],
101                ];
102                handle.strides = vec![handle.strides[0], handle.strides[2], handle.strides[3]];
103                handle
104            }
105        }
106    }
107}
108
109fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
110    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
111
112    let aligned = handle.strides[..3]
113        .iter()
114        .all(|stride| stride % stride_align == 0);
115
116    let valid_layout = match ident {
117        InputIdent::Lhs => handle.strides[3] == 1,
118        InputIdent::Rhs => {
119            let c_major = handle.strides[3] == 1;
120            let kernel_contig = handle.strides[2] * handle.shape[2] == handle.strides[1];
121            c_major && kernel_contig
122        }
123    };
124
125    valid_layout && aligned
126}
127
128fn check_problem_tma(problem: &ConvolutionProblem) -> Result<(), InvalidConfigError> {
129    fn check_range(
130        value: isize,
131        name: &str,
132        min: isize,
133        max: isize,
134    ) -> Result<(), InvalidConfigError> {
135        if value < min || value > max {
136            Err(Box::new(format!(
137                "value {name} outside of valid range ({min}, {max})"
138            )))
139        } else {
140            Ok(())
141        }
142    }
143
144    let corner = calculate_upper_corner(problem.padding, problem.kernel_size, problem.dilation);
145    check_range(corner[0] as isize, "corner_h", -128, 127)?;
146    check_range(corner[1] as isize, "corner_w", -128, 127)?;
147
148    let offset_h = (problem.kernel_size.0 - 1) * problem.dilation.0;
149    let offset_w = (problem.kernel_size.1 - 1) * problem.dilation.1;
150    check_range(offset_h as isize, "kernel size h", 0, 255)?;
151    check_range(offset_w as isize, "kernel size w", 0, 255)?;
152
153    check_range(problem.stride.0 as isize, "stride_h", 1, 8)?;
154    check_range(problem.stride.1 as isize, "stride_w", 1, 8)?;
155
156    Ok(())
157}
158
159pub fn calculate_upper_corner(
160    padding: (i32, i32),
161    kernel_size: (u32, u32),
162    dilation: (u32, u32),
163) -> Vec<i32> {
164    let corner_h = padding.0 - (kernel_size.0 - 1) as i32 * dilation.0 as i32;
165    let corner_w = padding.1 - (kernel_size.1 - 1) as i32 * dilation.1 as i32;
166
167    vec![corner_h, corner_w]
168}