Skip to main content

cubecl_std/
fast_math.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4/// Create a fast-divmod object if supported, or a regular fallback if not.
5/// This precalculates certain values on the host, in exchange for making division and modulo
6/// operations on the GPU much faster. Only supports u32 right now to allow for a simpler algorithm.
7/// It's mostly used for indices regardless.
8///
9/// Implementation based on ONNX:
10/// <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h>
11#[derive(CubeType, Clone, Copy)]
12pub enum FastDivmod<I: FastDivmodInt> {
13    Fast {
14        divisor: I,
15        multiplier: I,
16        shift_right: u32,
17    },
18    Fallback {
19        divisor: I,
20    },
21}
22
23pub trait FastDivmodInt: Int + MulHi + ScalarArgSettings {
24    fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize;
25}
26
27// Could potentially support signed, but that needs more handling
28
29impl FastDivmodInt for u32 {
30    fn size<R: Runtime>(_launcher: &KernelLauncher<R>) -> usize {
31        size_of::<u32>()
32    }
33}
34
35impl FastDivmodInt for usize {
36    fn size<R: Runtime>(launcher: &KernelLauncher<R>) -> usize {
37        launcher.settings.address_type.unsigned_type().size()
38    }
39}
40
41#[cube]
42impl<I: FastDivmodInt> FastDivmod<I> {
43    pub fn div(&self, dividend: I) -> I {
44        match self {
45            FastDivmod::Fast {
46                multiplier,
47                shift_right,
48                ..
49            } => {
50                let t = I::mul_hi(dividend, *multiplier);
51                (t + dividend) >> I::cast_from(*shift_right)
52            }
53            FastDivmod::Fallback { divisor } => dividend / *divisor,
54        }
55    }
56
57    pub fn modulo(&self, dividend: I) -> I {
58        let q = self.div(dividend);
59        match self {
60            FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
61            FastDivmod::Fallback { divisor } => dividend % *divisor,
62        }
63    }
64
65    pub fn div_mod(&self, dividend: I) -> (I, I) {
66        let q = self.div(dividend);
67        let r = match self {
68            FastDivmod::Fast { divisor, .. } => dividend - q * *divisor,
69            FastDivmod::Fallback { divisor } => dividend - q * *divisor,
70        };
71
72        (q, r)
73    }
74}
75
76fn find_params_u32(divisor: u32) -> (u32, u32) {
77    // A zero divisor arises only when a tensor has a zero-sized dimension
78    // (e.g. Brush's `sh_coeffs_rest` of shape [N, 0, 3] at SH degree 0).
79    // Such tensors cause 0 workgroups to be dispatched, so these params are
80    // never read by any kernel thread — return a dummy pair instead of
81    // panicking during kernel launch preparation.
82    if divisor == 0 {
83        return (0, 0);
84    }
85    let div_64 = divisor as u64;
86    let shift = divisor.next_power_of_two().trailing_zeros();
87    let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
88    (shift, multiplier as u32)
89}
90
91fn find_params_u64(divisor: u64) -> (u32, u64) {
92    if divisor == 0 {
93        return (0, 0);
94    }
95    let div_128 = divisor as u128;
96    let shift = divisor.next_power_of_two().trailing_zeros();
97    let multiplier = ((1u128 << 64) * ((1u128 << shift) - div_128)) / div_128 + 1;
98    (shift, multiplier as u64)
99}
100
101mod launch {
102    use cubecl_core::ir::UIntKind;
103
104    use super::*;
105
106    #[derive_cube_comptime]
107    pub enum FastDivmodCompilationArg {
108        Fast,
109        Fallback,
110    }
111
112    impl<I: FastDivmodInt> LaunchArg for FastDivmod<I> {
113        type RuntimeArg<R: Runtime> = I;
114        type CompilationArg = FastDivmodCompilationArg;
115
116        fn register<R: Runtime>(
117            divisor: Self::RuntimeArg<R>,
118            launcher: &mut KernelLauncher<R>,
119        ) -> Self::CompilationArg {
120            let props = launcher.with_scope(|scope| scope.properties.clone().unwrap());
121            let fast = props.features.supports_type(UIntKind::U64);
122            match fast {
123                true => {
124                    let (shift_right, multiplier) = match <I as FastDivmodInt>::size(launcher) {
125                        4 => {
126                            let divisor = divisor.to_u32().unwrap();
127                            let (shift, multiplier) = find_params_u32(divisor);
128
129                            let multiplier = I::from_int(multiplier as i64);
130                            (shift, multiplier)
131                        }
132                        8 => {
133                            let divisor = divisor.to_u64().unwrap();
134                            let (shift, multiplier) = find_params_u64(divisor);
135
136                            let multiplier = I::from_int(multiplier as i64);
137                            (shift, multiplier)
138                        }
139                        _ => panic!("unsupported type size for FastDivmod"),
140                    };
141                    <I as LaunchArg>::register(divisor, launcher);
142                    <I as LaunchArg>::register(multiplier, launcher);
143                    <u32 as LaunchArg>::register(shift_right, launcher);
144                    FastDivmodCompilationArg::Fast
145                }
146                false => {
147                    <I as LaunchArg>::register(divisor, launcher);
148                    FastDivmodCompilationArg::Fallback
149                }
150            }
151        }
152
153        fn expand(
154            arg: &Self::CompilationArg,
155            builder: &mut cubecl::prelude::KernelBuilder,
156        ) -> <Self as cubecl::prelude::CubeType>::ExpandType {
157            match arg {
158                FastDivmodCompilationArg::Fast => FastDivmodExpand::Fast {
159                    divisor: I::expand(&(), builder),
160                    multiplier: I::expand(&(), builder),
161                    shift_right: u32::expand(&(), builder),
162                },
163                FastDivmodCompilationArg::Fallback => FastDivmodExpand::Fallback {
164                    divisor: I::expand(&(), builder),
165                },
166            }
167        }
168    }
169}