cubecl_convolution/kernels/layered/algorithm/
mod.rs

1use cubecl_matmul::components::{
2    AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulIdent, MatmulLineSizes,
3    MatmulPrecision, MatmulSelection, MatmulSetupError, MultiRowStrategy,
4    global::{LoadSpecializationConfig, args::MatmulArgs, read::ReaderMode},
5    stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6    tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::prelude::*;
12
13use crate::components::{
14    ConvolutionProblem,
15    global::{GlobalConfig, GlobalConvolutionFamily},
16};
17
18pub mod multi_stage_tma;
19pub mod simple;
20pub mod simple_tma;
21
22/// Specifications for a convolution algorithm
23pub trait Algorithm {
24    type TileMatmul: TileMatmulFamily;
25    type StageMatmul: StageMatmulFamily;
26    type GlobalConvolution: GlobalConvolutionFamily;
27
28    type Args: MatmulArgs;
29
30    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
31        let m_stage = selection.tiling_scheme.elements_in_stage_m();
32        let n_stage = selection.tiling_scheme.elements_in_stage_n();
33        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
34        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
35
36        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
37    }
38
39    fn num_stages() -> NumStages;
40
41    fn multi_row_strategy() -> MultiRowStrategy {
42        MultiRowStrategy::Never
43    }
44
45    fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
46        LoadingPrecomputeStrategy::Never
47    }
48
49    fn reader_mode() -> ReaderMode {
50        ReaderMode::Relaxed
51    }
52
53    fn load_specialization() -> LoadSpecializationConfig {
54        LoadSpecializationConfig::default()
55    }
56
57    fn partition_buffering_strategy() -> PartitionBuffering {
58        PartitionBuffering::Double
59    }
60
61    /// Make a convolution config from a convolution problem, and launch options
62    fn setup<R: Runtime, MP: MatmulPrecision>(
63        client: &ComputeClient<R::Server>,
64        problem: &ConvolutionProblem,
65        selection: &MatmulSelection,
66        line_sizes: &MatmulLineSizes,
67    ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
68        Self::GlobalConvolution::setup::<R, MP>(client, problem, selection, line_sizes)
69    }
70
71    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
72        Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
73            Self::TileMatmul::filter_line_sizes(available_line_sizes),
74        ))
75    }
76
77    fn into_tensor_handle<R: Runtime, E: Numeric>(
78        client: &ComputeClient<R::Server>,
79        handle: &TensorHandleRef<'_, R>,
80        ident: MatmulIdent,
81    ) -> TensorHandle<R, E>;
82
83    fn selection<R: Runtime>(
84        client: &ComputeClient<R::Server>,
85        problem: &ConvolutionProblem,
86        plane_dim: u32,
87        matmul_elems: MatmulElems,
88    ) -> Result<MatmulSelection, MatmulSetupError>;
89}