use cubecl_core::ir::features::TypeUsage;
use cubecl_core::prelude::*;
use cubecl_core::{self as cubecl, ir::ConstantValue};
#[derive(CubeType, Clone, Copy)]
pub enum FastDivmod<I: FastDivmodInt> {
Fast {
divisor: I,
multiplier: I,
shift_right: u32,
},
Fallback {
divisor: I,
},
}
pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
}
impl FastDivmodInt for u32 {
fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
size_of::<u32>()
}
}
impl FastDivmodInt for usize {
fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
launcher.settings.address_type.unsigned_type().size()
}
}
impl<I: FastDivmodInt> FastDivmodArgs<I> {
pub fn new<R: Runtime>(client: &ComputeClient<R>, divisor: I) -> Self {
debug_assert!({
let divisor_value: ConstantValue = divisor.into();
let divisor_value = divisor_value.as_u64();
divisor_value != 0
});
if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
return FastDivmodArgs::Fallback {
divisor: ScalarArg::new(divisor),
};
}
FastDivmodArgs::Fast {
divisor: ScalarArg::new(divisor),
}
}
}
#[cube]
impl<I: FastDivmodInt> FastDivmod<I> {
pub fn div(&self, dividend: I) -> I {
match self {
FastDivmod::Fast {
multiplier,
shift_right,
..
} => {
let t = I::mul_hi(dividend, *multiplier);
(t + dividend) >> I::cast_from(*shift_right)
}
FastDivmod::Fallback { divisor } => dividend / *divisor,
}
}
pub fn modulo(&self, dividend: I) -> I {
let q = self.div(dividend);
match self {
FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
FastDivmod::Fallback { divisor } => dividend % *divisor,
}
}
pub fn div_mod(&self, dividend: I) -> (I, I) {
let q = self.div(dividend);
let r = match self {
FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
FastDivmod::Fallback { divisor } => dividend - q * *divisor,
};
(q, r)
}
}
fn find_params_u32(divisor: u32) -> (u32, u32) {
let div_64 = divisor as u64;
let shift = divisor.next_power_of_two().trailing_zeros();
let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
(shift, multiplier as u32)
}
fn find_params_u64(divisor: u64) -> (u32, u64) {
let div_128 = divisor as u128;
let shift = divisor.next_power_of_two().trailing_zeros();
let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
(shift, multiplier as u64)
}
mod launch {
use super::*;
#[derive(Clone, Copy)]
pub enum FastDivmodArgs<I: FastDivmodInt = usize> {
Fast { divisor: ScalarArg<I> },
Fallback { divisor: ScalarArg<I> },
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum FastDivmodCompilationArg<I: FastDivmodInt> {
Fast {
divisor: ScalarCompilationArg<I>,
multiplier: ScalarCompilationArg<I>,
shift_right: ScalarCompilationArg<u32>,
},
Fallback {
divisor: ScalarCompilationArg<I>,
},
}
impl<I: FastDivmodInt> CompilationArg for FastDivmodCompilationArg<I> {}
impl<I: FastDivmodInt, R: Runtime> ArgSettings<R> for FastDivmodArgs<I> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
match self {
FastDivmodArgs::Fast { divisor } => {
let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
4 => {
let divisor = divisor.elem.to_u32().unwrap();
let (shift, multiplier) = find_params_u32(divisor);
let shift = ScalarArg::new(shift);
let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
(shift, multiplier)
}
8 => {
let divisor = divisor.elem.to_u64().unwrap();
let (shift, multiplier) = find_params_u64(divisor);
let shift = ScalarArg::new(shift);
let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
(shift, multiplier)
}
_ => panic!("unsupported type size for FastDivmod"),
};
divisor.register(launcher);
multiplier.register(launcher);
shift_right.register(launcher);
}
FastDivmodArgs::Fallback { divisor } => {
divisor.register(launcher);
}
}
}
}
impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
type RuntimeArg<'a, R: Runtime> = FastDivmodArgs<I>;
type CompilationArg = FastDivmodCompilationArg<I>;
fn compilation_arg<'a, R: Runtime>(
runtime_arg: &Self::RuntimeArg<'a, R>,
) -> Self::CompilationArg {
match runtime_arg {
FastDivmodArgs::Fast { .. } => FastDivmodCompilationArg::Fast {
divisor: ScalarCompilationArg::new(),
multiplier: ScalarCompilationArg::new(),
shift_right: ScalarCompilationArg::new(),
},
FastDivmodArgs::Fallback { .. } => FastDivmodCompilationArg::Fallback {
divisor: ScalarCompilationArg::new(),
},
}
}
fn expand(
arg: &Self::CompilationArg,
builder: &mut cubecl::prelude::KernelBuilder,
) -> <Self as cubecl::prelude::CubeType>::ExpandType {
match arg {
FastDivmodCompilationArg::Fast {
divisor,
multiplier,
shift_right,
} => FastDivmodExpand::Fast {
divisor: I::expand(divisor, builder),
multiplier: I::expand(multiplier, builder),
shift_right: u32::expand(shift_right, builder),
},
FastDivmodCompilationArg::Fallback { divisor } => FastDivmodExpand::Fallback {
divisor: I::expand(divisor, builder),
},
}
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType {
Self::expand(arg, builder)
}
}
}
pub use launch::FastDivmodArgs;