cubecl_convolution/kernels/layered/algorithm/
mod.rs1use cubecl_matmul::components::{
2 AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulIdent, MatmulLineSizes,
3 MatmulSelection, MatmulSetupError, MultiRowStrategy,
4 global::{LoadSpecializationConfig, args::MatmulArgs, read::ReaderMode},
5 stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6 tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::prelude::*;
12
13use crate::components::{
14 ConvolutionProblem,
15 global::{GlobalConfig, GlobalConvolutionFamily},
16};
17
18pub mod multi_stage_tma;
19pub mod simple;
20pub mod simple_tma;
21
22pub trait Algorithm {
24 type TileMatmul: TileMatmulFamily;
25 type StageMatmul: StageMatmulFamily;
26 type GlobalConvolution: GlobalConvolutionFamily;
27
28 type Args: MatmulArgs;
29
30 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
31 let m_stage = selection.tiling_scheme.elements_in_stage_m();
32 let n_stage = selection.tiling_scheme.elements_in_stage_n();
33 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
34 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
35
36 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
37 }
38
39 fn num_stages() -> NumStages;
40
41 fn multi_row_strategy() -> MultiRowStrategy {
42 MultiRowStrategy::Never
43 }
44
45 fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
46 LoadingPrecomputeStrategy::Never
47 }
48
49 fn reader_mode() -> ReaderMode {
50 ReaderMode::Relaxed
51 }
52
53 fn load_specialization() -> LoadSpecializationConfig {
54 LoadSpecializationConfig::default()
55 }
56
57 fn partition_buffering_strategy() -> PartitionBuffering {
58 PartitionBuffering::Double
59 }
60
61 fn setup<R: Runtime>(
63 client: &ComputeClient<R::Server>,
64 problem: &ConvolutionProblem,
65 selection: &MatmulSelection,
66 line_sizes: &MatmulLineSizes,
67 dtypes: &MatmulElems,
68 ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
69 Self::GlobalConvolution::setup::<R>(client, problem, selection, line_sizes, dtypes)
70 }
71
72 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
73 Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
74 Self::TileMatmul::filter_line_sizes(available_line_sizes),
75 ))
76 }
77
78 fn into_tensor_handle<R: Runtime>(
79 client: &ComputeClient<R::Server>,
80 handle: &TensorHandleRef<'_, R>,
81 ident: MatmulIdent,
82 dtype: StorageType,
83 ) -> TensorHandle<R>;
84
85 fn selection<R: Runtime>(
86 client: &ComputeClient<R::Server>,
87 problem: &ConvolutionProblem,
88 plane_dim: u32,
89 matmul_elems: &mut MatmulElems,
90 ) -> Result<MatmulSelection, MatmulSetupError>;
91}