1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_runtime::TypeUsage;
4
5#[derive(CubeType, CubeLaunch, Clone, Copy)]
13pub enum FastDivmod {
14 Fast {
15 divisor: u32,
16 multiplier: u32,
17 shift_right: u32,
18 },
19 Fallback {
20 divisor: u32,
21 },
22}
23
24impl<R: Runtime> Clone for FastDivmodArgs<'_, R> {
25 fn clone(&self) -> Self {
26 *self
27 }
28}
29impl<R: Runtime> Copy for FastDivmodArgs<'_, R> {}
30
31impl<R: Runtime> FastDivmodArgs<'_, R> {
32 pub fn new(client: &ComputeClient<R::Server>, divisor: u32) -> Self {
33 debug_assert!(divisor != 0);
34
35 if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
36 return FastDivmodArgs::Fallback {
37 divisor: ScalarArg::new(divisor),
38 };
39 }
40
41 let div_64 = divisor as u64;
42 let shift = find_log2(divisor);
43 let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
44
45 FastDivmodArgs::Fast {
46 divisor: ScalarArg::new(divisor),
47 multiplier: ScalarArg::new(multiplier as u32),
48 shift_right: ScalarArg::new(shift as u32),
49 }
50 }
51}
52
53#[cube]
54impl FastDivmod {
55 pub fn div(&self, dividend: u32) -> u32 {
56 match self {
57 FastDivmod::Fast {
58 multiplier,
59 shift_right,
60 ..
61 } => {
62 let t = u32::mul_hi(dividend, *multiplier);
63 (t + dividend) >> shift_right
64 }
65 FastDivmod::Fallback { divisor } => dividend / divisor,
66 }
67 }
68
69 pub fn modulo(&self, dividend: u32) -> u32 {
70 let q = self.div(dividend);
71 match self {
72 FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
73 FastDivmod::Fallback { divisor } => dividend % divisor,
74 }
75 }
76
77 pub fn div_mod(&self, dividend: u32) -> (u32, u32) {
78 let q = self.div(dividend);
79 let r = match self {
80 FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
81 FastDivmod::Fallback { divisor } => dividend - q * divisor,
82 };
83
84 (q, r)
85 }
86}
87
88fn find_log2(x: u32) -> usize {
89 for i in 0..32 {
90 if (1 << i) >= x {
91 return i;
92 }
93 }
94 32
95}