cubecl_matmul/components/tile/accelerated/
setup.rs

1use crate::components::tile::accelerated::config::AcceleratedConfig;
2use crate::components::tile::accelerated::matmul::AcceleratedMatmul;
3use crate::components::tile::{
4    TileMatmulFamily,
5    accelerated::reader::{CmmaFragmentReader, CmmaStageReader},
6};
7use crate::components::{InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection};
8use crate::components::{error::MatmulSetupError, tile::io::Strided};
9use crate::components::{resource::ComputeResources, tile::io::TileKind};
10use cubecl_core::prelude::*;
11
12impl<Tile: TileKind> TileMatmulFamily for AcceleratedMatmul<Tile>
13where
14    CmmaStageReader<Tile>: CmmaFragmentReader<TileKind = Tile>,
15{
16    type Matmul<L: Numeric, R: Numeric, A: Numeric> = AcceleratedMatmul<Tile>;
17    type LhsTile = Strided;
18    type RhsTile = Strided;
19    type AccTile = Tile;
20    type OutTile = Strided;
21
22    type Config = AcceleratedConfig;
23
24    fn requires_accelerator() -> bool {
25        true
26    }
27
28    fn computation_resources() -> Result<ComputeResources, InvalidConfigError> {
29        Ok(ComputeResources::Planes(1))
30    }
31
32    fn setup<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, R: Runtime>(
33        client: &ComputeClient<R::Server>,
34        problem: &MatmulProblem,
35        selection: &MatmulSelection,
36        matmul_line_sizes: &MatmulLineSizes,
37    ) -> Result<Self::Config, MatmulSetupError> {
38        AcceleratedConfig::new::<Lhs, Rhs, Acc, R>(
39            client,
40            selection.tiling_scheme.tile_size,
41            selection.plane_dim,
42            problem.lhs_layout,
43            problem.rhs_layout,
44            matmul_line_sizes.lhs as u32,
45            matmul_line_sizes.rhs as u32,
46            matmul_line_sizes.out as u32,
47            matmul_line_sizes.lhs as u32,
48            matmul_line_sizes.rhs as u32,
49        )
50    }
51}