cubek_convolution/kernels/forward/algorithm/
mod.rs

1use cubek_matmul::definition::{
2    AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulLineSizes, MatmulSetupError,
3    MultiRowStrategy, TilingBlueprint,
4};
5use cubek_matmul::{
6    components::{
7        global::{LoadFlows, read::ReaderMode},
8        stage::{PartitionBuffering, StageMatmulFamily},
9        tile::TileMatmulFamily,
10    },
11    launch::MatmulArgs,
12};
13
14use cubecl::{
15    ir::DeviceProperties,
16    std::tensor::{TensorHandle, into_contiguous_pitched_ref, is_contiguous_pitched},
17};
18
19use cubecl::prelude::*;
20
21use crate::components::{
22    ConvolutionOperation, ConvolutionProblem,
23    global::{GlobalConfig, GlobalConvolutionFamily},
24};
25
26pub mod simple;
27
28/// Specifications for a convolution algorithm
29pub trait Algorithm {
30    type TileMatmul: TileMatmulFamily;
31    type StageMatmul: StageMatmulFamily;
32    type GlobalConvolution: GlobalConvolutionFamily;
33
34    type Args: MatmulArgs;
35
36    fn cube_count(selection: &TilingBlueprint, problem: &ConvolutionProblem) -> CubeCount {
37        let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
38        let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
39        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
40        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
41
42        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
43    }
44
45    fn multi_row_strategy() -> MultiRowStrategy {
46        MultiRowStrategy::Never
47    }
48
49    fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
50        LoadingPrecomputeStrategy::Never
51    }
52
53    fn reader_mode() -> ReaderMode {
54        ReaderMode::Relaxed
55    }
56
57    fn load_specialization() -> LoadFlows {
58        LoadFlows::default()
59    }
60
61    fn partition_buffering_strategy() -> PartitionBuffering {
62        PartitionBuffering::Double
63    }
64
65    /// Make a convolution config from a convolution problem, and launch options
66    fn expand_config(
67        device_props: &DeviceProperties,
68        problem: &ConvolutionProblem,
69        selection: &TilingBlueprint,
70        line_sizes: &MatmulLineSizes,
71        dtypes: &MatmulElems,
72    ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
73        Self::GlobalConvolution::expand_config(device_props, problem, selection, line_sizes, dtypes)
74    }
75
76    fn into_tensor_handle<R: Runtime>(
77        client: &ComputeClient<R>,
78        handle: &TensorHandleRef<'_, R>,
79        dtype: StorageType,
80        operation: ConvolutionOperation,
81    ) -> Result<TensorHandle<R>, LaunchError>;
82
83    fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
84        line_sizes
85    }
86
87    fn selection<R: Runtime>(
88        client: &ComputeClient<R>,
89        problem: &ConvolutionProblem,
90        plane_dim: u32,
91        line_sizes: &MatmulLineSizes,
92        matmul_elems: &mut MatmulElems,
93    ) -> Result<TilingBlueprint, MatmulSetupError>;
94}
95
96pub(crate) fn into_tensor_handle<R: Runtime>(
97    client: &ComputeClient<R>,
98    handle: &TensorHandleRef<'_, R>,
99    dtype: StorageType,
100) -> Result<TensorHandle<R>, LaunchError> {
101    let handle = if has_valid_layout(handle) {
102        TensorHandle::from_ref(handle, dtype)
103    } else {
104        into_contiguous_pitched_ref(client, handle, dtype)?
105    };
106    Ok(handle)
107}
108
109fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
110    let rank = handle.shape.len();
111    let dim_c = rank - 1;
112    handle.strides[dim_c] == 1
113}
114
115const TMA_STRIDE_ALIGN: usize = 16;
116
117pub(crate) fn into_tensor_handle_tma<R: Runtime>(
118    client: &ComputeClient<R>,
119    handle: &TensorHandleRef<'_, R>,
120    dtype: StorageType,
121    operation: ConvolutionOperation,
122) -> Result<TensorHandle<R>, LaunchError> {
123    let handle = if has_valid_layout_tma(handle, operation) {
124        TensorHandle::from_ref(handle, dtype)
125    } else {
126        into_contiguous_pitched_ref(client, handle, dtype)?
127    };
128    Ok(handle)
129}
130
131pub(crate) fn has_valid_layout_tma<R: Runtime>(
132    handle: &TensorHandleRef<'_, R>,
133    operation: ConvolutionOperation,
134) -> bool {
135    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
136    let rank = handle.shape.len();
137    let dim_c = rank - 1;
138
139    let aligned = handle.strides[..dim_c]
140        .iter()
141        .all(|stride| stride % stride_align == 0);
142
143    let valid_layout = handle.strides[dim_c] == 1;
144
145    let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
146        is_contiguous_pitched(handle.shape, handle.strides)
147    } else {
148        true
149    };
150
151    valid_layout && aligned && is_valid_wgrad
152}