cubecl_matmul/components/tile/accelerated/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::tile::{TileConfig, TileMatmul, accelerated::reader::CmmaFragmentReader};
4use crate::components::tile::{accelerated::writer::CmmaStageWriter, tile_data::StridedTile};
5use crate::components::tile::{
6 accelerated::{config::AcceleratedConfig, reader::CmmaStageReader},
7 io::{Strided, TileKind},
8};
9use crate::components::{StageIdent, as_cmma_layout};
10use cubecl_core as cubecl;
11use cubecl_core::{cmma, prelude::*};
12use cubecl_std::CubeOption;
13
14pub struct AcceleratedMatmul<Acc: TileKind> {
16 _ty: PhantomData<Acc>,
17}
18
19#[cube]
20impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
21 for AcceleratedMatmul<AccTile>
22where
23 CmmaStageReader<AccTile>: CmmaFragmentReader<TileKind = AccTile>,
24{
25 type Config = AcceleratedConfig;
26 type LhsFragment = cmma::Matrix<L>;
27 type RhsFragment = cmma::Matrix<R>;
28 type AccFragment = cmma::Matrix<A>;
29
30 type LhsTile = Strided;
31 type RhsTile = Strided;
32 type AccTile = AccTile;
33 type OutTile = Strided;
34
35 fn execute(
36 lhs: &Self::LhsFragment,
37 rhs: &Self::RhsFragment,
38 out: &mut Self::AccFragment,
39 #[comptime] _config: Self::Config,
40 ) {
41 cmma::execute::<L, R, A, A>(lhs, rhs, out, out);
42 }
43
44 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
45 let size = config.tile_size();
46 let layout = config.matrix_layout(StageIdent::Lhs);
47 unsafe {
48 cmma::Matrix::<L>::uninitialized(
49 cmma::MatrixIdent::A,
50 size.m(),
51 size.n(),
52 size.k(),
53 as_cmma_layout(layout),
54 )
55 }
56 }
57
58 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
59 let size = config.tile_size();
60 let layout = config.matrix_layout(StageIdent::Rhs);
61 unsafe {
62 cmma::Matrix::<R>::uninitialized(
63 cmma::MatrixIdent::B,
64 size.m(),
65 size.n(),
66 size.k(),
67 as_cmma_layout(layout),
68 )
69 }
70 }
71
72 fn load_lhs<E: Numeric>(
73 tile: &StridedTile<E>,
74 lhs: &mut Self::LhsFragment,
75 #[comptime] _config: Self::Config,
76 ) {
77 CmmaStageReader::<Self::LhsTile>::load_fragment(tile, lhs, CubeOption::new_None());
78 }
79
80 fn load_rhs<E: Numeric>(
81 tile: &StridedTile<E>,
82 rhs: &mut Self::RhsFragment,
83 #[comptime] _config: Self::Config,
84 ) {
85 CmmaStageReader::<Self::RhsTile>::load_fragment(tile, rhs, CubeOption::new_None());
86 }
87
88 fn load_acc<E: Numeric>(
89 tile: &AccTile::Tile<E>,
90 acc: &mut Self::AccFragment,
91 #[comptime] config: Self::Config,
92 ) {
93 let layout = comptime!(as_cmma_layout(config.matrix_layout(StageIdent::Acc)));
94 CmmaStageReader::<Self::AccTile>::load_fragment(tile, acc, CubeOption::new_Some(layout));
95 }
96
97 fn write_results<E: Numeric>(
98 tile: &mut StridedTile<E, ReadWrite>,
99 out: &Self::AccFragment,
100 #[comptime] _config: Self::Config,
101 ) {
102 let out = cmma::cast::<A, E>(out);
103 CmmaStageWriter::store_fragment(tile, &out);
104 }
105
106 fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
107 let size = config.tile_size();
108 unsafe {
109 cmma::Matrix::<A>::uninitialized(
110 cmma::MatrixIdent::Accumulator,
111 size.m(),
112 size.n(),
113 size.k(),
114 cmma::MatrixLayout::Undefined,
115 )
116 }
117 }
118}