Skip to main content

cubek_std/tile/data/
register.rs

1use cubecl::prelude::*;
2
3use crate::{
4    MatrixLayout, SwizzleModes, TileSize,
5    tile::{Tile, TileScope},
6};
7
8#[derive(CubeType)]
9pub struct RegisterTile<N: Numeric> {
10    pub data: Array<N>,
11    #[cube(comptime)]
12    pub matrix_layout: MatrixLayout,
13    #[cube(comptime)]
14    pub config: RegisterMatmul,
15}
16
17/// Execution mode for the RegisterMatmul
18#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
19pub enum ProductType {
20    /// Computes the Tile Matmul as m*n inner products of length k.
21    ///
22    /// Needs Lhs to be row major and Rhs to be col major
23    /// If not the case, tile will be transposed during load
24    Inner,
25    /// Computes the Stage Matmul as the sum of k outer products of size m*n.
26    ///
27    /// Needs Lhs to be col major and Rhs to be row major
28    /// If not the case, tile will be transposed during load
29    Outer,
30}
31
32impl ProductType {
33    pub fn from_layouts(
34        lhs_layout: MatrixLayout,
35        rhs_layout: MatrixLayout,
36        tile_size: TileSize,
37    ) -> Self {
38        let lhs_preferred = match lhs_layout {
39            MatrixLayout::RowMajor => ProductType::Inner,
40            MatrixLayout::ColMajor => ProductType::Outer,
41        };
42        let rhs_preferred = match rhs_layout {
43            MatrixLayout::RowMajor => ProductType::Outer,
44            MatrixLayout::ColMajor => ProductType::Inner,
45        };
46
47        if lhs_preferred == rhs_preferred {
48            lhs_preferred
49        } else if tile_size.m() == 1 {
50            rhs_preferred
51        } else if tile_size.n() == 1 {
52            lhs_preferred
53        } else {
54            // No better solution
55            ProductType::Outer
56        }
57    }
58}
59
60#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
61pub struct RegisterMatmul {
62    pub tile_size: TileSize,
63    pub plane_dim: u32,
64    pub swizzle_modes: SwizzleModes,
65    pub product_type: ProductType,
66}
67
68impl RegisterMatmul {
69    pub fn new(
70        lhs_layout: MatrixLayout,
71        rhs_layout: MatrixLayout,
72        tile_size: TileSize,
73        plane_dim: u32,
74        swizzle_modes: SwizzleModes,
75    ) -> Self {
76        Self {
77            tile_size,
78            plane_dim,
79            swizzle_modes,
80            product_type: ProductType::from_layouts(lhs_layout, rhs_layout, tile_size),
81        }
82    }
83}
84
85#[cube]
86pub fn register_allocate_lhs<L: Numeric, Sc: TileScope>(
87    #[comptime] layout: MatrixLayout,
88    #[comptime] config: RegisterMatmul,
89) -> Tile<L, Sc, ReadWrite> {
90    Tile::new_Register(RegisterTile::<L> {
91        data: Array::new((config.tile_size.m() * config.tile_size.k()) as usize),
92        matrix_layout: layout,
93        config,
94    })
95}
96
97#[cube]
98pub fn register_allocate_rhs<R: Numeric, Sc: TileScope>(
99    #[comptime] layout: MatrixLayout,
100    #[comptime] config: RegisterMatmul,
101) -> Tile<R, Sc, ReadWrite> {
102    Tile::new_Register(RegisterTile::<R> {
103        data: Array::new((config.tile_size.n() * config.tile_size.k()) as usize),
104        matrix_layout: layout,
105        config,
106    })
107}
108
109#[cube]
110pub fn register_allocate_acc<A: Numeric, Sc: TileScope>(
111    #[comptime] layout: MatrixLayout,
112    #[comptime] config: RegisterMatmul,
113) -> Tile<A, Sc, ReadWrite> {
114    Tile::new_Register(RegisterTile::<A> {
115        data: Array::new((config.tile_size.m() * config.tile_size.n()) as usize),
116        matrix_layout: layout,
117        config,
118    })
119}