cubek_std/tile/compute/matmul/
plane_vec.rs1use cubecl::prelude::*;
2
3use crate::{
4 StageIdent,
5 tile::data::{NPlaneVec, PlaneVecMatInnerProduct, StridedTile},
6};
7
8#[cube]
9pub fn planevec_execute<L: Numeric, R: Numeric, A: Numeric>(
10 lhs: &Array<Vector<L, NPlaneVec>>,
11 rhs: &Array<Vector<R, NPlaneVec>>,
12 acc: &mut Array<Vector<A, NPlaneVec>>,
13 #[comptime] config: PlaneVecMatInnerProduct,
14) {
15 let n = config.tile_size.n();
16 #[unroll]
17 for n_idx in 0..n as usize {
18 let mut acc_vec = acc[n_idx];
19 #[unroll]
20 for vi in 0..NPlaneVec::value() {
21 let lhs_elem = A::cast_from(lhs[0usize][vi]);
22 let rhs_elem = A::cast_from(rhs[n_idx][vi]);
23 acc_vec[vi] += plane_sum(lhs_elem * rhs_elem);
24 }
25 acc[n_idx] = acc_vec;
26 }
27}
28
29#[cube]
30pub fn planevec_load_from_shared<E: Numeric, ES: Size, N: Numeric, IO: SliceVisibility>(
31 shared: &StridedTile<E, ES, IO>,
32 arr: &mut Array<Vector<N, NPlaneVec>>,
33 #[comptime] config: PlaneVecMatInnerProduct,
34 #[comptime] ident: StageIdent,
35) {
36 match ident {
37 StageIdent::Lhs => {
38 let offset = shared.stage_offset(UNIT_POS_X);
39 arr[0usize] = Vector::cast_from(shared.container[offset as usize]);
40 }
41 StageIdent::Rhs | StageIdent::Acc => {
42 let n = config.tile_size.n();
43 #[unroll]
44 for n_idx in 0..n {
45 let offset = shared.stage_offset(UNIT_POS_X + n_idx * shared.stride);
46 arr[n_idx as usize] = Vector::cast_from(shared.container[offset as usize]);
47 }
48 }
49 _ => panic!("Invalid ident for PlaneVec load"),
50 }
51}
52
53#[cube]
54pub fn planevec_load_zeros<N: Numeric>(
55 arr: &mut Array<Vector<N, NPlaneVec>>,
56 #[comptime] config: PlaneVecMatInnerProduct,
57) {
58 let n = config.tile_size.n();
59 let zero = N::from_int(0);
60 #[unroll]
61 for n_idx in 0..n as usize {
62 arr[n_idx] = Vector::cast_from(zero);
63 }
64}
65
66#[cube]
67pub fn planevec_write_to_shared<A: Numeric, E: Numeric, ES: Size>(
68 shared: &mut StridedTile<E, ES, ReadWrite>,
69 arr: &Array<Vector<A, NPlaneVec>>,
70 #[comptime] config: PlaneVecMatInnerProduct,
71) {
72 if UNIT_POS_X == 0 {
73 let out_vector_size = shared.container.vector_size().comptime();
74 let n = config.tile_size.n();
75 let total_out_vectors = n as usize / out_vector_size;
76 let reduce_vec = config.reduce_vector_size as usize;
77
78 #[unroll]
79 for out_vector_iter in 0..total_out_vectors {
80 let mut out_vector = Vector::<E, ES>::empty();
81 #[unroll]
82 for within_vector in 0..out_vector_size {
83 let n_idx = out_vector_iter * out_vector_size + within_vector;
84 let acc_vec = arr[n_idx];
85 let mut sum = A::from_int(0);
86 for i in 0..reduce_vec {
87 sum += acc_vec[i];
88 }
89 out_vector[within_vector] = E::cast_from(sum);
90 }
91 let offset = shared.stage_offset(out_vector_iter as u32);
92 shared.container[offset as usize] = out_vector;
93 }
94 }
95}