cubecl_matmul/components/tile/mma/
setup.rs1use crate::components::tile::{
2 TileMatmulFamily,
3 mma::{
4 MmaMatmul,
5 config::MmaMatmulConfig,
6 reader::{MmaFragmentReader, MmaStageReader},
7 },
8};
9use crate::components::{InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection};
10use crate::components::{error::MatmulSetupError, tile::io::Strided};
11use crate::components::{resource::ComputeResources, tile::io::TileKind};
12use cubecl_core::prelude::*;
13
14impl<Tile: TileKind> TileMatmulFamily for MmaMatmul<Tile>
15where
16 MmaStageReader<Tile>: MmaFragmentReader<TileKind = Tile>,
17{
18 type Matmul<L: Numeric, R: Numeric, A: Numeric> = MmaMatmul<Tile>;
19 type LhsTile = Strided;
20 type RhsTile = Strided;
21 type AccTile = Tile;
22 type OutTile = Strided;
23
24 type Config = MmaMatmulConfig;
25
26 fn requires_accelerator() -> bool {
27 true
28 }
29
30 fn computation_resources() -> Result<ComputeResources, InvalidConfigError> {
31 Ok(ComputeResources::Planes(1))
32 }
33
34 fn setup<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, R: Runtime>(
35 client: &ComputeClient<R::Server>,
36 problem: &MatmulProblem,
37 selection: &MatmulSelection,
38 matmul_line_sizes: &MatmulLineSizes,
39 ) -> Result<Self::Config, MatmulSetupError> {
40 MmaMatmulConfig::new::<Lhs, Rhs, Acc, R>(
41 client,
42 selection.tiling_scheme.tile_size,
43 selection.plane_dim,
44 problem.lhs_layout,
45 problem.rhs_layout,
46 matmul_line_sizes.lhs as u32,
47 matmul_line_sizes.rhs as u32,
48 matmul_line_sizes.out as u32,
49 matmul_line_sizes.lhs as u32,
50 matmul_line_sizes.rhs as u32,
51 )
52 }
53}