cubecl_convolution/algorithm/
mod.rs

1use cubecl_matmul::components::{
2    AvailableLineSizes, InputIdent, LoadingPrecomputeStrategy, MatmulLineSizes, MatmulPrecision,
3    MatmulSelection, MatmulSetupError, MultiRowStrategy,
4    global::{LoadSpecializationConfig, args::MatmulArgs, load::LoaderMode},
5    stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6    tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::{ir::Elem, prelude::*};
12
13use super::base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem};
14
15pub mod multi_stage_tma;
16pub mod simple;
17pub mod simple_tma;
18
19/// Specifications for a convolution algorithm
20pub trait Algorithm {
21    type TileMatmul: TileMatmulFamily;
22    type StageMatmul: StageMatmulFamily;
23    type GlobalConvolution: ConvolutionFamily;
24
25    type Args: MatmulArgs;
26
27    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
28        let m_stage = selection.tiling_scheme.elements_in_stage_m();
29        let n_stage = selection.tiling_scheme.elements_in_stage_n();
30        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
31        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
32
33        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
34    }
35
36    fn num_stages() -> NumStages;
37
38    fn multi_row_strategy() -> MultiRowStrategy {
39        MultiRowStrategy::Never
40    }
41
42    fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
43        LoadingPrecomputeStrategy::Never
44    }
45
46    fn loader_mode() -> LoaderMode {
47        LoaderMode::Relaxed
48    }
49
50    fn load_specialization() -> LoadSpecializationConfig {
51        LoadSpecializationConfig::default()
52    }
53
54    fn partition_buffering_strategy() -> PartitionBuffering {
55        PartitionBuffering::Double
56    }
57
58    /// Make a convolution config from a convolution problem, and launch options
59    fn setup<R: Runtime, MP: MatmulPrecision>(
60        client: &ComputeClient<R::Server, R::Channel>,
61        problem: &ConvolutionProblem,
62        selection: &MatmulSelection,
63        line_sizes: &MatmulLineSizes,
64    ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, MatmulSetupError>
65    {
66        Self::GlobalConvolution::setup::<R, MP>(client, problem, selection, line_sizes)
67    }
68
69    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
70        Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
71            Self::TileMatmul::filter_line_sizes(available_line_sizes),
72        ))
73    }
74
75    fn into_tensor_handle<R: Runtime, E: Numeric>(
76        client: &ComputeClient<R::Server, R::Channel>,
77        handle: &TensorHandleRef<'_, R>,
78        ident: InputIdent,
79    ) -> TensorHandle<R, E>;
80
81    fn selection<R: Runtime>(
82        client: &ComputeClient<R::Server, R::Channel>,
83        problem: &ConvolutionProblem,
84        plane_dim: u32,
85        elem_stage: Elem,
86        elem_acc: Elem,
87    ) -> MatmulSelection;
88}