Skip to main content

num_modular/
preinv.rs

1use crate::{DivExact, ModularUnaryOps};
2
3/// Pre-computing the modular inverse for fast divisibility check.
4///
5/// This struct stores the modular inverse of a divisor, and a limit for divisibility check.
6/// See <https://math.stackexchange.com/a/1251328> for the explanation of the trick
7#[derive(Debug, Clone, Copy)]
8pub struct PreModInv<T> {
9    d_inv: T, // modular inverse of divisor
10    q_lim: T, // limit of residue
11}
12
13macro_rules! impl_preinv_for_prim_int {
14    ($t:ident, $ns:ident) => {
15        mod $ns {
16            use super::*;
17            use crate::word::$t::*;
18
19            impl PreModInv<$t> {
20                /// Construct the preinv instance with raw values.
21                ///
22                /// This function can be used to initialize preinv in a constant context, the divisor d
23                /// is required only for verification of d_inv and q_lim.
24                #[inline]
25                pub const fn new(d_inv: $t, q_lim: $t) -> Self {
26                    Self { d_inv, q_lim }
27                }
28
29                // check if the divisor is consistent in debug mode
30                #[inline]
31                fn debug_check(&self, d: $t) {
32                    debug_assert!(d % 2 != 0, "only odd divisors are supported");
33                    debug_assert!(d.wrapping_mul(self.d_inv) == 1);
34                    debug_assert!(self.q_lim * d > (<$t>::MAX - d));
35                }
36            }
37
38            impl From<$t> for PreModInv<$t> {
39                #[inline]
40                fn from(v: $t) -> Self {
41                    use crate::word::$t::*;
42
43                    debug_assert!(v % 2 != 0, "only odd divisors are supported");
44                    let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t;
45                    let q_lim = <$t>::MAX / v;
46                    Self { d_inv, q_lim }
47                }
48            }
49
50            impl DivExact<$t, PreModInv<$t>> for $t {
51                type Output = $t;
52                #[inline]
53                fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self> {
54                    pre.debug_check(d);
55                    let q = self.wrapping_mul(pre.d_inv);
56                    if q <= pre.q_lim {
57                        Some(q)
58                    } else {
59                        None
60                    }
61                }
62            }
63
64            impl DivExact<$t, PreModInv<$t>> for DoubleWord {
65                type Output = DoubleWord;
66
67                #[inline]
68                fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self::Output> {
69                    pre.debug_check(d);
70
71                    // this implementation comes from GNU factor,
72                    // see https://math.stackexchange.com/q/4436380/815652 for explanation
73
74                    let (n0, n1) = split(self);
75                    let q0 = n0.wrapping_mul(pre.d_inv);
76                    let nr0 = wmul(q0, d);
77                    let nr0 = split(nr0).1;
78                    if nr0 > n1 {
79                        return None;
80                    }
81                    let nr1 = n1 - nr0;
82                    let q1 = nr1.wrapping_mul(pre.d_inv);
83                    if q1 > pre.q_lim {
84                        return None;
85                    }
86                    Some(merge(q0, q1))
87                }
88            }
89        }
90    };
91}
92impl_preinv_for_prim_int!(u8, u8_impl);
93impl_preinv_for_prim_int!(u16, u16_impl);
94impl_preinv_for_prim_int!(u32, u32_impl);
95impl_preinv_for_prim_int!(u64, u64_impl);
96impl_preinv_for_prim_int!(usize, usize_impl);
97
98// XXX: unchecked div_exact can be introduced by not checking the q_lim,
99//      investigate this after `exact_div` is introduced or removed from core lib
100//      https://github.com/rust-lang/rust/issues/85122
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use rand::random;
106
107    #[test]
108    #[allow(unstable_name_collisions)]
109    fn div_exact_test() {
110        const N: u8 = 100;
111        for _ in 0..N {
112            // u8 test
113            let d = random::<u8>() | 1;
114            let pre: PreModInv<_> = d.into();
115
116            let n: u8 = random();
117            let expect = if n % d == 0 { Some(n / d) } else { None };
118            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
119            let n: u16 = random();
120            let expect = if n % (d as u16) == 0 {
121                Some(n / (d as u16))
122            } else {
123                None
124            };
125            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
126
127            // u16 test
128            let d = random::<u16>() | 1;
129            let pre: PreModInv<_> = d.into();
130
131            let n: u16 = random();
132            let expect = if n % d == 0 { Some(n / d) } else { None };
133            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
134            let n: u32 = random();
135            let expect = if n % (d as u32) == 0 {
136                Some(n / (d as u32))
137            } else {
138                None
139            };
140            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
141
142            // u32 test
143            let d = random::<u32>() | 1;
144            let pre: PreModInv<_> = d.into();
145
146            let n: u32 = random();
147            let expect = if n % d == 0 { Some(n / d) } else { None };
148            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
149            let n: u64 = random();
150            let expect = if n % (d as u64) == 0 {
151                Some(n / (d as u64))
152            } else {
153                None
154            };
155            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
156
157            // u64 test
158            let d = random::<u64>() | 1;
159            let pre: PreModInv<_> = d.into();
160
161            let n: u64 = random();
162            let expect = if n % d == 0 { Some(n / d) } else { None };
163            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
164            let n: u128 = random();
165            let expect = if n % (d as u128) == 0 {
166                Some(n / (d as u128))
167            } else {
168                None
169            };
170            assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
171        }
172    }
173}