cubecl_std/
fast_math.rs

1use cubecl_core::ir::features::TypeUsage;
2use cubecl_core::prelude::*;
3use cubecl_core::{self as cubecl, ir::ConstantValue};
4
5/// Create a fast-divmod object if supported, or a regular fallback if not.
6/// This precalculates certain values on the host, in exchange for making division and modulo
7/// operations on the GPU much faster. Only supports u32 right now to allow for a simpler algorithm.
8/// It's mostly used for indices regardless.
9///
10/// Implementation based on ONNX:
11/// <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h>
12#[derive(CubeType, Clone, Copy)]
13pub enum FastDivmod<I: FastDivmodInt> {
14    Fast {
15        divisor: I,
16        multiplier: I,
17        shift_right: u32,
18    },
19    Fallback {
20        divisor: I,
21    },
22}
23
24pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
25    fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
26}
27
28// Could potentially support signed, but that needs more handling
29
30impl FastDivmodInt for u32 {
31    fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
32        size_of::<u32>()
33    }
34}
35
36impl FastDivmodInt for usize {
37    fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
38        launcher.settings.address_type.unsigned_type().size()
39    }
40}
41
42impl<I: FastDivmodInt> FastDivmodArgs<I> {
43    pub fn new<R: Runtime>(client: &ComputeClient<R>, divisor: I) -> Self {
44        debug_assert!({
45            let divisor_value: ConstantValue = divisor.into();
46            let divisor_value = divisor_value.as_u64();
47            divisor_value != 0
48        });
49
50        if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
51            return FastDivmodArgs::Fallback {
52                divisor: ScalarArg::new(divisor),
53            };
54        }
55
56        FastDivmodArgs::Fast {
57            divisor: ScalarArg::new(divisor),
58        }
59    }
60}
61
62#[cube]
63impl<I: FastDivmodInt> FastDivmod<I> {
64    pub fn div(&self, dividend: I) -> I {
65        match self {
66            FastDivmod::Fast {
67                multiplier,
68                shift_right,
69                ..
70            } => {
71                let t = I::mul_hi(dividend, *multiplier);
72                (t + dividend) >> I::cast_from(*shift_right)
73            }
74            FastDivmod::Fallback { divisor } => dividend / *divisor,
75        }
76    }
77
78    pub fn modulo(&self, dividend: I) -> I {
79        let q = self.div(dividend);
80        match self {
81            FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
82            FastDivmod::Fallback { divisor } => dividend % *divisor,
83        }
84    }
85
86    pub fn div_mod(&self, dividend: I) -> (I, I) {
87        let q = self.div(dividend);
88        let r = match self {
89            FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
90            FastDivmod::Fallback { divisor } => dividend - q * *divisor,
91        };
92
93        (q, r)
94    }
95}
96
97fn find_params_u32(divisor: u32) -> (u32, u32) {
98    let div_64 = divisor as u64;
99    let shift = divisor.next_power_of_two().trailing_zeros();
100    let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
101    (shift, multiplier as u32)
102}
103
104fn find_params_u64(divisor: u64) -> (u32, u64) {
105    let div_128 = divisor as u128;
106    let shift = divisor.next_power_of_two().trailing_zeros();
107    let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
108    (shift, multiplier as u64)
109}
110
111mod launch {
112    use super::*;
113
114    #[derive(Clone, Copy)]
115    pub enum FastDivmodArgs<I: FastDivmodInt = usize> {
116        Fast { divisor: ScalarArg<I> },
117        Fallback { divisor: ScalarArg<I> },
118    }
119
120    #[derive(Clone, PartialEq, Eq, Hash, Debug)]
121    pub enum FastDivmodCompilationArg<I: FastDivmodInt> {
122        Fast {
123            divisor: ScalarCompilationArg<I>,
124            multiplier: ScalarCompilationArg<I>,
125            shift_right: ScalarCompilationArg<u32>,
126        },
127        Fallback {
128            divisor: ScalarCompilationArg<I>,
129        },
130    }
131
132    impl<I: FastDivmodInt> CompilationArg for FastDivmodCompilationArg<I> {}
133
134    impl<I: FastDivmodInt, R: Runtime> ArgSettings<R> for FastDivmodArgs<I> {
135        fn register(&self, launcher: &mut KernelLauncher<R>) {
136            match self {
137                FastDivmodArgs::Fast { divisor } => {
138                    let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
139                        4 => {
140                            let divisor = divisor.elem.to_u32().unwrap();
141                            let (shift, multiplier) = find_params_u32(divisor);
142
143                            let shift = ScalarArg::new(shift);
144                            let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
145                            (shift, multiplier)
146                        }
147                        8 => {
148                            let divisor = divisor.elem.to_u64().unwrap();
149                            let (shift, multiplier) = find_params_u64(divisor);
150
151                            let shift = ScalarArg::new(shift);
152                            let multiplier = ScalarArg::new(I::from_int(multiplier as i64));
153                            (shift, multiplier)
154                        }
155                        _ => panic!("unsupported type size for FastDivmod"),
156                    };
157                    divisor.register(launcher);
158                    multiplier.register(launcher);
159                    shift_right.register(launcher);
160                }
161                FastDivmodArgs::Fallback { divisor } => {
162                    divisor.register(launcher);
163                }
164            }
165        }
166    }
167
168    impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
169        type RuntimeArg<'a, R: Runtime> = FastDivmodArgs<I>;
170        type CompilationArg = FastDivmodCompilationArg<I>;
171
172        fn compilation_arg<'a, R: Runtime>(
173            runtime_arg: &Self::RuntimeArg<'a, R>,
174        ) -> Self::CompilationArg {
175            match runtime_arg {
176                FastDivmodArgs::Fast { .. } => FastDivmodCompilationArg::Fast {
177                    divisor: ScalarCompilationArg::new(),
178                    multiplier: ScalarCompilationArg::new(),
179                    shift_right: ScalarCompilationArg::new(),
180                },
181                FastDivmodArgs::Fallback { .. } => FastDivmodCompilationArg::Fallback {
182                    divisor: ScalarCompilationArg::new(),
183                },
184            }
185        }
186
187        fn expand(
188            arg: &Self::CompilationArg,
189            builder: &mut cubecl::prelude::KernelBuilder,
190        ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
191            match arg {
192                FastDivmodCompilationArg::Fast {
193                    divisor,
194                    multiplier,
195                    shift_right,
196                } => FastDivmodExpand::Fast {
197                    divisor: I::expand(divisor, builder),
198                    multiplier: I::expand(multiplier, builder),
199                    shift_right: u32::expand(shift_right, builder),
200                },
201                FastDivmodCompilationArg::Fallback { divisor } => FastDivmodExpand::Fallback {
202                    divisor: I::expand(divisor, builder),
203                },
204            }
205        }
206
207        fn expand_output(
208            arg: &Self::CompilationArg,
209            builder: &mut KernelBuilder,
210        ) -> <Self as CubeType>::ExpandType {
211            Self::expand(arg, builder)
212        }
213    }
214}
215pub use launch::FastDivmodArgs;