cubecl_std/
fast_math.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4/// Create a fast-divmod object if supported, or a regular fallback if not.
5/// This precalculates certain values on the host, in exchange for making division and modulo
6/// operations on the GPU much faster. Only supports u32 right now to allow for a simpler algorithm.
7/// It's mostly used for indices regardless.
8///
9/// Implementation based on ONNX:
10/// <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h>
11#[derive(CubeType, CubeLaunch, Clone, Copy)]
12pub enum FastDivmod {
13    Fast {
14        divisor: u32,
15        multiplier: u32,
16        shift_right: u32,
17    },
18    PowerOfTwo {
19        shift: u32,
20        mask: u32,
21    },
22    Fallback {
23        divisor: u32,
24    },
25}
26
27impl<R: Runtime> FastDivmodArgs<'_, R> {
28    pub fn new(client: &ComputeClient<R::Server, R::Channel>, divisor: u32) -> Self {
29        debug_assert!(divisor != 0);
30
31        if divisor.is_power_of_two() {
32            return FastDivmodArgs::PowerOfTwo {
33                shift: ScalarArg::new(divisor.trailing_zeros()),
34                mask: ScalarArg::new(divisor - 1),
35            };
36        }
37
38        if !u64::is_supported(client) {
39            return FastDivmodArgs::Fallback {
40                divisor: ScalarArg::new(divisor),
41            };
42        }
43
44        let div_64 = divisor as u64;
45        let shift = find_log2(divisor);
46        let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
47
48        FastDivmodArgs::Fast {
49            divisor: ScalarArg::new(divisor),
50            multiplier: ScalarArg::new(multiplier as u32),
51            shift_right: ScalarArg::new(shift as u32),
52        }
53    }
54}
55
56#[cube]
57impl FastDivmod {
58    pub fn div(&self, dividend: u32) -> u32 {
59        match self {
60            FastDivmod::Fast {
61                multiplier,
62                shift_right,
63                ..
64            } => {
65                let t = u32::mul_hi(dividend, *multiplier);
66                (t + dividend) >> shift_right
67            }
68            FastDivmod::PowerOfTwo { shift, .. } => dividend >> *shift,
69            FastDivmod::Fallback { divisor } => dividend / divisor,
70        }
71    }
72
73    pub fn modulo(&self, dividend: u32) -> u32 {
74        let q = self.div(dividend);
75        match self {
76            FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
77            FastDivmod::PowerOfTwo { mask, .. } => dividend & mask,
78            FastDivmod::Fallback { divisor } => dividend % divisor,
79        }
80    }
81
82    pub fn div_mod(&self, dividend: u32) -> (u32, u32) {
83        let q = self.div(dividend);
84        let r = match self {
85            FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
86            FastDivmod::Fallback { divisor } => dividend - q * divisor,
87            FastDivmod::PowerOfTwo { mask, .. } => dividend & *mask,
88        };
89
90        (q, r)
91    }
92}
93
94fn find_log2(x: u32) -> usize {
95    for i in 0..32 {
96        if (1 << i) >= x {
97            return i;
98        }
99    }
100    32
101}