1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4#[derive(CubeType, CubeLaunch, Clone, Copy)]
12pub enum FastDivmod {
13 Fast {
14 divisor: u32,
15 multiplier: u32,
16 shift_right: u32,
17 },
18 PowerOfTwo {
19 shift: u32,
20 mask: u32,
21 },
22 Fallback {
23 divisor: u32,
24 },
25}
26
27impl<R: Runtime> FastDivmodArgs<'_, R> {
28 pub fn new(client: &ComputeClient<R::Server, R::Channel>, divisor: u32) -> Self {
29 debug_assert!(divisor != 0);
30
31 if divisor.is_power_of_two() {
32 return FastDivmodArgs::PowerOfTwo {
33 shift: ScalarArg::new(divisor.trailing_zeros()),
34 mask: ScalarArg::new(divisor - 1),
35 };
36 }
37
38 if !u64::is_supported(client) {
39 return FastDivmodArgs::Fallback {
40 divisor: ScalarArg::new(divisor),
41 };
42 }
43
44 let div_64 = divisor as u64;
45 let shift = find_log2(divisor);
46 let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
47
48 FastDivmodArgs::Fast {
49 divisor: ScalarArg::new(divisor),
50 multiplier: ScalarArg::new(multiplier as u32),
51 shift_right: ScalarArg::new(shift as u32),
52 }
53 }
54}
55
56#[cube]
57impl FastDivmod {
58 pub fn div(&self, dividend: u32) -> u32 {
59 match self {
60 FastDivmod::Fast {
61 multiplier,
62 shift_right,
63 ..
64 } => {
65 let t = u32::mul_hi(dividend, *multiplier);
66 (t + dividend) >> shift_right
67 }
68 FastDivmod::PowerOfTwo { shift, .. } => dividend >> *shift,
69 FastDivmod::Fallback { divisor } => dividend / divisor,
70 }
71 }
72
73 pub fn modulo(&self, dividend: u32) -> u32 {
74 let q = self.div(dividend);
75 match self {
76 FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
77 FastDivmod::PowerOfTwo { mask, .. } => dividend & mask,
78 FastDivmod::Fallback { divisor } => dividend % divisor,
79 }
80 }
81
82 pub fn div_mod(&self, dividend: u32) -> (u32, u32) {
83 let q = self.div(dividend);
84 let r = match self {
85 FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
86 FastDivmod::Fallback { divisor } => dividend - q * divisor,
87 FastDivmod::PowerOfTwo { mask, .. } => dividend & *mask,
88 };
89
90 (q, r)
91 }
92}
93
94fn find_log2(x: u32) -> usize {
95 for i in 0..32 {
96 if (1 << i) >= x {
97 return i;
98 }
99 }
100 32
101}