use cubecl_core as cubecl;
use cubecl_core::prelude::*;
use cubecl_runtime::TypeUsage;
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub enum FastDivmod {
Fast {
divisor: u32,
multiplier: u32,
shift_right: u32,
},
Fallback {
divisor: u32,
},
}
impl<R: Runtime> Clone for FastDivmodArgs<'_, R> {
fn clone(&self) -> Self {
*self
}
}
impl<R: Runtime> Copy for FastDivmodArgs<'_, R> {}
impl<R: Runtime> FastDivmodArgs<'_, R> {
pub fn new(client: &ComputeClient<R::Server>, divisor: u32) -> Self {
debug_assert!(divisor != 0);
if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
return FastDivmodArgs::Fallback {
divisor: ScalarArg::new(divisor),
};
}
let div_64 = divisor as u64;
let shift = find_log2(divisor);
let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
FastDivmodArgs::Fast {
divisor: ScalarArg::new(divisor),
multiplier: ScalarArg::new(multiplier as u32),
shift_right: ScalarArg::new(shift as u32),
}
}
}
#[cube]
impl FastDivmod {
pub fn div(&self, dividend: u32) -> u32 {
match self {
FastDivmod::Fast {
multiplier,
shift_right,
..
} => {
let t = u32::mul_hi(dividend, *multiplier);
(t + dividend) >> shift_right
}
FastDivmod::Fallback { divisor } => dividend / divisor,
}
}
pub fn modulo(&self, dividend: u32) -> u32 {
let q = self.div(dividend);
match self {
FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
FastDivmod::Fallback { divisor } => dividend % divisor,
}
}
pub fn div_mod(&self, dividend: u32) -> (u32, u32) {
let q = self.div(dividend);
let r = match self {
FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
FastDivmod::Fallback { divisor } => dividend - q * divisor,
};
(q, r)
}
}
fn find_log2(x: u32) -> usize {
for i in 0..32 {
if (1 << i) >= x {
return i;
}
}
32
}