cubecl_linalg/convolution/algorithm/
mod.rs

1use crate::{
2    matmul::{
3        components::{
4            CompleteStageTiling, InputIdent, InvalidConfigError, MatmulPrecision, MatmulSelection,
5            global::args::MatmulArgs,
6            stage::{StageBuffering, StageMatmulFamily, StageVectorization},
7            tile::TileMatmulFamily,
8        },
9        kernels::MatmulAvailabilityError,
10    },
11    tensor::TensorHandle,
12};
13use cubecl_core::prelude::*;
14
15use super::base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem};
16
17pub mod simple;
18pub mod simple_tma;
19
20pub type StageInput = (CompleteStageTiling, StageBuffering, StageVectorization);
21
22/// Specifications for a convolution algorithm
23pub trait Algorithm {
24    type TileMatmul: TileMatmulFamily;
25    type StageMatmul: StageMatmulFamily<Input = StageInput>;
26    type GlobalConvolution: ConvolutionFamily<Input = StageInput>;
27
28    type Args: MatmulArgs;
29
30    fn cube_dim(selection: &MatmulSelection) -> CubeDim;
31    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount;
32
33    /// Make a convolution config from a convolution problem, and launch options
34    fn make_config(
35        input: <Self::GlobalConvolution as ConvolutionConfigFactory>::Input,
36        problem: &ConvolutionProblem,
37        cube_dim: &CubeDim,
38        cube_count: &CubeCount,
39    ) -> Result<<Self::GlobalConvolution as ConvolutionConfigFactory>::Config, InvalidConfigError>
40    {
41        let config = Self::GlobalConvolution::make_config(input, problem, cube_dim, cube_count);
42        Self::GlobalConvolution::check_config(&config)?;
43        Ok(config)
44    }
45
46    fn check_availability<R: Runtime, MP: MatmulPrecision>(
47        client: &ComputeClient<R::Server, R::Channel>,
48        config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
49    ) -> Result<(), MatmulAvailabilityError> {
50        <Self::GlobalConvolution as ConvolutionConfigFactory>::check_availability::<R, MP>(
51            client, config,
52        )
53    }
54
55    fn into_tensor_handle<R: Runtime, E: Numeric>(
56        client: &ComputeClient<R::Server, R::Channel>,
57        handle: &TensorHandleRef<'_, R>,
58        ident: InputIdent,
59    ) -> TensorHandle<R, E>;
60}