cubek_std/tile/data/
plane_vec.rs1use cubecl::{define_size, prelude::*};
2
3use crate::{
4 MatrixLayout, SwizzleModes, TileSize,
5 tile::{Tile, TileScope},
6};
7
8define_size!(pub NPlaneVec);
13
14#[derive(CubeType)]
15pub struct PlaneVecTile<N: Numeric> {
16 pub data: Array<Vector<N, NPlaneVec>>,
18 #[cube(comptime)]
19 pub matrix_layout: MatrixLayout,
20 #[cube(comptime)]
21 pub config: PlaneVecMatInnerProduct,
22}
23
24#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
25pub struct PlaneVecMatInnerProduct {
26 pub tile_size: TileSize,
27 pub plane_dim: u32,
28 pub swizzle_modes: SwizzleModes,
29 pub reduce_vector_size: u32,
30}
31
32impl PlaneVecMatInnerProduct {
33 pub fn new(
34 tile_size: TileSize,
35 plane_dim: u32,
36 swizzle_modes: SwizzleModes,
37 reduce_vector_size: u32,
38 ) -> Self {
39 Self {
40 tile_size,
41 plane_dim,
42 swizzle_modes,
43 reduce_vector_size,
44 }
45 }
46}
47
48#[cube]
51#[allow(unused_variables)]
52fn register_reduce_vector_size(#[comptime] reduce_vector_size: u32) {
53 intrinsic!(|scope| {
54 scope.register_size::<NPlaneVec>(reduce_vector_size as usize);
55 });
56}
57
58#[cube]
59pub fn planevec_allocate_lhs<L: Numeric, Sc: TileScope>(
60 #[comptime] layout: MatrixLayout,
61 #[comptime] config: PlaneVecMatInnerProduct,
62) -> Tile<L, Sc, ReadWrite> {
63 register_reduce_vector_size(config.reduce_vector_size);
64 Tile::new_PlaneVec(PlaneVecTile::<L> {
65 data: Array::new(1usize),
66 matrix_layout: layout,
67 config,
68 })
69}
70
71#[cube]
72pub fn planevec_allocate_rhs<R: Numeric, Sc: TileScope>(
73 #[comptime] layout: MatrixLayout,
74 #[comptime] config: PlaneVecMatInnerProduct,
75) -> Tile<R, Sc, ReadWrite> {
76 register_reduce_vector_size(config.reduce_vector_size);
77 Tile::new_PlaneVec(PlaneVecTile::<R> {
78 data: Array::new(config.tile_size.n() as usize),
79 matrix_layout: layout,
80 config,
81 })
82}
83
84#[cube]
85pub fn planevec_allocate_acc<A: Numeric, Sc: TileScope>(
86 #[comptime] layout: MatrixLayout,
87 #[comptime] config: PlaneVecMatInnerProduct,
88) -> Tile<A, Sc, ReadWrite> {
89 register_reduce_vector_size(config.reduce_vector_size);
90 Tile::new_PlaneVec(PlaneVecTile::<A> {
91 data: Array::new(config.tile_size.n() as usize),
92 matrix_layout: layout,
93 config,
94 })
95}