cubek_std/tile/data/
interleaved.rs1use cubecl::prelude::*;
2
3use crate::{
4 MatrixLayout, SwizzleModes, TileSize,
5 tile::{Tile, TileScope},
6};
7
8#[derive(CubeType)]
9pub struct InterleavedTile<N: Numeric> {
10 pub data: Array<N>,
11 #[cube(comptime)]
12 pub matrix_layout: MatrixLayout,
13 #[cube(comptime)]
14 pub config: InterleavedMatmul,
15}
16
17#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
18pub struct InterleavedMatmul {
19 pub tile_size: TileSize,
20 pub plane_dim: u32,
21 pub swizzle_modes: SwizzleModes,
22}
23
24impl InterleavedMatmul {
25 pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
26 Self {
27 tile_size,
28 plane_dim,
29 swizzle_modes,
30 }
31 }
32
33 pub fn elements_per_unit_m(&self) -> usize {
34 self.tile_size.m() as usize
35 }
36
37 pub fn elements_per_unit_n(&self) -> usize {
38 self.tile_size.n() as usize
39 }
40
41 pub fn local_tile_size(&self) -> TileSize {
42 TileSize {
43 m: self.tile_size.m(),
44 n: self.tile_size.n(),
45 k: self.tile_size.k(),
46 }
47 }
48
49 pub fn elements_per_unit_k(&self) -> usize {
50 let k = self.tile_size.k() as usize;
51 let plane_dim = self.plane_dim as usize;
52 assert!(
53 k.is_multiple_of(plane_dim),
54 "k must be divisible by plane_dim. Got k={:?}, plane_dim={:?}",
55 k,
56 plane_dim
57 );
58
59 k / plane_dim
60 }
61}
62
63#[cube]
64pub fn interleaved_allocate_lhs<L: Numeric, Sc: TileScope>(
65 #[comptime] layout: MatrixLayout,
66 #[comptime] config: InterleavedMatmul,
67) -> Tile<L, Sc, ReadWrite> {
68 let m = config.tile_size.m();
69 let k = config.tile_size.k();
70 let plane_dim = config.plane_dim;
71 Tile::new_Interleaved(InterleavedTile::<L> {
72 data: Array::new((m * (k / plane_dim)) as usize),
73 matrix_layout: layout,
74 config,
75 })
76}
77
78#[cube]
79pub fn interleaved_allocate_rhs<R: Numeric, Sc: TileScope>(
80 #[comptime] layout: MatrixLayout,
81 #[comptime] config: InterleavedMatmul,
82) -> Tile<R, Sc, ReadWrite> {
83 let n = config.tile_size.n();
84 let k = config.tile_size.k();
85 let plane_dim = config.plane_dim;
86 Tile::new_Interleaved(InterleavedTile::<R> {
87 data: Array::new(((k / plane_dim) * n) as usize),
88 matrix_layout: layout,
89 config,
90 })
91}
92
93#[cube]
94pub fn interleaved_allocate_acc<A: Numeric, Sc: TileScope>(
95 #[comptime] layout: MatrixLayout,
96 #[comptime] config: InterleavedMatmul,
97) -> Tile<A, Sc, ReadWrite> {
98 let m = config.tile_size.m();
99 let n = config.tile_size.n();
100 Tile::new_Interleaved(InterleavedTile::<A> {
101 data: Array::new((m * n) as usize),
102 matrix_layout: layout,
103 config,
104 })
105}