mod2k/
fast.rs

1//! Arithmetic modulo `2^8 - 1`, `2^16 - 1`, `2^32 - 1`, and `2^64 - 1`.
2//!
3//! Most combinations of operations are compiled as efficiently as possible. A notable exception is
4//! comparison with a constant: prefer `x.is::<C>()` over `x == FastK::new(C)`.
5
6use super::Mod;
7use core::hint::select_unpredictable;
8use core::ops::{Add, Mul, Neg, Shl, Shr, Sub};
9
10macro_rules! define_type {
11    (
12        #[$meta:meta]
13        $ty:ident as $native:ident,
14        test in $test_mod:ident,
15        carmichael = $carmichael:literal,
16        factorization = $factorization:expr,
17        inv_strategy = {$($inv_strategy:tt)*}
18    ) => {
19        // The `value` field stores some value equivalent to `x` modulo `2^k - 1`: specifically, `0`
20        // can be represented as either `0` or `2^k - 1`.
21        crate::macros::define_type_basics! {
22            #[$meta]
23            $ty as $native,
24            shr = true
25        }
26
27        impl $ty {
28            const CARMICHAEL: u64 = $carmichael;
29
30            // Calculate `x / 2^64 mod MODULUS`.
31            #[allow(unused, reason = "unused for Fast64 only")]
32            fn redc64(x: u64) -> Self {
33                Self::new((x % (Self::MODULUS as u64)) as $native)
34            }
35        }
36
37        impl crate::traits::sealed::Sealed for $ty {}
38        impl Mod for $ty {
39            type Native = $native;
40            const MODULUS: $native = $native::MAX;
41            const ZERO: Self = Self { value: 0 };
42            const ONE: Self = Self { value: 1 };
43
44            #[inline]
45            fn new(x: $native) -> Self {
46                Self { value: x }
47            }
48
49            #[inline]
50            unsafe fn new_unchecked(x: $native) -> Self {
51                debug_assert!(x < Self::MODULUS);
52                Self { value: x }
53            }
54
55            #[inline]
56            fn remainder(self) -> $native {
57                select_unpredictable(self.value == $native::MAX, 0, self.value)
58            }
59
60            #[inline]
61            fn to_raw(self) -> $native {
62                self.value
63            }
64
65            #[inline]
66            fn is<const C: u64>(self) -> bool {
67                const {
68                    assert!(C < Self::MODULUS as u64, "constant out of bounds");
69                }
70                if C == 0 {
71                    self.is_zero()
72                } else {
73                    self.value == C as $native
74                }
75            }
76
77            #[inline]
78            fn is_zero(&self) -> bool {
79                self.value == 0 || self.value == $native::MAX
80            }
81
82            fn pow(self, n: u64) -> Self {
83                if n == 0 {
84                    return Self::ONE;
85                }
86                // The existence of non-square-free Fermat numbers is an open problem, which means
87                // we can assume `2^k - 1` is square-free for all reasonable data types. A property
88                // of the Carmichael function guarantees
89                //     a^(n + lambda(m)) = a^n  (mod m)
90                // for *all* `a`, even those not coprime with `m`, as long as `n` >= the largest
91                // exponent in factorization (i.e. 1), which almost always allows us to take `n`
92                // modulo `lambda(m)`.
93                let new_n = if !Self::CARMICHAEL.is_power_of_two() && n <= Self::CARMICHAEL {
94                    // Branching to avoid modulo is only useful for non-power-of-two moduli. LLVM
95                    // can't infer that it's a no-op, so we enable it conditionally by hand.
96                    n
97                } else {
98                    (n - 1) % Self::CARMICHAEL + 1
99                };
100                self.pow_internal(new_n, Self::ONE)
101            }
102
103            fn is_invertible(&self) -> bool {
104                $factorization.iter().all(|p| self.value % *p != 0)
105            }
106
107            crate::xgcd::define_inverse!(prime = false, $($inv_strategy)*);
108        }
109
110        impl Add for $ty {
111            type Output = Self;
112
113            #[inline]
114            fn add(self, other: Self) -> Self {
115                let (sum, carry) = self.value.overflowing_add(other.value);
116                Self::new(sum.wrapping_add(carry as $native))
117            }
118        }
119
120        impl Sub for $ty {
121            type Output = Self;
122
123            #[inline]
124            fn sub(self, other: Self) -> Self {
125                let (diff, borrow) = self.value.overflowing_sub(other.value);
126                Self::new(diff.wrapping_sub(borrow as $native))
127            }
128        }
129
130        impl Mul for $ty {
131            type Output = Self;
132
133            #[inline]
134            #[allow(clippy::suspicious_arithmetic_impl, reason = "2^k mod (2^k - 1) = 1")]
135            fn mul(self, other: Self) -> Self {
136                let (low, high) = self.value.carrying_mul(other.value, 0);
137                Self::new(low) + Self::new(high)
138            }
139        }
140
141        impl Neg for $ty {
142            type Output = Self;
143
144            #[inline]
145            fn neg(self) -> Self {
146                Self::new(!self.value)
147            }
148        }
149
150        impl Shl<i64> for $ty {
151            type Output = Self;
152
153            #[inline]
154            fn shl(self, n: i64) -> Self {
155                Self::new(self.value.rotate_left(n as u32))
156            }
157        }
158
159        impl Shl<u64> for $ty {
160            type Output = Self;
161
162            #[inline]
163            fn shl(self, n: u64) -> Self {
164                Self::new(self.value.rotate_left(n as u32))
165            }
166        }
167
168        impl Shr<i64> for $ty {
169            type Output = Self;
170
171            #[inline]
172            fn shr(self, n: i64) -> Self {
173                Self::new(self.value.rotate_right(n as u32))
174            }
175        }
176
177        impl Shr<u64> for $ty {
178            type Output = Self;
179
180            #[inline]
181            fn shr(self, n: u64) -> Self {
182                Self::new(self.value.rotate_right(n as u32))
183            }
184        }
185
186        impl PartialEq for $ty {
187            #[inline]
188            fn eq(&self, other: &$ty) -> bool {
189                let (diff, borrow) = self.value.overflowing_sub(other.value);
190                let diff = diff.wrapping_sub(borrow as $native);
191                // Optimize comparison against a constant. This still produces suboptimal results
192                // (`sub + sbb + sete` instead of `cmp + sete`) [1], but it's better than nothing.
193                // [1]: https://github.com/llvm/llvm-project/issues/171676
194                if other.value != 0 {
195                    // SAFETY: If no overflow happened, `diff < self.value` and thus `diff < MAX`.
196                    // If overflow happened, initially `diff != 0`, so subtracting 1 cannot give
197                    // `diff == MAX`.
198                    unsafe {
199                        core::hint::assert_unchecked(diff != $native::MAX);
200                    }
201                }
202                diff == 0 || diff == $native::MAX
203            }
204        }
205
206        #[cfg(test)]
207        mod $test_mod {
208            use super::{Mod, $ty};
209
210            crate::macros::test_ty!($ty as $native, shr = true);
211            crate::macros::test_exact_raw!($ty as $native);
212        }
213    };
214}
215
216define_type! {
217    /// Arithmetic modulo `2^8 - 1 = 3 * 5 * 17`.
218    Fast8 as u8,
219    test in test8,
220    carmichael = 16,
221    factorization = [3, 5, 17],
222    inv_strategy = {long = false}
223}
224
225define_type! {
226    /// Arithmetic modulo `2^16 - 1 = 3 * 5 * 17 * 257`.
227    Fast16 as u16,
228    test in test16,
229    carmichael = 256,
230    factorization = [3, 5, 17, 257],
231    inv_strategy = {long = false}
232}
233
234define_type! {
235    /// Arithmetic modulo `2^32 - 1 = 3 * 5 * 17 * 257 * 65537`.
236    Fast32 as u32,
237    test in test32,
238    carmichael = 65536,
239    factorization = [3, 5, 17, 257, 65537],
240    inv_strategy = {long = false}
241}
242
243define_type! {
244    /// Arithmetic modulo `2^64 - 1 = 3 * 5 * 17 * 257 * 641 * 65537 * 6700417`.
245    Fast64 as u64,
246    test in test64,
247    carmichael = 17153064960,
248    factorization = [3, 5, 17, 257, 641, 65537, 6700417],
249    inv_strategy = {builtin}
250}
251
252#[cfg(doctest)]
253#[allow(dead_code, reason = "ad-hoc compile-fail test")]
254/// ```compile_fail
255/// use mod2k::Mod;
256/// mod2k::Fast8::ZERO.is::<{ u8::MAX as u64 }>();
257/// ```
258///
259/// ```compile_fail
260/// use mod2k::Mod;
261/// mod2k::Fast16::ZERO.is::<{ u16::MAX as u64 }>();
262/// ```
263///
264/// ```compile_fail
265/// use mod2k::Mod;
266/// mod2k::Fast32::ZERO.is::<{ u32::MAX as u64 }>();
267/// ```
268///
269/// ```compile_fail
270/// use mod2k::Mod;
271/// mod2k::Fast64::ZERO.is::<{ u64::MAX }>();
272/// ```
273fn test_is() {}