1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4#[derive(CubeType, Clone, Copy)]
12pub enum FastDivmod<I: FastDivmodInt> {
13 Fast {
14 divisor: I,
15 multiplier: I,
16 shift_right: u32,
17 },
18 Fallback {
19 divisor: I,
20 },
21}
22
23pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
24 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
25}
26
27impl FastDivmodInt for u32 {
30 fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
31 size_of::<u32>()
32 }
33}
34
35impl FastDivmodInt for usize {
36 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
37 launcher.settings.address_type.unsigned_type().size()
38 }
39}
40
41#[cube]
42impl<I: FastDivmodInt> FastDivmod<I> {
43 pub fn div(&self, dividend: I) -> I {
44 match self {
45 FastDivmod::Fast {
46 multiplier,
47 shift_right,
48 ..
49 } => {
50 let t = I::mul_hi(dividend, *multiplier);
51 (t + dividend) >> I::cast_from(*shift_right)
52 }
53 FastDivmod::Fallback { divisor } => dividend / *divisor,
54 }
55 }
56
57 pub fn modulo(&self, dividend: I) -> I {
58 let q = self.div(dividend);
59 match self {
60 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
61 FastDivmod::Fallback { divisor } => dividend % *divisor,
62 }
63 }
64
65 pub fn div_mod(&self, dividend: I) -> (I, I) {
66 let q = self.div(dividend);
67 let r = match self {
68 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
69 FastDivmod::Fallback { divisor } => dividend - q * *divisor,
70 };
71
72 (q, r)
73 }
74}
75
76fn find_params_u32(divisor: u32) -> (u32, u32) {
77 if divisor == 0 {
83 return (0, 0);
84 }
85 let div_64 = divisor as u64;
86 let shift = divisor.next_power_of_two().trailing_zeros();
87 let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
88 (shift, multiplier as u32)
89}
90
91fn find_params_u64(divisor: u64) -> (u32, u64) {
92 if divisor == 0 {
93 return (0, 0);
94 }
95 let div_128 = divisor as u128;
96 let shift = divisor.next_power_of_two().trailing_zeros();
97 let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
98 (shift, multiplier as u64)
99}
100
101mod launch {
102 use cubecl_core::ir::UIntKind;
103
104 use super::*;
105
106 #[derive_cube_comptime]
107 pub enum FastDivmodCompilationArg {
108 Fast,
109 Fallback,
110 }
111
112 impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
113 type RuntimeArg<R: Runtime> = I;
114 type CompilationArg = FastDivmodCompilationArg;
115
116 fn register<R: Runtime>(
117 divisor: Self::RuntimeArg<R>,
118 launcher: &mut KernelLauncher<R>,
119 ) -> Self::CompilationArg {
120 let props = launcher.with_scope(|scope| scope.properties.clone().unwrap());
121 let fast = props.features.supports_type(UIntKind::U64);
122 match fast {
123 true => {
124 let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
125 4 => {
126 let divisor = divisor.to_u32().unwrap();
127 let (shift, multiplier) = find_params_u32(divisor);
128
129 let multiplier = I::from_int(multiplier as i64);
130 (shift, multiplier)
131 }
132 8 => {
133 let divisor = divisor.to_u64().unwrap();
134 let (shift, multiplier) = find_params_u64(divisor);
135
136 let multiplier = I::from_int(multiplier as i64);
137 (shift, multiplier)
138 }
139 _ => panic!("unsupported type size for FastDivmod"),
140 };
141 <I as LaunchArg>::register(divisor, launcher);
142 <I as LaunchArg>::register(multiplier, launcher);
143 <u32 as LaunchArg>::register(shift_right, launcher);
144 FastDivmodCompilationArg::Fast
145 }
146 false => {
147 <I as LaunchArg>::register(divisor, launcher);
148 FastDivmodCompilationArg::Fallback
149 }
150 }
151 }
152
153 fn expand(
154 arg: &Self::CompilationArg,
155 builder: &mut cubecl::prelude::KernelBuilder,
156 ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
157 match arg {
158 FastDivmodCompilationArg::Fast => FastDivmodExpand::Fast {
159 divisor: I::expand(&(), builder),
160 multiplier: I::expand(&(), builder),
161 shift_right: u32::expand(&(), builder),
162 },
163 FastDivmodCompilationArg::Fallback => FastDivmodExpand::Fallback {
164 divisor: I::expand(&(), builder),
165 },
166 }
167 }
168 }
169}