cubecl_convolution/algorithm/
mod.rs1use cubecl_matmul::components::{
2 AvailableLineSizes, InputIdent, LoadingPrecomputeStrategy, MatmulLineSizes, MatmulPrecision,
3 MatmulSelection, MatmulSetupError, MultiRowStrategy,
4 global::{LoadSpecializationConfig, args::MatmulArgs, load::LoaderMode},
5 stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6 tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::{ir::Elem, prelude::*};
12
13use super::base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem};
14
15pub mod multi_stage_tma;
16pub mod simple;
17pub mod simple_tma;
18
19pub trait Algorithm {
21 type TileMatmul: TileMatmulFamily;
22 type StageMatmul: StageMatmulFamily;
23 type GlobalConvolution: ConvolutionFamily;
24
25 type Args: MatmulArgs;
26
27 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
28 let m_stage = selection.tiling_scheme.elements_in_stage_m();
29 let n_stage = selection.tiling_scheme.elements_in_stage_n();
30 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
31 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
32
33 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
34 }
35
36 fn num_stages() -> NumStages;
37
38 fn multi_row_strategy() -> MultiRowStrategy {
39 MultiRowStrategy::Never
40 }
41
42 fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
43 LoadingPrecomputeStrategy::Never
44 }
45
46 fn loader_mode() -> LoaderMode {
47 LoaderMode::Relaxed
48 }
49
50 fn load_specialization() -> LoadSpecializationConfig {
51 LoadSpecializationConfig::default()
52 }
53
54 fn partition_buffering_strategy() -> PartitionBuffering {
55 PartitionBuffering::Double
56 }
57
58 fn setup<R: Runtime, MP: MatmulPrecision>(
60 client: &ComputeClient<R::Server, R::Channel>,
61 problem: &ConvolutionProblem,
62 selection: &MatmulSelection,
63 line_sizes: &MatmulLineSizes,
64 ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, MatmulSetupError>
65 {
66 Self::GlobalConvolution::setup::<R, MP>(client, problem, selection, line_sizes)
67 }
68
69 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
70 Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
71 Self::TileMatmul::filter_line_sizes(available_line_sizes),
72 ))
73 }
74
75 fn into_tensor_handle<R: Runtime, E: Numeric>(
76 client: &ComputeClient<R::Server, R::Channel>,
77 handle: &TensorHandleRef<'_, R>,
78 ident: InputIdent,
79 ) -> TensorHandle<R, E>;
80
81 fn selection<R: Runtime>(
82 client: &ComputeClient<R::Server, R::Channel>,
83 problem: &ConvolutionProblem,
84 plane_dim: u32,
85 elem_stage: Elem,
86 elem_acc: Elem,
87 ) -> MatmulSelection;
88}