cubecl_std/
fast_math.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_runtime::TypeUsage;
4
5/// Create a fast-divmod object if supported, or a regular fallback if not.
6/// This precalculates certain values on the host, in exchange for making division and modulo
7/// operations on the GPU much faster. Only supports u32 right now to allow for a simpler algorithm.
8/// It's mostly used for indices regardless.
9///
10/// Implementation based on ONNX:
11/// <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h>
12#[derive(CubeType, CubeLaunch, Clone, Copy)]
13pub enum FastDivmod {
14    Fast {
15        divisor: u32,
16        multiplier: u32,
17        shift_right: u32,
18    },
19    Fallback {
20        divisor: u32,
21    },
22}
23
24impl<R: Runtime> Clone for FastDivmodArgs<'_, R> {
25    fn clone(&self) -> Self {
26        *self
27    }
28}
29impl<R: Runtime> Copy for FastDivmodArgs<'_, R> {}
30
31impl<R: Runtime> FastDivmodArgs<'_, R> {
32    pub fn new(client: &ComputeClient<R::Server>, divisor: u32) -> Self {
33        debug_assert!(divisor != 0);
34
35        if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
36            return FastDivmodArgs::Fallback {
37                divisor: ScalarArg::new(divisor),
38            };
39        }
40
41        let div_64 = divisor as u64;
42        let shift = find_log2(divisor);
43        let multiplier = ((1u64 << 32) * ((1u64 << shift) - div_64)) / div_64 + 1;
44
45        FastDivmodArgs::Fast {
46            divisor: ScalarArg::new(divisor),
47            multiplier: ScalarArg::new(multiplier as u32),
48            shift_right: ScalarArg::new(shift as u32),
49        }
50    }
51}
52
53#[cube]
54impl FastDivmod {
55    pub fn div(&self, dividend: u32) -> u32 {
56        match self {
57            FastDivmod::Fast {
58                multiplier,
59                shift_right,
60                ..
61            } => {
62                let t = u32::mul_hi(dividend, *multiplier);
63                (t + dividend) >> shift_right
64            }
65            FastDivmod::Fallback { divisor } => dividend / divisor,
66        }
67    }
68
69    pub fn modulo(&self, dividend: u32) -> u32 {
70        let q = self.div(dividend);
71        match self {
72            FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
73            FastDivmod::Fallback { divisor } => dividend % divisor,
74        }
75    }
76
77    pub fn div_mod(&self, dividend: u32) -> (u32, u32) {
78        let q = self.div(dividend);
79        let r = match self {
80            FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
81            FastDivmod::Fallback { divisor } => dividend - q * divisor,
82        };
83
84        (q, r)
85    }
86}
87
88fn find_log2(x: u32) -> usize {
89    for i in 0..32 {
90        if (1 << i) >= x {
91            return i;
92        }
93    }
94    32
95}