cubek_std/tile/data/
register.rs1use cubecl::prelude::*;
2
3use crate::{
4 MatrixLayout, SwizzleModes, TileSize,
5 tile::{Tile, TileScope},
6};
7
8#[derive(CubeType)]
9pub struct RegisterTile<N: Numeric> {
10 pub data: Array<N>,
11 #[cube(comptime)]
12 pub matrix_layout: MatrixLayout,
13 #[cube(comptime)]
14 pub config: RegisterMatmul,
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
19pub enum ProductType {
20 Inner,
25 Outer,
30}
31
32impl ProductType {
33 pub fn from_layouts(
34 lhs_layout: MatrixLayout,
35 rhs_layout: MatrixLayout,
36 tile_size: TileSize,
37 ) -> Self {
38 let lhs_preferred = match lhs_layout {
39 MatrixLayout::RowMajor => ProductType::Inner,
40 MatrixLayout::ColMajor => ProductType::Outer,
41 };
42 let rhs_preferred = match rhs_layout {
43 MatrixLayout::RowMajor => ProductType::Outer,
44 MatrixLayout::ColMajor => ProductType::Inner,
45 };
46
47 if lhs_preferred == rhs_preferred {
48 lhs_preferred
49 } else if tile_size.m() == 1 {
50 rhs_preferred
51 } else if tile_size.n() == 1 {
52 lhs_preferred
53 } else {
54 ProductType::Outer
56 }
57 }
58}
59
60#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
61pub struct RegisterMatmul {
62 pub tile_size: TileSize,
63 pub plane_dim: u32,
64 pub swizzle_modes: SwizzleModes,
65 pub product_type: ProductType,
66}
67
68impl RegisterMatmul {
69 pub fn new(
70 lhs_layout: MatrixLayout,
71 rhs_layout: MatrixLayout,
72 tile_size: TileSize,
73 plane_dim: u32,
74 swizzle_modes: SwizzleModes,
75 ) -> Self {
76 Self {
77 tile_size,
78 plane_dim,
79 swizzle_modes,
80 product_type: ProductType::from_layouts(lhs_layout, rhs_layout, tile_size),
81 }
82 }
83}
84
85#[cube]
86pub fn register_allocate_lhs<L: Numeric, Sc: TileScope>(
87 #[comptime] layout: MatrixLayout,
88 #[comptime] config: RegisterMatmul,
89) -> Tile<L, Sc, ReadWrite> {
90 Tile::new_Register(RegisterTile::<L> {
91 data: Array::new((config.tile_size.m() * config.tile_size.k()) as usize),
92 matrix_layout: layout,
93 config,
94 })
95}
96
97#[cube]
98pub fn register_allocate_rhs<R: Numeric, Sc: TileScope>(
99 #[comptime] layout: MatrixLayout,
100 #[comptime] config: RegisterMatmul,
101) -> Tile<R, Sc, ReadWrite> {
102 Tile::new_Register(RegisterTile::<R> {
103 data: Array::new((config.tile_size.n() * config.tile_size.k()) as usize),
104 matrix_layout: layout,
105 config,
106 })
107}
108
109#[cube]
110pub fn register_allocate_acc<A: Numeric, Sc: TileScope>(
111 #[comptime] layout: MatrixLayout,
112 #[comptime] config: RegisterMatmul,
113) -> Tile<A, Sc, ReadWrite> {
114 Tile::new_Register(RegisterTile::<A> {
115 data: Array::new((config.tile_size.m() * config.tile_size.n()) as usize),
116 matrix_layout: layout,
117 config,
118 })
119}