1use cubecl_core::ir::features::TypeUsage;
2use cubecl_core::prelude::*;
3use cubecl_core::{self as cubecl, ir::ConstantValue};
4
5#[derive(CubeType, Clone, Copy)]
13pub enum FastDivmod<I: FastDivmodInt> {
14 Fast {
15 divisor: I,
16 multiplier: I,
17 shift_right: u32,
18 },
19 Fallback {
20 divisor: I,
21 },
22}
23
24pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
25 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
26}
27
28impl FastDivmodInt for u32 {
31 fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
32 size_of::<u32>()
33 }
34}
35
36impl FastDivmodInt for usize {
37 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
38 launcher.settings.address_type.unsigned_type().size()
39 }
40}
41
42impl<I: FastDivmodInt> FastDivmodArgs<I> {
43 pub fn new<R: Runtime>(client: &ComputeClient<R>, divisor: I) -> Self {
44 debug_assert!({
45 let divisor_value: ConstantValue = divisor.into();
46 let divisor_value = divisor_value.as_u64();
47 divisor_value != 0
48 });
49
50 if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
51 return FastDivmodArgs::Fallback {
52 divisor: ScalarArg::new(divisor),
53 };
54 }
55
56 FastDivmodArgs::Fast {
57 divisor: ScalarArg::new(divisor),
58 }
59 }
60}
61
62#[cube]
63impl<I: FastDivmodInt> FastDivmod<I> {
64 pub fn div(&self, dividend: I) -> I {
65 match self {
66 FastDivmod::Fast {
67 multiplier,
68 shift_right,
69 ..
70 } => {
71 let t = I::mul_hi(dividend, *multiplier);
72 (t + dividend) >> I::cast_from(*shift_right)
73 }
74 FastDivmod::Fallback { divisor } => dividend / *divisor,
75 }
76 }
77
78 pub fn modulo(&self, dividend: I) -> I {
79 let q = self.div(dividend);
80 match self {
81 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
82 FastDivmod::Fallback { divisor } => dividend % *divisor,
83 }
84 }
85
86 pub fn div_mod(&self, dividend: I) -> (I, I) {
87 let q = self.div(dividend);
88 let r = match self {
89 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
90 FastDivmod::Fallback { divisor } => dividend - q * *divisor,
91 };
92
93 (q, r)
94 }
95}
96
97fn find_params_u32(divisor: u32) -> (u32, u32) {
98 let div_64 = divisor as u64;
99 let shift = divisor.next_power_of_two().trailing_zeros();
100 let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
101 (shift, multiplier as u32)
102}
103
104fn find_params_u64(divisor: u64) -> (u32, u64) {
105 let div_128 = divisor as u128;
106 let shift = divisor.next_power_of_two().trailing_zeros();
107 let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
108 (shift, multiplier as u64)
109}
110
111mod launch {
112 use super::*;
113
114 #[derive(Clone, Copy)]
115 pub enum FastDivmodArgs<I: FastDivmodInt = usize> {
116 Fast { divisor: ScalarArg<I> },
117 Fallback { divisor: ScalarArg<I> },
118 }
119
120 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
121 pub enum FastDivmodCompilationArg<I: FastDivmodInt> {
122 Fast {
123 divisor: ScalarCompilationArg<I>,
124 multiplier: ScalarCompilationArg<I>,
125 shift_right: ScalarCompilationArg<u32>,
126 },
127 Fallback {
128 divisor: ScalarCompilationArg<I>,
129 },
130 }
131
132 impl<I: FastDivmodInt> CompilationArg for FastDivmodCompilationArg<I> {}
133
134 impl<I: FastDivmodInt, R: Runtime> ArgSettings<R> for FastDivmodArgs<I> {
135 fn register(&self, launcher: &mut KernelLauncher<R>) {
136 match self {
137 FastDivmodArgs::Fast { divisor } => {
138 let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
139 4 => {
140 let divisor = divisor.elem.to_u32().unwrap();
141 let (shift, multiplier) = find_params_u32(divisor);
142
143 let shift = ScalarArg::new(shift);
144 let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
145 (shift, multiplier)
146 }
147 8 => {
148 let divisor = divisor.elem.to_u64().unwrap();
149 let (shift, multiplier) = find_params_u64(divisor);
150
151 let shift = ScalarArg::new(shift);
152 let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
153 (shift, multiplier)
154 }
155 _ => panic!("unsupported type size for FastDivmod"),
156 };
157 divisor.register(launcher);
158 multiplier.register(launcher);
159 shift_right.register(launcher);
160 }
161 FastDivmodArgs::Fallback { divisor } => {
162 divisor.register(launcher);
163 }
164 }
165 }
166 }
167
168 impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
169 type RuntimeArg<'a, R: Runtime> = FastDivmodArgs<I>;
170 type CompilationArg = FastDivmodCompilationArg<I>;
171
172 fn compilation_arg<'a, R: Runtime>(
173 runtime_arg: &Self::RuntimeArg<'a, R>,
174 ) -> Self::CompilationArg {
175 match runtime_arg {
176 FastDivmodArgs::Fast { .. } => FastDivmodCompilationArg::Fast {
177 divisor: ScalarCompilationArg::new(),
178 multiplier: ScalarCompilationArg::new(),
179 shift_right: ScalarCompilationArg::new(),
180 },
181 FastDivmodArgs::Fallback { .. } => FastDivmodCompilationArg::Fallback {
182 divisor: ScalarCompilationArg::new(),
183 },
184 }
185 }
186
187 fn expand(
188 arg: &Self::CompilationArg,
189 builder: &mut cubecl::prelude::KernelBuilder,
190 ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
191 match arg {
192 FastDivmodCompilationArg::Fast {
193 divisor,
194 multiplier,
195 shift_right,
196 } => FastDivmodExpand::Fast {
197 divisor: I::expand(divisor, builder),
198 multiplier: I::expand(multiplier, builder),
199 shift_right: u32::expand(shift_right, builder),
200 },
201 FastDivmodCompilationArg::Fallback { divisor } => FastDivmodExpand::Fallback {
202 divisor: I::expand(divisor, builder),
203 },
204 }
205 }
206
207 fn expand_output(
208 arg: &Self::CompilationArg,
209 builder: &mut KernelBuilder,
210 ) -> <Self as CubeType>::ExpandType {
211 Self::expand(arg, builder)
212 }
213 }
214}
215pub use launch::FastDivmodArgs;