Skip to main content

cubek_std/tile/compute/matmul/
mod.rs

1//! Per-flavor tile matmul compute: `*_execute`, `*_load_*`, `*_write_to_shared`,
2//! plus the fragment readers/writers for each flavor. Tile data and matmul
3//! configs live alongside the corresponding data structures in
4//! [`crate::tile::data`].
5
6pub mod cmma;
7pub mod interleaved;
8pub mod mma;
9pub mod plane_vec;
10pub mod register;
11
12use cubecl::prelude::*;
13
14use crate::tile::{
15    MmaFragment, MmaFragmentExpand, Tile, TileExpand, TileScope,
16    compute::matmul::{
17        cmma::cmma_execute, interleaved::interleaved_execute, mma::mma_execute,
18        plane_vec::planevec_execute, register::register_execute,
19    },
20};
21
22#[cube]
23impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
24    /// Executes `lhs ยท rhs`, accumulating the result into `self`.
25    pub fn mma<L: Numeric, R: Numeric>(
26        &mut self,
27        lhs: &Tile<L, Sc, ReadWrite>,
28        rhs: &Tile<R, Sc, ReadWrite>,
29    ) {
30        match (lhs, rhs, self) {
31            (Tile::Cmma(l), Tile::Cmma(r), Tile::Cmma(a)) => {
32                cmma_execute(&l.matrix, &r.matrix, &mut a.matrix);
33            }
34            (Tile::Cmma(l), Tile::Cmma(r), Tile::Bounce(a)) => {
35                cmma_execute(&l.matrix, &r.matrix, &mut a.cmma.matrix);
36            }
37            (Tile::Bounce(l), Tile::Cmma(r), Tile::Bounce(a)) => {
38                cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.cmma.matrix);
39            }
40            (Tile::Bounce(l), Tile::Cmma(r), Tile::Cmma(a)) => {
41                cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.matrix);
42            }
43            (Tile::Mma(l), Tile::Mma(r), Tile::Mma(a)) => match &l.fragment {
44                MmaFragment::Lhs(lf) => match &r.fragment {
45                    MmaFragment::Rhs(rf) => match &mut a.fragment {
46                        MmaFragment::Acc(af) => {
47                            mma_execute(lf, rf, af, a.matrix_layout, a.config);
48                        }
49                        MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
50                            panic!("Mma: expected Acc role for accumulator")
51                        }
52                    },
53                    MmaFragment::Lhs(_) | MmaFragment::Acc(_) => {
54                        panic!("Mma: expected Rhs role for rhs")
55                    }
56                },
57                MmaFragment::Rhs(_) | MmaFragment::Acc(_) => {
58                    panic!("Mma: expected Lhs role for lhs")
59                }
60            },
61            (Tile::Register(l), Tile::Register(r), Tile::Register(a)) => {
62                register_execute(&l.data, &r.data, &mut a.data, a.config);
63            }
64            (Tile::PlaneVec(l), Tile::PlaneVec(r), Tile::PlaneVec(a)) => {
65                planevec_execute(&l.data, &r.data, &mut a.data, a.config);
66            }
67            (Tile::Interleaved(l), Tile::Interleaved(r), Tile::Interleaved(a)) => {
68                interleaved_execute(
69                    &l.data,
70                    l.matrix_layout,
71                    &r.data,
72                    r.matrix_layout,
73                    &mut a.data,
74                    a.matrix_layout,
75                    a.config,
76                );
77            }
78            _ => panic!("Unsupported storage combination for mma"),
79        }
80    }
81}