cubecl_convolution/kernels/layered/algorithm/
mod.rs

1use cubecl_matmul::components::{
2    AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulIdent, MatmulLineSizes,
3    MatmulSelection, MatmulSetupError, MultiRowStrategy,
4    global::{LoadSpecializationConfig, args::MatmulArgs, read::ReaderMode},
5    stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6    tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::prelude::*;
12
13use crate::components::{
14    ConvolutionProblem,
15    global::{GlobalConfig, GlobalConvolutionFamily},
16};
17
18pub mod multi_stage_tma;
19pub mod simple;
20pub mod simple_tma;
21
22/// Specifications for a convolution algorithm
23pub trait Algorithm {
24    type TileMatmul: TileMatmulFamily;
25    type StageMatmul: StageMatmulFamily;
26    type GlobalConvolution: GlobalConvolutionFamily;
27
28    type Args: MatmulArgs;
29
30    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
31        let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
32        let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
33        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
34        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
35
36        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
37    }
38
39    fn num_stages() -> NumStages;
40
41    fn multi_row_strategy() -> MultiRowStrategy {
42        MultiRowStrategy::Never
43    }
44
45    fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
46        LoadingPrecomputeStrategy::Never
47    }
48
49    fn reader_mode() -> ReaderMode {
50        ReaderMode::Relaxed
51    }
52
53    fn load_specialization() -> LoadSpecializationConfig {
54        LoadSpecializationConfig::default()
55    }
56
57    fn partition_buffering_strategy() -> PartitionBuffering {
58        PartitionBuffering::Double
59    }
60
61    /// Make a convolution config from a convolution problem, and launch options
62    fn setup<R: Runtime>(
63        client: &ComputeClient<R>,
64        problem: &ConvolutionProblem,
65        selection: &MatmulSelection,
66        line_sizes: &MatmulLineSizes,
67        dtypes: &MatmulElems,
68    ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
69        Self::GlobalConvolution::setup(client, problem, selection, line_sizes, dtypes)
70    }
71
72    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
73        Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
74            Self::TileMatmul::filter_line_sizes(available_line_sizes),
75        ))
76    }
77
78    fn into_tensor_handle<R: Runtime>(
79        client: &ComputeClient<R>,
80        handle: &TensorHandleRef<'_, R>,
81        ident: MatmulIdent,
82        dtype: StorageType,
83    ) -> Result<TensorHandle<R>, LaunchError>;
84
85    fn selection<R: Runtime>(
86        client: &ComputeClient<R>,
87        problem: &ConvolutionProblem,
88        plane_dim: u32,
89        matmul_elems: &mut MatmulElems,
90    ) -> Result<MatmulSelection, MatmulSetupError>;
91}