use cubecl;
use cubecl::prelude::*;
use crate::tile::ops::broadcast_reducer::{local_row_max, local_row_sum};
use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
use crate::tile::variants::BounceTile;
use crate::tile::{Plane, Tile, TileExpand};
#[cube]
impl<E: Float> Tile<E, Plane, ReadWrite> {
pub fn row_max(&self, acc: &mut RowWise<E>, base: &RowWise<E>) {
match self {
Tile::Unit(t) => {
acc.copy_from(base);
let m = comptime!(t.layout.num_rows);
let n = comptime!(t.layout.num_cols);
for r in 0..m as usize {
let row_offset = r as u32 * n;
let mut val = E::min_value();
for c in 0..n {
val = max(val, t.data[(row_offset + c) as usize]);
}
acc.vals[r] = max(acc.vals[r], val);
}
}
Tile::Local(t) => {
local_row_max::<E>(acc, base, t);
}
Tile::Bounce(b) => {
local_row_max::<E>(acc, base, &b.local);
}
Tile::Register(t) => {
acc.copy_from(base);
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
for r in 0..m as usize {
let row_offset = r as u32 * n;
let mut val = E::min_value();
for c in 0..n {
val = max(val, t.data[(row_offset + c) as usize]);
}
acc.vals[r] = max(acc.vals[r], val);
}
}
_ => panic!("row_max: unsupported tile variant"),
}
}
pub fn row_sum(&self, acc: &mut RowWise<E>) {
match self {
Tile::Unit(t) => {
acc.fill(E::from_int(0));
let m = comptime!(t.layout.num_rows);
let n = comptime!(t.layout.num_cols);
for r in 0..m as usize {
let row_offset = r as u32 * n;
let mut val = E::from_int(0);
for c in 0..n {
val += t.data[(row_offset + c) as usize];
}
acc.vals[r] += val;
}
}
Tile::Local(t) => {
local_row_sum::<E>(acc, t);
}
Tile::Bounce(b) => {
local_row_sum::<E>(acc, &b.local);
}
Tile::Register(t) => {
acc.fill(E::from_int(0));
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
for r in 0..m as usize {
let row_offset = r as u32 * n;
let mut val = E::from_int(0);
for c in 0..n {
val += t.data[(row_offset + c) as usize];
}
acc.vals[r] += val;
}
}
_ => panic!("row_sum: unsupported tile variant"),
}
}
pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
match self {
Tile::Unit(t) => t.exp_diff(rowwise),
Tile::Local(t) => t.exp_diff(rowwise),
Tile::Bounce(b) => b.local.exp_diff(rowwise),
Tile::Register(t) => {
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
let threshold = E::new(LOGIT_MASKED);
for r in 0..m as usize {
let row_offset = r as u32 * n;
let val = rowwise.vals[r];
let safe_val = clamp_min(val, threshold);
let not_masked = E::cast_from(val >= threshold);
for c in 0..n {
let idx = (row_offset + c) as usize;
t.data[idx] = not_masked * (t.data[idx] - safe_val).exp();
}
}
}
_ => panic!("exp_diff: unsupported tile variant"),
}
}
pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
match self {
Tile::Unit(t) => t.rowwise_scale(scale),
Tile::Local(t) => t.rowwise_scale(scale),
Tile::Bounce(b) => b.local.rowwise_scale(scale),
Tile::Register(t) => {
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
for r in 0..m as usize {
let row_offset = r as u32 * n;
for c in 0..n {
let idx = (row_offset + c) as usize;
t.data[idx] = t.data[idx] * scale.vals[r];
}
}
}
_ => panic!("rowwise_scale: unsupported tile variant"),
}
}
pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
match self {
Tile::Unit(t) => t.scale_and_mask::<M>(scale, mask),
Tile::Local(t) => t.scale_and_mask::<M>(scale, mask),
Tile::Bounce(b) => b.local.scale_and_mask::<M>(scale, mask),
Tile::Register(t) => {
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
for r in 0..m {
let row_offset = r * n;
for c in 0..n {
let idx = (row_offset + c) as usize;
t.data[idx] = t.data[idx] * scale
+ E::cast_from(mask.should_mask((r, c))) * E::min_value();
}
}
}
_ => panic!("scale_and_mask: unsupported tile variant"),
}
}
pub fn fill_zero(&mut self) {
match self {
Tile::Register(t) => {
let m = comptime!(t.config.tile_size.m());
let n = comptime!(t.config.tile_size.n());
for i in 0..m * n {
t.data[i as usize] = E::from_int(0);
}
}
Tile::Unit(t) => t.zero(),
Tile::Local(t) => t.zero(),
Tile::Bounce(b) => {
cubecl::cmma::fill(&b.cmma.matrix, E::from_int(0));
}
Tile::Cmma(t) => {
cubecl::cmma::fill(&t.matrix, E::from_int(0));
}
_ => panic!("fill_zero: unsupported tile variant"),
}
}
}
#[cube]
pub(crate) fn cmma_to_local<E: Float>(b: &mut BounceTile<E>) {
let stride = comptime!(b.cmma.tile_size.n());
cubecl::cmma::store(
&mut b.smem,
&b.cmma.matrix,
stride,
cubecl::cmma::MatrixLayout::RowMajor,
);
sync_cube();
b.local.load_from_slice(&b.smem.to_slice());
sync_cube();
}
#[cube]
pub(crate) fn local_to_cmma<E: Float>(b: &mut BounceTile<E>) {
let stride = comptime!(b.cmma.tile_size.n());
b.local.store_to(&mut b.smem);
sync_cube();
cubecl::cmma::load_with_layout(
&b.cmma.matrix,
&b.smem.to_slice(),
stride,
cubecl::cmma::MatrixLayout::RowMajor,
);
}