use cubecl;
use cubecl::prelude::*;
use crate::{
StageIdent,
tile::{
MmaFragment, MmaFragmentExpand, Tile, TileExpand, TileScope,
compute::matmul::{
cmma::{cmma_load_from_shared, cmma_load_zeros, cmma_write_to_shared},
interleaved::{
interleaved_load_from_shared, interleaved_load_zeros, interleaved_write_to_shared,
},
mma::{
mma_load_acc_from_shared, mma_load_acc_zeros, mma_load_lhs_from_shared,
mma_load_rhs_from_shared, mma_write_to_shared,
},
plane_vec::{planevec_load_from_shared, planevec_load_zeros, planevec_write_to_shared},
register::{register_load_from_shared, register_load_zeros, register_write_to_shared},
},
data::BounceTile,
},
};
#[cube]
pub(crate) fn cmma_to_whitebox_fragment<E: Float>(b: &mut BounceTile<E>) {
let stride = comptime!(b.cmma.tile_size.n());
cubecl::cmma::store(
&mut b.smem,
&b.cmma.matrix,
stride,
cubecl::cmma::MatrixLayout::RowMajor,
);
sync_cube();
b.fragment.load_from_slice(&b.smem.to_slice());
sync_cube();
}
#[cube]
pub(crate) fn whitebox_fragment_to_cmma<E: Float>(b: &mut BounceTile<E>) {
let stride = comptime!(b.cmma.tile_size.n());
b.fragment.store_to(&mut b.smem);
sync_cube();
cubecl::cmma::load_with_layout(
&b.cmma.matrix,
&b.smem.to_slice(),
stride,
cubecl::cmma::MatrixLayout::RowMajor,
);
}
#[cube]
impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
pub fn copy_from<
SE: Numeric,
SS: Size,
L: Numeric,
R: Numeric,
A: Numeric,
SIO: SliceVisibility,
>(
&mut self,
source: &Tile<SE, Sc, SIO>,
#[comptime] ident: StageIdent,
) {
match (source, self) {
(Tile::SharedMemory(shared), Tile::Cmma(t)) => {
let shared = shared.view::<SS>();
cmma_load_from_shared::<SE, SS, N, SIO>(
&shared,
&mut t.matrix,
ident,
t.matrix_layout,
);
}
(Tile::None, Tile::Cmma(t)) => {
cmma_load_zeros::<N>(&mut t.matrix);
}
(Tile::SharedMemory(shared), Tile::Bounce(b)) => {
let shared = shared.view::<SS>();
cmma_load_from_shared::<SE, SS, N, SIO>(
&shared,
&mut b.cmma.matrix,
ident,
b.cmma.matrix_layout,
);
}
(Tile::None, Tile::Bounce(b)) => {
cmma_load_zeros::<N>(&mut b.cmma.matrix);
}
(Tile::SharedMemory(shared), Tile::Mma(t)) => {
let shared = shared.view::<SS>();
match &mut t.fragment {
MmaFragment::Lhs(f) => mma_load_lhs_from_shared::<SE, SS, N, R, A, SIO>(
&shared,
f,
t.matrix_layout,
t.config,
),
MmaFragment::Rhs(f) => mma_load_rhs_from_shared::<SE, SS, N, L, A, SIO>(
&shared,
f,
t.matrix_layout,
t.config,
),
MmaFragment::Acc(f) => mma_load_acc_from_shared::<SE, SS, N, L, R, SIO>(
&shared,
f,
t.matrix_layout,
t.config,
),
}
}
(Tile::None, Tile::Mma(t)) => match &mut t.fragment {
MmaFragment::Acc(f) => {
mma_load_acc_zeros::<SE, SS, N, L, R>(f, t.matrix_layout, t.config);
}
MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
panic!("Mma zero-load only supported for Acc role")
}
},
(Tile::SharedMemory(shared), Tile::Register(t)) => {
let shared = shared.view::<SS>();
register_load_from_shared::<SE, SS, N, SIO>(
&shared,
&mut t.data,
t.matrix_layout,
t.config,
ident,
);
}
(Tile::None, Tile::Register(t)) => {
register_load_zeros::<N>(&mut t.data, t.config, ident);
}
(Tile::SharedMemory(shared), Tile::PlaneVec(t)) => {
let shared = shared.view::<SS>();
planevec_load_from_shared::<SE, SS, N, SIO>(&shared, &mut t.data, t.config, ident);
}
(Tile::None, Tile::PlaneVec(t)) => {
planevec_load_zeros::<N>(&mut t.data, t.config);
}
(Tile::SharedMemory(shared), Tile::Interleaved(t)) => {
let shared = shared.view::<SS>();
interleaved_load_from_shared::<SE, SS, N, SIO>(
&shared,
&mut t.data,
t.config,
ident,
);
}
(Tile::None, Tile::Interleaved(t)) => {
interleaved_load_zeros::<N>(&mut t.data, t.config);
}
(Tile::Cmma(t), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
cmma_write_to_shared::<N, SS, SE>(&mut shared, &t.matrix);
}
(Tile::Bounce(b), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
cmma_write_to_shared::<N, SS, SE>(&mut shared, &b.cmma.matrix);
}
(Tile::Mma(t), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
match &t.fragment {
MmaFragment::Acc(f) => {
mma_write_to_shared::<N, SS, SE, L, R>(&mut shared, f, t.config);
}
MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
panic!("Mma write_to_shared only supported for Acc role")
}
}
}
(Tile::Register(t), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
register_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
}
(Tile::PlaneVec(t), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
planevec_write_to_shared::<SE, N, SS>(&mut shared, &t.data, t.config);
}
(Tile::Interleaved(t), Tile::SharedMemory(shared)) => {
let mut shared = shared.view::<SS>();
interleaved_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
}
_ => panic!("Unsupported storage pair for copy_from"),
}
}
}