Skip to main content

cubek_std/tile/data/
bounce.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::{
5    Tile,
6    data::{
7        cmma::CmmaTile,
8        whitebox_fragment::{InnerLayout, WhiteboxFragment, WhiteboxFragmentLayout},
9    },
10    scope::{TileScope, assert_plane_scope},
11};
12
13/// Comptime configuration for [`BounceTile`].
14///
15/// A bounce tile bundles an opaque cmma fragment together with a shared-memory
16/// scratch slice and a [`WhiteboxFragment`] view, so row-wise operations can be
17/// expressed as `copy_from` between the inner pieces. From the caller's point
18/// of view it is a single [`Tile`] variant — only valid when the tile's
19/// scope is `Plane`.
20#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
21pub struct BounceConfig {
22    pub tile_shape: (u32, u32),
23    pub num_planes: u32,
24    pub plane_dim: u32,
25    pub inner_layout: InnerLayout,
26}
27
28#[derive(CubeType)]
29pub struct BounceTile<N: Numeric> {
30    pub cmma: CmmaTile<N>,
31    pub smem: SliceMut<N>,
32    pub fragment: WhiteboxFragment<N>,
33}
34
35#[cube]
36impl<N: Numeric> BounceTile<N> {
37    pub fn new(cmma: CmmaTile<N>, #[comptime] cfg: BounceConfig) -> BounceTile<N> {
38        let total_tile_size = comptime!((cfg.tile_shape.0 * cfg.tile_shape.1) as usize);
39        let smem_size = comptime!(total_tile_size * cfg.num_planes as usize);
40        let start = UNIT_POS_Y as usize * total_tile_size;
41        let end = start + total_tile_size;
42        let smem = SharedMemory::new(smem_size).slice_mut(start, end);
43
44        let layout = comptime!(WhiteboxFragmentLayout::new(
45            cfg.tile_shape,
46            cfg.plane_dim,
47            cfg.inner_layout
48        ));
49        let fragment = WhiteboxFragment::new(layout);
50
51        BounceTile::<N> {
52            cmma,
53            smem,
54            fragment,
55        }
56    }
57}
58
59#[cube]
60/// Wraps a freshly built `CmmaTile` in a `Tile::Bounce`. Panics at expansion
61/// time unless `Sc = Plane`.
62pub fn allocate_bounce_tile<E: Numeric, Sc: TileScope>(
63    cmma: CmmaTile<E>,
64    #[comptime] cfg: BounceConfig,
65) -> Tile<E, Sc, ReadWrite> {
66    comptime!(assert_plane_scope(Sc::KIND));
67    Tile::new_Bounce(BounceTile::<E>::new(cmma, cfg))
68}