cubecl_linalg/convolution/algorithm/
mod.rs1use crate::{
2 matmul::{
3 components::{
4 CompleteStageTiling, InputIdent, InvalidConfigError, MatmulPrecision, MatmulSelection,
5 global::args::MatmulArgs,
6 stage::{StageBuffering, StageMatmulFamily, StageVectorization},
7 tile::TileMatmulFamily,
8 },
9 kernels::MatmulAvailabilityError,
10 },
11 tensor::TensorHandle,
12};
13use cubecl_core::prelude::*;
14
15use super::base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem};
16
17pub mod simple;
18pub mod simple_tma;
19
20pub type StageInput = (CompleteStageTiling, StageBuffering, StageVectorization);
21
22pub trait Algorithm {
24 type TileMatmul: TileMatmulFamily;
25 type StageMatmul: StageMatmulFamily<Input = StageInput>;
26 type GlobalConvolution: ConvolutionFamily<Input = StageInput>;
27
28 type Args: MatmulArgs;
29
30 fn cube_dim(selection: &MatmulSelection) -> CubeDim;
31 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount;
32
33 fn make_config(
35 input: <Self::GlobalConvolution as ConvolutionConfigFactory>::Input,
36 problem: &ConvolutionProblem,
37 cube_dim: &CubeDim,
38 cube_count: &CubeCount,
39 ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, InvalidConfigError>
40 {
41 let config = Self::GlobalConvolution::make_config(input, problem, cube_dim, cube_count);
42 Self::GlobalConvolution::check_config(&config)?;
43 Ok(config)
44 }
45
46 fn check_availability<R: Runtime, MP: MatmulPrecision>(
47 client: &ComputeClient<R::Server, R::Channel>,
48 config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
49 ) -> Result<(), MatmulAvailabilityError> {
50 <Self::GlobalConvolution as ConvolutionConfigFactory>::check_availability::<R, MP>(
51 client, config,
52 )
53 }
54
55 fn into_tensor_handle<R: Runtime, E: Numeric>(
56 client: &ComputeClient<R::Server, R::Channel>,
57 handle: &TensorHandleRef<'_, R>,
58 ident: InputIdent,
59 ) -> TensorHandle<R, E>;
60}