Skip to main content

cubek_std/tile/compute/matmul/
plane_vec.rs

1use cubecl::prelude::*;
2
3use crate::{
4    StageIdent,
5    tile::data::{NPlaneVec, PlaneVecMatInnerProduct, StridedTile},
6};
7
8#[cube]
9pub fn planevec_execute<L: Numeric, R: Numeric, A: Numeric>(
10    lhs: &Array<Vector<L, NPlaneVec>>,
11    rhs: &Array<Vector<R, NPlaneVec>>,
12    acc: &mut Array<Vector<A, NPlaneVec>>,
13    #[comptime] config: PlaneVecMatInnerProduct,
14) {
15    let n = config.tile_size.n();
16    #[unroll]
17    for n_idx in 0..n as usize {
18        let mut acc_vec = acc[n_idx];
19        #[unroll]
20        for vi in 0..NPlaneVec::value() {
21            let lhs_elem = A::cast_from(lhs[0usize][vi]);
22            let rhs_elem = A::cast_from(rhs[n_idx][vi]);
23            acc_vec[vi] += plane_sum(lhs_elem * rhs_elem);
24        }
25        acc[n_idx] = acc_vec;
26    }
27}
28
29#[cube]
30pub fn planevec_load_from_shared<E: Numeric, ES: Size, N: Numeric, IO: SliceVisibility>(
31    shared: &StridedTile<E, ES, IO>,
32    arr: &mut Array<Vector<N, NPlaneVec>>,
33    #[comptime] config: PlaneVecMatInnerProduct,
34    #[comptime] ident: StageIdent,
35) {
36    match ident {
37        StageIdent::Lhs => {
38            let offset = shared.stage_offset(UNIT_POS_X);
39            arr[0usize] = Vector::cast_from(shared.container[offset as usize]);
40        }
41        StageIdent::Rhs | StageIdent::Acc => {
42            let n = config.tile_size.n();
43            #[unroll]
44            for n_idx in 0..n {
45                let offset = shared.stage_offset(UNIT_POS_X + n_idx * shared.stride);
46                arr[n_idx as usize] = Vector::cast_from(shared.container[offset as usize]);
47            }
48        }
49        _ => panic!("Invalid ident for PlaneVec load"),
50    }
51}
52
53#[cube]
54pub fn planevec_load_zeros<N: Numeric>(
55    arr: &mut Array<Vector<N, NPlaneVec>>,
56    #[comptime] config: PlaneVecMatInnerProduct,
57) {
58    let n = config.tile_size.n();
59    let zero = N::from_int(0);
60    #[unroll]
61    for n_idx in 0..n as usize {
62        arr[n_idx] = Vector::cast_from(zero);
63    }
64}
65
66#[cube]
67pub fn planevec_write_to_shared<A: Numeric, E: Numeric, ES: Size>(
68    shared: &mut StridedTile<E, ES, ReadWrite>,
69    arr: &Array<Vector<A, NPlaneVec>>,
70    #[comptime] config: PlaneVecMatInnerProduct,
71) {
72    if UNIT_POS_X == 0 {
73        let out_vector_size = shared.container.vector_size().comptime();
74        let n = config.tile_size.n();
75        let total_out_vectors = n as usize / out_vector_size;
76        let reduce_vec = config.reduce_vector_size as usize;
77
78        #[unroll]
79        for out_vector_iter in 0..total_out_vectors {
80            let mut out_vector = Vector::<E, ES>::empty();
81            #[unroll]
82            for within_vector in 0..out_vector_size {
83                let n_idx = out_vector_iter * out_vector_size + within_vector;
84                let acc_vec = arr[n_idx];
85                let mut sum = A::from_int(0);
86                for i in 0..reduce_vec {
87                    sum += acc_vec[i];
88                }
89                out_vector[within_vector] = E::cast_from(sum);
90            }
91            let offset = shared.stage_offset(out_vector_iter as u32);
92            shared.container[offset as usize] = out_vector;
93        }
94    }
95}