tfhe_ntt/
fastdiv.rs

1use crate::u256;
2
3#[inline(always)]
4pub(crate) const fn mul128_u32(lowbits: u64, d: u32) -> u32 {
5    ((lowbits as u128 * d as u128) >> 64) as u32
6}
7
8#[inline(always)]
9pub(crate) const fn mul128_u64(lowbits: u128, d: u64) -> u64 {
10    let mut bottom_half = (lowbits & 0xFFFF_FFFF_FFFF_FFFF) * d as u128;
11    bottom_half >>= 64;
12    let top_half = (lowbits >> 64) * d as u128;
13    let both_halves = bottom_half + top_half;
14    (both_halves >> 64) as u64
15}
16
17#[inline(always)]
18pub(crate) const fn mul256_u128(lowbits: u256, d: u128) -> u128 {
19    lowbits.mul_u256_u128(d).1
20}
21
22#[inline(always)]
23pub(crate) const fn mul256_u64(lowbits: u256, d: u64) -> u64 {
24    lowbits.mul_u256_u64(d).1
25}
26
27/// Divisor representing a 32bit denominator.
28#[derive(Copy, Clone, Debug)]
29pub struct Div32 {
30    pub double_reciprocal: u128,
31    pub single_reciprocal: u64,
32    pub divisor: u32,
33}
34
35/// Divisor representing a 64bit denominator.
36#[derive(Copy, Clone, Debug)]
37pub struct Div64 {
38    pub double_reciprocal: u256,
39    pub single_reciprocal: u128,
40    pub divisor: u64,
41}
42
43impl Div32 {
44    /// Returns the division structure holding the given divisor.
45    ///
46    /// # Panics
47    /// Panics if the divisor is zero or one.
48    pub const fn new(divisor: u32) -> Self {
49        assert!(divisor > 1);
50        let single_reciprocal = (u64::MAX / divisor as u64) + 1;
51        let double_reciprocal = (u128::MAX / divisor as u128) + 1;
52
53        Self {
54            double_reciprocal,
55            single_reciprocal,
56            divisor,
57        }
58    }
59
60    /// Returns the quotient of the division of `n` by `d`.
61    #[inline(always)]
62    pub const fn div(n: u32, d: Self) -> u32 {
63        mul128_u32(d.single_reciprocal, n)
64    }
65
66    /// Returns the remainder of the division of `n` by `d`.
67    #[inline(always)]
68    pub const fn rem(n: u32, d: Self) -> u32 {
69        let low_bits = d.single_reciprocal.wrapping_mul(n as u64);
70        mul128_u32(low_bits, d.divisor)
71    }
72
73    /// Returns the quotient of the division of `n` by `d`.
74    #[inline(always)]
75    pub const fn div_u64(n: u64, d: Self) -> u64 {
76        mul128_u64(d.double_reciprocal, n)
77    }
78
79    /// Returns the remainder of the division of `n` by `d`.
80    #[inline(always)]
81    pub const fn rem_u64(n: u64, d: Self) -> u32 {
82        let low_bits = d.double_reciprocal.wrapping_mul(n as u128);
83        mul128_u64(low_bits, d.divisor as u64) as u32
84    }
85
86    /// Returns the internal divisor as an integer.
87    #[inline(always)]
88    pub const fn divisor(&self) -> u32 {
89        self.divisor
90    }
91}
92
93impl Div64 {
94    /// Returns the division structure holding the given divisor.
95    ///
96    /// # Panics
97    /// Panics if the divisor is zero or one.
98    pub const fn new(divisor: u64) -> Self {
99        assert!(divisor > 1);
100        let single_reciprocal = ((u128::MAX) / divisor as u128) + 1;
101        let double_reciprocal = u256::MAX
102            .div_rem_u256_u64(divisor)
103            .0
104            .overflowing_add(u256 {
105                x0: 1,
106                x1: 0,
107                x2: 0,
108                x3: 0,
109            })
110            .0;
111
112        Self {
113            double_reciprocal,
114            single_reciprocal,
115            divisor,
116        }
117    }
118
119    /// Returns the quotient of the division of `n` by `d`.
120    #[inline(always)]
121    pub const fn div(n: u64, d: Self) -> u64 {
122        mul128_u64(d.single_reciprocal, n)
123    }
124
125    /// Returns the remainder of the division of `n` by `d`.
126    #[inline(always)]
127    pub const fn rem(n: u64, d: Self) -> u64 {
128        let low_bits = d.single_reciprocal.wrapping_mul(n as u128);
129        mul128_u64(low_bits, d.divisor)
130    }
131
132    /// Returns the quotient of the division of `n` by `d`.
133    #[inline(always)]
134    pub const fn div_u128(n: u128, d: Self) -> u128 {
135        mul256_u128(d.double_reciprocal, n)
136    }
137
138    /// Returns the remainder of the division of `n` by `d`.
139    #[inline(always)]
140    pub const fn rem_u128(n: u128, d: Self) -> u64 {
141        let low_bits = d.double_reciprocal.wrapping_mul_u256_u128(n);
142        mul256_u64(low_bits, d.divisor)
143    }
144
145    /// Returns the internal divisor as an integer.
146    #[inline(always)]
147    pub const fn divisor(&self) -> u64 {
148        self.divisor
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use rand::random;
156
157    #[test]
158    fn test_div64() {
159        for _ in 0..1000 {
160            let divisor = loop {
161                let d = random();
162                if d > 1 {
163                    break d;
164                }
165            };
166
167            let div = Div64::new(divisor);
168            let n = random();
169            let m = random();
170            assert_eq!(Div64::div(m, div), m / divisor);
171            assert_eq!(Div64::rem(m, div), m % divisor);
172            assert_eq!(Div64::div_u128(n, div), n / divisor as u128);
173            assert_eq!(Div64::rem_u128(n, div) as u128, n % divisor as u128);
174        }
175    }
176
177    #[test]
178    fn test_div32() {
179        for _ in 0..1000 {
180            let divisor = loop {
181                let d = random();
182                if d > 1 {
183                    break d;
184                }
185            };
186
187            let div = Div32::new(divisor);
188            let n = random();
189            let m = random();
190            assert_eq!(Div32::div(m, div), m / divisor);
191            assert_eq!(Div32::rem(m, div), m % divisor);
192            assert_eq!(Div32::div_u64(n, div), n / divisor as u64);
193            assert_eq!(Div32::rem_u64(n, div) as u64, n % divisor as u64);
194        }
195    }
196}