use cubecl;
use cubecl::prelude::*;
use crate::tile::ops::FULLY_MASKED_ROW_THRESHOLD;
#[derive(CubeType)]
pub struct RowWise<E: Numeric> {
pub vals: Array<E>,
#[cube(comptime)]
pub num_rows: usize,
}
#[cube]
impl<E: Numeric> RowWise<E> {
pub fn new_filled(#[comptime] num_rows: usize, val: E) -> RowWise<E> {
let mut vals = Array::new(num_rows);
for i in 0..num_rows {
vals[i] = val;
}
RowWise::<E> { vals, num_rows }
}
pub fn fill(&mut self, val: E) {
for i in 0..self.num_rows {
self.vals[i] = val;
}
}
pub fn new_min_value(#[comptime] num_rows: usize) -> RowWise<E> {
Self::new_filled(num_rows, E::min_value())
}
pub fn new_zero(#[comptime] num_rows: usize) -> RowWise<E> {
Self::new_filled(num_rows, E::from_int(0))
}
pub fn copy_from(&mut self, other: &RowWise<E>) {
for i in 0..self.num_rows {
self.vals[i] = other.vals[i]
}
}
pub fn add(&self, other: &RowWise<E>) -> RowWise<E> {
let mut result = Array::new(self.num_rows);
for i in 0..self.num_rows {
result[i] = self.vals[i] + other.vals[i];
}
RowWise::<E> {
vals: result,
num_rows: self.num_rows,
}
}
pub fn add_inplace(&mut self, other: &RowWise<E>) {
for i in 0..self.num_rows {
self.vals[i] += other.vals[i];
}
}
pub fn mul(&self, other: &RowWise<E>) -> RowWise<E> {
let mut result = Array::new(self.num_rows);
for i in 0..self.num_rows {
result[i] = self.vals[i] * other.vals[i];
}
RowWise::<E> {
vals: result,
num_rows: self.num_rows,
}
}
pub fn mul_inplace(&mut self, other: &RowWise<E>) {
for i in 0..self.num_rows {
self.vals[i] *= other.vals[i];
}
}
pub fn max_inplace(&mut self, other: &RowWise<E>) {
for i in 0..self.num_rows {
self.vals[i] = max(self.vals[i], other.vals[i]);
}
}
pub fn replace_at(&mut self, i: usize, new_val: E) {
self.vals[i] = new_val;
}
pub fn cast_from<E2: Float>(row_wise: &RowWise<E>) -> RowWise<E2> {
let num_rows = row_wise.num_rows;
let mut vals = Array::new(num_rows);
for i in 0..num_rows {
vals[i] = E2::cast_from(row_wise.vals[i]);
}
RowWise::<E2> { vals, num_rows }
}
}
#[cube]
impl<E: Float> RowWise<E> {
pub fn exp_diff(&self, other: &RowWise<E>) -> RowWise<E> {
let mut vals = Array::new(self.num_rows);
for i in 0..self.num_rows {
vals[i] = (self.vals[i] - other.vals[i]).exp();
}
RowWise::<E> {
vals,
num_rows: self.num_rows,
}
}
pub fn recip_inplace(&mut self) {
for i in 0..self.num_rows {
let row_val = self.vals[i];
let epsilon = E::new(FULLY_MASKED_ROW_THRESHOLD);
let not_masked = E::cast_from(row_val >= epsilon);
let safe_val = clamp_min(row_val, epsilon);
let recip = safe_val.recip();
self.vals[i] = not_masked * recip;
}
}
}