cubek_std/tile/compute/matmul/
mod.rs1pub mod cmma;
7pub mod interleaved;
8pub mod mma;
9pub mod plane_vec;
10pub mod register;
11
12use cubecl::prelude::*;
13
14use crate::tile::{
15 MmaFragment, MmaFragmentExpand, Tile, TileExpand, TileScope,
16 compute::matmul::{
17 cmma::cmma_execute, interleaved::interleaved_execute, mma::mma_execute,
18 plane_vec::planevec_execute, register::register_execute,
19 },
20};
21
22#[cube]
23impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
24 pub fn mma<L: Numeric, R: Numeric>(
26 &mut self,
27 lhs: &Tile<L, Sc, ReadWrite>,
28 rhs: &Tile<R, Sc, ReadWrite>,
29 ) {
30 match (lhs, rhs, self) {
31 (Tile::Cmma(l), Tile::Cmma(r), Tile::Cmma(a)) => {
32 cmma_execute(&l.matrix, &r.matrix, &mut a.matrix);
33 }
34 (Tile::Cmma(l), Tile::Cmma(r), Tile::Bounce(a)) => {
35 cmma_execute(&l.matrix, &r.matrix, &mut a.cmma.matrix);
36 }
37 (Tile::Bounce(l), Tile::Cmma(r), Tile::Bounce(a)) => {
38 cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.cmma.matrix);
39 }
40 (Tile::Bounce(l), Tile::Cmma(r), Tile::Cmma(a)) => {
41 cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.matrix);
42 }
43 (Tile::Mma(l), Tile::Mma(r), Tile::Mma(a)) => match &l.fragment {
44 MmaFragment::Lhs(lf) => match &r.fragment {
45 MmaFragment::Rhs(rf) => match &mut a.fragment {
46 MmaFragment::Acc(af) => {
47 mma_execute(lf, rf, af, a.matrix_layout, a.config);
48 }
49 MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
50 panic!("Mma: expected Acc role for accumulator")
51 }
52 },
53 MmaFragment::Lhs(_) | MmaFragment::Acc(_) => {
54 panic!("Mma: expected Rhs role for rhs")
55 }
56 },
57 MmaFragment::Rhs(_) | MmaFragment::Acc(_) => {
58 panic!("Mma: expected Lhs role for lhs")
59 }
60 },
61 (Tile::Register(l), Tile::Register(r), Tile::Register(a)) => {
62 register_execute(&l.data, &r.data, &mut a.data, a.config);
63 }
64 (Tile::PlaneVec(l), Tile::PlaneVec(r), Tile::PlaneVec(a)) => {
65 planevec_execute(&l.data, &r.data, &mut a.data, a.config);
66 }
67 (Tile::Interleaved(l), Tile::Interleaved(r), Tile::Interleaved(a)) => {
68 interleaved_execute(
69 &l.data,
70 l.matrix_layout,
71 &r.data,
72 r.matrix_layout,
73 &mut a.data,
74 a.matrix_layout,
75 a.config,
76 );
77 }
78 _ => panic!("Unsupported storage combination for mma"),
79 }
80 }
81}