Skip to main content

num_modular/
preinv.rs

1use crate::{DivExact, ModularUnaryOps};
2
3/// Pre-computing the modular inverse for fast divisibility check.
4///
5/// This struct stores the modular inverse of a divisor, and a limit for divisibility check.
6/// See <https://math.stackexchange.com/a/1251328> for the explanation of the trick
7#[must_use]
8#[derive(Debug, Clone, Copy)]
9pub struct PreModInv<T> {
10    d_inv: T, // modular inverse of divisor
11    q_lim: T, // limit of residue
12}
13
14macro_rules! impl_preinv_for_prim_int {
15    ($t:ident, $ns:ident) => {
16        mod $ns {
17            use super::*;
18            use crate::word::$t::*;
19
20            impl PreModInv<$t> {
21                /// Construct the preinv instance with raw values.
22                ///
23                /// This function can be used to initialize preinv in a constant context, the divisor d
24                /// is required only for verification of d_inv and q_lim.
25                #[inline]
26                pub const fn new(d_inv: $t, q_lim: $t) -> Self {
27                    Self { d_inv, q_lim }
28                }
29
30                // check if the divisor is consistent in debug mode
31                #[inline]
32                fn debug_check(&self, d: $t) {
33                    debug_assert!(d % 2 != 0, "only odd divisors are supported");
34                    debug_assert!(d.wrapping_mul(self.d_inv) == 1);
35                    debug_assert!(self.q_lim * d > (<$t>::MAX - d));
36                }
37            }
38
39            impl From<$t> for PreModInv<$t> {
40                #[inline]
41                fn from(v: $t) -> Self {
42                    use crate::word::$t::*;
43
44                    debug_assert!(v % 2 != 0, "only odd divisors are supported");
45                    let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t;
46                    let q_lim = <$t>::MAX / v;
47                    Self { d_inv, q_lim }
48                }
49            }
50
51            impl DivExact<$t, PreModInv<$t>> for $t {
52                type Output = $t;
53                #[inline]
54                fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self> {
55                    pre.debug_check(d);
56                    let q = self.wrapping_mul(pre.d_inv);
57                    if q <= pre.q_lim {
58                        Some(q)
59                    } else {
60                        None
61                    }
62                }
63            }
64
65            impl DivExact<$t, PreModInv<$t>> for DoubleWord {
66                type Output = DoubleWord;
67
68                #[inline]
69                fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self::Output> {
70                    pre.debug_check(d);
71
72                    // this implementation comes from GNU factor,
73                    // see https://math.stackexchange.com/q/4436380/815652 for explanation
74
75                    let (n0, n1) = split(self);
76                    let q0 = n0.wrapping_mul(pre.d_inv);
77                    let nr0 = wmul(q0, d);
78                    let nr0 = split(nr0).1;
79                    if nr0 > n1 {
80                        return None;
81                    }
82                    let nr1 = n1 - nr0;
83                    let q1 = nr1.wrapping_mul(pre.d_inv);
84                    if q1 > pre.q_lim {
85                        return None;
86                    }
87                    Some(merge(q0, q1))
88                }
89            }
90        }
91    };
92}
93impl_preinv_for_prim_int!(u8, u8_impl);
94impl_preinv_for_prim_int!(u16, u16_impl);
95impl_preinv_for_prim_int!(u32, u32_impl);
96impl_preinv_for_prim_int!(u64, u64_impl);
97impl_preinv_for_prim_int!(usize, usize_impl);
98
99// XXX: unchecked div_exact can be introduced by not checking the q_lim,
100//      investigate this after `exact_div` is introduced or removed from core lib
101//      https://github.com/rust-lang/rust/issues/85122
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use rand::random;
107
108    #[test]
109    #[allow(unstable_name_collisions)]
110    fn div_exact_test() {
111        const N: u8 = 100;
112        for _ in 0..N {
113            // u8 test
114            let d = random::<u8>() | 1;
115            let pre: PreModInv<_> = d.into();
116
117            let n: u8 = random();
118            let expect = if n % d == 0 { Some(n / d) } else { None };
119            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
120            let n: u16 = random();
121            let expect = if n % (d as u16) == 0 {
122                Some(n / (d as u16))
123            } else {
124                None
125            };
126            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
127
128            // u16 test
129            let d = random::<u16>() | 1;
130            let pre: PreModInv<_> = d.into();
131
132            let n: u16 = random();
133            let expect = if n % d == 0 { Some(n / d) } else { None };
134            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
135            let n: u32 = random();
136            let expect = if n % (d as u32) == 0 {
137                Some(n / (d as u32))
138            } else {
139                None
140            };
141            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
142
143            // u32 test
144            let d = random::<u32>() | 1;
145            let pre: PreModInv<_> = d.into();
146
147            let n: u32 = random();
148            let expect = if n % d == 0 { Some(n / d) } else { None };
149            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
150            let n: u64 = random();
151            let expect = if n % (d as u64) == 0 {
152                Some(n / (d as u64))
153            } else {
154                None
155            };
156            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
157
158            // u64 test
159            let d = random::<u64>() | 1;
160            let pre: PreModInv<_> = d.into();
161
162            let n: u64 = random();
163            let expect = if n % d == 0 { Some(n / d) } else { None };
164            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
165            let n: u128 = random();
166            let expect = if n % (d as u128) == 0 {
167                Some(n / (d as u128))
168            } else {
169                None
170            };
171            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
172        }
173    }
174}