1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4#[derive(CubeType, Clone, Copy)]
12pub enum FastDivmod<I: FastDivmodInt> {
13 Fast {
14 divisor: I,
15 multiplier: I,
16 shift_right: u32,
17 },
18 Fallback {
19 divisor: I,
20 },
21}
22
23pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
24 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
25}
26
27impl FastDivmodInt for u32 {
30 fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
31 size_of::<u32>()
32 }
33}
34
35impl FastDivmodInt for usize {
36 fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
37 launcher.settings.address_type.unsigned_type().size()
38 }
39}
40
41#[cube]
42impl<I: FastDivmodInt> FastDivmod<I> {
43 pub fn div(&self, dividend: I) -> I {
44 match self {
45 FastDivmod::Fast {
46 multiplier,
47 shift_right,
48 ..
49 } => {
50 let t = I::mul_hi(dividend, *multiplier);
51 (t + dividend) >> I::cast_from(*shift_right)
52 }
53 FastDivmod::Fallback { divisor } => dividend / *divisor,
54 }
55 }
56
57 pub fn modulo(&self, dividend: I) -> I {
58 let q = self.div(dividend);
59 match self {
60 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
61 FastDivmod::Fallback { divisor } => dividend % *divisor,
62 }
63 }
64
65 pub fn div_mod(&self, dividend: I) -> (I, I) {
66 let q = self.div(dividend);
67 let r = match self {
68 FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
69 FastDivmod::Fallback { divisor } => dividend - q * *divisor,
70 };
71
72 (q, r)
73 }
74}
75
76fn find_params_u32(divisor: u32) -> (u32, u32) {
77 let div_64 = divisor as u64;
78 let shift = divisor.next_power_of_two().trailing_zeros();
79 let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
80 (shift, multiplier as u32)
81}
82
83fn find_params_u64(divisor: u64) -> (u32, u64) {
84 let div_128 = divisor as u128;
85 let shift = divisor.next_power_of_two().trailing_zeros();
86 let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
87 (shift, multiplier as u64)
88}
89
90mod launch {
91 use cubecl_core::ir::UIntKind;
92
93 use super::*;
94
95 #[derive_cube_comptime]
96 pub enum FastDivmodCompilationArg {
97 Fast,
98 Fallback,
99 }
100
101 impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
102 type RuntimeArg<R: Runtime> = I;
103 type CompilationArg = FastDivmodCompilationArg;
104
105 fn register<R: Runtime>(
106 divisor: Self::RuntimeArg<R>,
107 launcher: &mut KernelLauncher<R>,
108 ) -> Self::CompilationArg {
109 let props = launcher.with_scope(|scope| scope.properties.clone().unwrap());
110 let fast = props.features.supports_type(UIntKind::U64);
111 match fast {
112 true => {
113 let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
114 4 => {
115 let divisor = divisor.to_u32().unwrap();
116 let (shift, multiplier) = find_params_u32(divisor);
117
118 let multiplier = I::from_int(multiplier as i64);
119 (shift, multiplier)
120 }
121 8 => {
122 let divisor = divisor.to_u64().unwrap();
123 let (shift, multiplier) = find_params_u64(divisor);
124
125 let multiplier = I::from_int(multiplier as i64);
126 (shift, multiplier)
127 }
128 _ => panic!("unsupported type size for FastDivmod"),
129 };
130 <I as LaunchArg>::register(divisor, launcher);
131 <I as LaunchArg>::register(multiplier, launcher);
132 <u32 as LaunchArg>::register(shift_right, launcher);
133 FastDivmodCompilationArg::Fast
134 }
135 false => {
136 <I as LaunchArg>::register(divisor, launcher);
137 FastDivmodCompilationArg::Fallback
138 }
139 }
140 }
141
142 fn expand(
143 arg: &Self::CompilationArg,
144 builder: &mut cubecl::prelude::KernelBuilder,
145 ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
146 match arg {
147 FastDivmodCompilationArg::Fast => FastDivmodExpand::Fast {
148 divisor: I::expand(&(), builder),
149 multiplier: I::expand(&(), builder),
150 shift_right: u32::expand(&(), builder),
151 },
152 FastDivmodCompilationArg::Fallback => FastDivmodExpand::Fallback {
153 divisor: I::expand(&(), builder),
154 },
155 }
156 }
157 }
158}