cubecl_matmul/kernels/layered/algorithm/
vecmat.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::{
4    components::{
5        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
6        PartitionSize, TileSize, TilingScheme,
7        batch::{
8            CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection,
9            PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul, SmAllocation,
10        },
11        global::{
12            PlaneWriterFamily,
13            multi_stage::double_buffering::DoubleBufferingMatmulFamily,
14            read::{
15                sync_full_cyclic::SyncFullCyclicLoading,
16                sync_partial_cyclic::SyncPartialCyclicLoading,
17            },
18            single_stage::simple::SimpleMatmulFamily,
19        },
20        stage::{
21            ColMajorTilingOrder, FilledStageFamily, PartitionBuffering, PlaneMatmulFamily,
22            RowMajorTilingOrder, StridedStageFamily,
23        },
24        tile::{io::Filled, plane_vec_mat_inner_product::PlaneVecMatInnerProduct},
25    },
26    kernels::layered::Algorithm,
27};
28
29pub struct SimpleVecMatAlgorithm {}
30
31impl Algorithm for SimpleVecMatAlgorithm {
32    type SelectionArgs = ();
33    type TileMatmul = PlaneVecMatInnerProduct<Filled>;
34    type StageMatmul = PlaneMatmulFamily<
35        Self::TileMatmul,
36        StridedStageFamily,
37        StridedStageFamily,
38        FilledStageFamily,
39    >;
40    type GlobalMatmul = SimpleMatmulFamily<
41        Self::StageMatmul,
42        SyncFullCyclicLoading<RowMajorTilingOrder>,
43        SyncFullCyclicLoading<ColMajorTilingOrder>,
44        PlaneWriterFamily,
45    >;
46    type BatchMatmul =
47        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
48
49    fn selection<R: Runtime>(
50        client: &ComputeClient<R::Server>,
51        problem: &MatmulProblem,
52        plane_dim: u32,
53        line_sizes: &MatmulLineSizes,
54        _elems: MatmulElems,
55        _args: &Self::SelectionArgs,
56    ) -> Result<MatmulSelection, MatmulSetupError> {
57        Ok(selection_vecmat::<R>(
58            client,
59            problem,
60            (1, line_sizes.out as u32, plane_dim * line_sizes.lhs as u32).into(),
61            plane_dim,
62        ))
63    }
64}
65
66pub struct DoubleVecMatAlgorithm {}
67
68impl Algorithm for DoubleVecMatAlgorithm {
69    type SelectionArgs = ();
70    type TileMatmul = PlaneVecMatInnerProduct<Filled>;
71    type StageMatmul = PlaneMatmulFamily<
72        Self::TileMatmul,
73        StridedStageFamily,
74        StridedStageFamily,
75        FilledStageFamily,
76    >;
77    type GlobalMatmul = DoubleBufferingMatmulFamily<
78        Self::StageMatmul,
79        SyncPartialCyclicLoading<RowMajorTilingOrder>,
80        SyncPartialCyclicLoading<ColMajorTilingOrder>,
81        PlaneWriterFamily,
82    >;
83    type BatchMatmul =
84        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
85
86    fn selection<R: Runtime>(
87        client: &ComputeClient<R::Server>,
88        problem: &MatmulProblem,
89        plane_dim: u32,
90        line_sizes: &MatmulLineSizes,
91        _elems: MatmulElems,
92        _args: &Self::SelectionArgs,
93    ) -> Result<MatmulSelection, MatmulSetupError> {
94        Ok(selection_vecmat::<R>(
95            client,
96            problem,
97            (1, line_sizes.out as u32, plane_dim * line_sizes.lhs as u32).into(),
98            plane_dim,
99        ))
100    }
101}
102
103fn selection_vecmat<R: Runtime>(
104    client: &ComputeClient<R::Server>,
105    problem: &MatmulProblem,
106    tile_size: TileSize,
107    plane_dim: u32,
108) -> MatmulSelection {
109    let tiling_scheme = TilingScheme::builder()
110        .with_tile_size(tile_size)
111        .with_partition_size(PartitionSize::new(1, 1, 1))
112        .with_stage_size((1, 1, 1).into())
113        .build()
114        .unwrap();
115    let cube_count_plan = match client.properties().hardware.num_streaming_multiprocessors {
116        Some(num_sms) => CubeCountPlanSelection::Sm {
117            num_sms,
118            sm_usage: SmAllocation::Exact,
119            cubes_first: true,
120        },
121        None => CubeCountPlanSelection::FromProblem,
122    };
123
124    let hypercube = HypercubeSelection::builder(&tiling_scheme)
125        .global_order(GlobalOrderSelection::SwizzleRow {
126            m: problem.m as u32,
127            w: 2,
128        })
129        .cube_count_plan(cube_count_plan)
130        .build();
131
132    MatmulSelection::builder(tiling_scheme, plane_dim)
133        .partition_buffering(PartitionBuffering::Single)
134        .hypercube_config(hypercube)
135        .build()
136}