cubecl_convolution/kernels/layered/algorithm/
mod.rs1use cubecl_matmul::components::{
2 AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulIdent, MatmulLineSizes,
3 MatmulPrecision, MatmulSelection, MatmulSetupError, MultiRowStrategy,
4 global::{LoadSpecializationConfig, args::MatmulArgs, read::ReaderMode},
5 stage::{NumStages, PartitionBuffering, StageMatmulFamily},
6 tile::TileMatmulFamily,
7};
8
9use cubecl_std::tensor::TensorHandle;
10
11use cubecl_core::prelude::*;
12
13use crate::components::{
14 ConvolutionProblem,
15 global::{GlobalConfig, GlobalConvolutionFamily},
16};
17
18pub mod multi_stage_tma;
19pub mod simple;
20pub mod simple_tma;
21
22pub trait Algorithm {
24 type TileMatmul: TileMatmulFamily;
25 type StageMatmul: StageMatmulFamily;
26 type GlobalConvolution: GlobalConvolutionFamily;
27
28 type Args: MatmulArgs;
29
30 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
31 let m_stage = selection.tiling_scheme.elements_in_stage_m();
32 let n_stage = selection.tiling_scheme.elements_in_stage_n();
33 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
34 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
35
36 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
37 }
38
39 fn num_stages() -> NumStages;
40
41 fn multi_row_strategy() -> MultiRowStrategy {
42 MultiRowStrategy::Never
43 }
44
45 fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
46 LoadingPrecomputeStrategy::Never
47 }
48
49 fn reader_mode() -> ReaderMode {
50 ReaderMode::Relaxed
51 }
52
53 fn load_specialization() -> LoadSpecializationConfig {
54 LoadSpecializationConfig::default()
55 }
56
57 fn partition_buffering_strategy() -> PartitionBuffering {
58 PartitionBuffering::Double
59 }
60
61 fn setup<R: Runtime, MP: MatmulPrecision>(
63 client: &ComputeClient<R::Server>,
64 problem: &ConvolutionProblem,
65 selection: &MatmulSelection,
66 line_sizes: &MatmulLineSizes,
67 ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
68 Self::GlobalConvolution::setup::<R, MP>(client, problem, selection, line_sizes)
69 }
70
71 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
72 Self::GlobalConvolution::filter_line_sizes(Self::StageMatmul::filter_line_sizes(
73 Self::TileMatmul::filter_line_sizes(available_line_sizes),
74 ))
75 }
76
77 fn into_tensor_handle<R: Runtime, E: Numeric>(
78 client: &ComputeClient<R::Server>,
79 handle: &TensorHandleRef<'_, R>,
80 ident: MatmulIdent,
81 ) -> TensorHandle<R, E>;
82
83 fn selection<R: Runtime>(
84 client: &ComputeClient<R::Server>,
85 problem: &ConvolutionProblem,
86 plane_dim: u32,
87 matmul_elems: MatmulElems,
88 ) -> Result<MatmulSelection, MatmulSetupError>;
89}