cubecl_matmul/components/tile/register/
setup.rs

1use crate::components::tile::register::config::RegisterConfig;
2use crate::components::tile::register::matmul::RegisterMatmul;
3use crate::components::tile::{TileMatmulFamily, io::Strided};
4use crate::components::{
5    AvailableLineSizes, InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection,
6};
7use crate::components::{error::MatmulSetupError, tile::io::TileKind};
8use crate::components::{
9    resource::ComputeResources,
10    tile::register::reader::{RegisterFragmentReader, RegisterStageReader},
11};
12use cubecl_core::prelude::*;
13
14impl<AccTile: TileKind> TileMatmulFamily for RegisterMatmul<AccTile>
15where
16    RegisterStageReader<AccTile>: RegisterFragmentReader<TileKind = AccTile>,
17{
18    type Matmul<L: Numeric, R: Numeric, A: Numeric> = RegisterMatmul<AccTile>;
19    type Config = RegisterConfig;
20
21    type LhsTile = Strided;
22    type RhsTile = Strided;
23    type AccTile = AccTile;
24    type OutTile = Strided;
25
26    fn requires_accelerator() -> bool {
27        false
28    }
29
30    fn computation_resources() -> Result<ComputeResources, InvalidConfigError> {
31        Ok(ComputeResources::Units(1))
32    }
33
34    fn setup<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, R: Runtime>(
35        client: &ComputeClient<R::Server>,
36        problem: &MatmulProblem,
37        selection: &MatmulSelection,
38        matmul_line_sizes: &MatmulLineSizes,
39    ) -> Result<Self::Config, MatmulSetupError> {
40        RegisterConfig::new::<Lhs, Rhs, Acc, R>(
41            client,
42            selection.tiling_scheme.tile_size,
43            selection.plane_dim,
44            problem.lhs_layout,
45            problem.rhs_layout,
46            matmul_line_sizes.lhs as u32,
47            matmul_line_sizes.rhs as u32,
48            matmul_line_sizes.out as u32,
49            matmul_line_sizes.lhs as u32,
50            matmul_line_sizes.rhs as u32,
51        )
52    }
53
54    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
55        available_line_sizes
56    }
57}