cubek_std/tile/data/
bounce.rs1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::{
5 Tile,
6 data::{
7 cmma::CmmaTile,
8 whitebox_fragment::{InnerLayout, WhiteboxFragment, WhiteboxFragmentLayout},
9 },
10 scope::{TileScope, assert_plane_scope},
11};
12
13#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
21pub struct BounceConfig {
22 pub tile_shape: (u32, u32),
23 pub num_planes: u32,
24 pub plane_dim: u32,
25 pub inner_layout: InnerLayout,
26}
27
28#[derive(CubeType)]
29pub struct BounceTile<N: Numeric> {
30 pub cmma: CmmaTile<N>,
31 pub smem: SliceMut<N>,
32 pub fragment: WhiteboxFragment<N>,
33}
34
35#[cube]
36impl<N: Numeric> BounceTile<N> {
37 pub fn new(cmma: CmmaTile<N>, #[comptime] cfg: BounceConfig) -> BounceTile<N> {
38 let total_tile_size = comptime!((cfg.tile_shape.0 * cfg.tile_shape.1) as usize);
39 let smem_size = comptime!(total_tile_size * cfg.num_planes as usize);
40 let start = UNIT_POS_Y as usize * total_tile_size;
41 let end = start + total_tile_size;
42 let smem = SharedMemory::new(smem_size).slice_mut(start, end);
43
44 let layout = comptime!(WhiteboxFragmentLayout::new(
45 cfg.tile_shape,
46 cfg.plane_dim,
47 cfg.inner_layout
48 ));
49 let fragment = WhiteboxFragment::new(layout);
50
51 BounceTile::<N> {
52 cmma,
53 smem,
54 fragment,
55 }
56 }
57}
58
59#[cube]
60pub fn allocate_bounce_tile<E: Numeric, Sc: TileScope>(
63 cmma: CmmaTile<E>,
64 #[comptime] cfg: BounceConfig,
65) -> Tile<E, Sc, ReadWrite> {
66 comptime!(assert_plane_scope(Sc::KIND));
67 Tile::new_Bounce(BounceTile::<E>::new(cmma, cfg))
68}