Skip to main content

cubek_std/tile/data/
plane_vec.rs

1use cubecl::{define_size, prelude::*};
2
3use crate::{
4    MatrixLayout, SwizzleModes, TileSize,
5    tile::{Tile, TileScope},
6};
7
8// plane_vec_mat's fragment inner vector size (= reduce_vector_size). Bound at
9// allocate time via `scope.register_size::<NPlaneVec>(reduce_vector_size)`.
10// Decoupled from the outer enum `V` so the fragment is sized by the tile impl's
11// needs, not the stage's vector size.
12define_size!(pub NPlaneVec);
13
14#[derive(CubeType)]
15pub struct PlaneVecTile<N: Numeric> {
16    // Fragment inner size is `NPlaneVec` (= reduce_vector_size).
17    pub data: Array<Vector<N, NPlaneVec>>,
18    #[cube(comptime)]
19    pub matrix_layout: MatrixLayout,
20    #[cube(comptime)]
21    pub config: PlaneVecMatInnerProduct,
22}
23
24#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
25pub struct PlaneVecMatInnerProduct {
26    pub tile_size: TileSize,
27    pub plane_dim: u32,
28    pub swizzle_modes: SwizzleModes,
29    pub reduce_vector_size: u32,
30}
31
32impl PlaneVecMatInnerProduct {
33    pub fn new(
34        tile_size: TileSize,
35        plane_dim: u32,
36        swizzle_modes: SwizzleModes,
37        reduce_vector_size: u32,
38    ) -> Self {
39        Self {
40            tile_size,
41            plane_dim,
42            swizzle_modes,
43            reduce_vector_size,
44        }
45    }
46}
47
48// Binds the plane_vec_mat fragment's inner vector size (`NPlaneVec`) to the
49// `reduce_vector_size` chosen by the tile config at allocation time.
50#[cube]
51#[allow(unused_variables)]
52fn register_reduce_vector_size(#[comptime] reduce_vector_size: u32) {
53    intrinsic!(|scope| {
54        scope.register_size::<NPlaneVec>(reduce_vector_size as usize);
55    });
56}
57
58#[cube]
59pub fn planevec_allocate_lhs<L: Numeric, Sc: TileScope>(
60    #[comptime] layout: MatrixLayout,
61    #[comptime] config: PlaneVecMatInnerProduct,
62) -> Tile<L, Sc, ReadWrite> {
63    register_reduce_vector_size(config.reduce_vector_size);
64    Tile::new_PlaneVec(PlaneVecTile::<L> {
65        data: Array::new(1usize),
66        matrix_layout: layout,
67        config,
68    })
69}
70
71#[cube]
72pub fn planevec_allocate_rhs<R: Numeric, Sc: TileScope>(
73    #[comptime] layout: MatrixLayout,
74    #[comptime] config: PlaneVecMatInnerProduct,
75) -> Tile<R, Sc, ReadWrite> {
76    register_reduce_vector_size(config.reduce_vector_size);
77    Tile::new_PlaneVec(PlaneVecTile::<R> {
78        data: Array::new(config.tile_size.n() as usize),
79        matrix_layout: layout,
80        config,
81    })
82}
83
84#[cube]
85pub fn planevec_allocate_acc<A: Numeric, Sc: TileScope>(
86    #[comptime] layout: MatrixLayout,
87    #[comptime] config: PlaneVecMatInnerProduct,
88) -> Tile<A, Sc, ReadWrite> {
89    register_reduce_vector_size(config.reduce_vector_size);
90    Tile::new_PlaneVec(PlaneVecTile::<A> {
91        data: Array::new(config.tile_size.n() as usize),
92        matrix_layout: layout,
93        config,
94    })
95}