Skip to main content

cubek_std/tile/data/
interleaved.rs

1use cubecl::prelude::*;
2
3use crate::{
4    MatrixLayout, SwizzleModes, TileSize,
5    tile::{Tile, TileScope},
6};
7
8#[derive(CubeType)]
9pub struct InterleavedTile<N: Numeric> {
10    pub data: Array<N>,
11    #[cube(comptime)]
12    pub matrix_layout: MatrixLayout,
13    #[cube(comptime)]
14    pub config: InterleavedMatmul,
15}
16
17#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
18pub struct InterleavedMatmul {
19    pub tile_size: TileSize,
20    pub plane_dim: u32,
21    pub swizzle_modes: SwizzleModes,
22}
23
24impl InterleavedMatmul {
25    pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
26        Self {
27            tile_size,
28            plane_dim,
29            swizzle_modes,
30        }
31    }
32
33    pub fn elements_per_unit_m(&self) -> usize {
34        self.tile_size.m() as usize
35    }
36
37    pub fn elements_per_unit_n(&self) -> usize {
38        self.tile_size.n() as usize
39    }
40
41    pub fn local_tile_size(&self) -> TileSize {
42        TileSize {
43            m: self.tile_size.m(),
44            n: self.tile_size.n(),
45            k: self.tile_size.k(),
46        }
47    }
48
49    pub fn elements_per_unit_k(&self) -> usize {
50        let k = self.tile_size.k() as usize;
51        let plane_dim = self.plane_dim as usize;
52        assert!(
53            k.is_multiple_of(plane_dim),
54            "k must be divisible by plane_dim. Got k={:?}, plane_dim={:?}",
55            k,
56            plane_dim
57        );
58
59        k / plane_dim
60    }
61}
62
63#[cube]
64pub fn interleaved_allocate_lhs<L: Numeric, Sc: TileScope>(
65    #[comptime] layout: MatrixLayout,
66    #[comptime] config: InterleavedMatmul,
67) -> Tile<L, Sc, ReadWrite> {
68    let m = config.tile_size.m();
69    let k = config.tile_size.k();
70    let plane_dim = config.plane_dim;
71    Tile::new_Interleaved(InterleavedTile::<L> {
72        data: Array::new((m * (k / plane_dim)) as usize),
73        matrix_layout: layout,
74        config,
75    })
76}
77
78#[cube]
79pub fn interleaved_allocate_rhs<R: Numeric, Sc: TileScope>(
80    #[comptime] layout: MatrixLayout,
81    #[comptime] config: InterleavedMatmul,
82) -> Tile<R, Sc, ReadWrite> {
83    let n = config.tile_size.n();
84    let k = config.tile_size.k();
85    let plane_dim = config.plane_dim;
86    Tile::new_Interleaved(InterleavedTile::<R> {
87        data: Array::new(((k / plane_dim) * n) as usize),
88        matrix_layout: layout,
89        config,
90    })
91}
92
93#[cube]
94pub fn interleaved_allocate_acc<A: Numeric, Sc: TileScope>(
95    #[comptime] layout: MatrixLayout,
96    #[comptime] config: InterleavedMatmul,
97) -> Tile<A, Sc, ReadWrite> {
98    let m = config.tile_size.m();
99    let n = config.tile_size.n();
100    Tile::new_Interleaved(InterleavedTile::<A> {
101        data: Array::new((m * n) as usize),
102        matrix_layout: layout,
103        config,
104    })
105}