use cubecl::prelude::*;
use crate::{
StageIdent,
tile::data::{NPlaneVec, PlaneVecMatInnerProduct, StridedTile},
};
#[cube]
pub fn planevec_execute<L: Numeric, R: Numeric, A: Numeric>(
lhs: &Array<Vector<L, NPlaneVec>>,
rhs: &Array<Vector<R, NPlaneVec>>,
acc: &mut Array<Vector<A, NPlaneVec>>,
#[comptime] config: PlaneVecMatInnerProduct,
) {
let n = config.tile_size.n();
#[unroll]
for n_idx in 0..n as usize {
let mut acc_vec = acc[n_idx];
#[unroll]
for vi in 0..NPlaneVec::value() {
let lhs_elem = A::cast_from(lhs[0usize][vi]);
let rhs_elem = A::cast_from(rhs[n_idx][vi]);
acc_vec[vi] += plane_sum(lhs_elem * rhs_elem);
}
acc[n_idx] = acc_vec;
}
}
#[cube]
pub fn planevec_load_from_shared<E: Numeric, ES: Size, N: Numeric, IO: SliceVisibility>(
shared: &StridedTile<E, ES, IO>,
arr: &mut Array<Vector<N, NPlaneVec>>,
#[comptime] config: PlaneVecMatInnerProduct,
#[comptime] ident: StageIdent,
) {
match ident {
StageIdent::Lhs => {
let offset = shared.stage_offset(UNIT_POS_X);
arr[0usize] = Vector::cast_from(shared.container[offset as usize]);
}
StageIdent::Rhs | StageIdent::Acc => {
let n = config.tile_size.n();
#[unroll]
for n_idx in 0..n {
let offset = shared.stage_offset(UNIT_POS_X + n_idx * shared.stride);
arr[n_idx as usize] = Vector::cast_from(shared.container[offset as usize]);
}
}
_ => panic!("Invalid ident for PlaneVec load"),
}
}
#[cube]
pub fn planevec_load_zeros<N: Numeric>(
arr: &mut Array<Vector<N, NPlaneVec>>,
#[comptime] config: PlaneVecMatInnerProduct,
) {
let n = config.tile_size.n();
let zero = N::from_int(0);
#[unroll]
for n_idx in 0..n as usize {
arr[n_idx] = Vector::cast_from(zero);
}
}
#[cube]
pub fn planevec_write_to_shared<A: Numeric, E: Numeric, ES: Size>(
shared: &mut StridedTile<E, ES, ReadWrite>,
arr: &Array<Vector<A, NPlaneVec>>,
#[comptime] config: PlaneVecMatInnerProduct,
) {
if UNIT_POS_X == 0 {
let out_vector_size = shared.container.vector_size().comptime();
let n = config.tile_size.n();
let total_out_vectors = n as usize / out_vector_size;
let reduce_vec = config.reduce_vector_size as usize;
#[unroll]
for out_vector_iter in 0..total_out_vectors {
let mut out_vector = Vector::<E, ES>::empty();
#[unroll]
for within_vector in 0..out_vector_size {
let n_idx = out_vector_iter * out_vector_size + within_vector;
let acc_vec = arr[n_idx];
let mut sum = A::from_int(0);
for i in 0..reduce_vec {
sum += acc_vec[i];
}
out_vector[within_vector] = E::cast_from(sum);
}
let offset = shared.stage_offset(out_vector_iter as u32);
shared.container[offset as usize] = out_vector;
}
}
}