1use crate::{DivExact, ModularUnaryOps};
2
3#[derive(Debug, Clone, Copy)]
8pub struct PreModInv<T> {
9 d_inv: T, q_lim: T, }
12
13macro_rules! impl_preinv_for_prim_int {
14 ($t:ident, $ns:ident) => {
15 mod $ns {
16 use super::*;
17 use crate::word::$t::*;
18
19 impl PreModInv<$t> {
20 #[inline]
25 pub const fn new(d_inv: $t, q_lim: $t) -> Self {
26 Self { d_inv, q_lim }
27 }
28
29 #[inline]
31 fn debug_check(&self, d: $t) {
32 debug_assert!(d % 2 != 0, "only odd divisors are supported");
33 debug_assert!(d.wrapping_mul(self.d_inv) == 1);
34 debug_assert!(self.q_lim * d > (<$t>::MAX - d));
35 }
36 }
37
38 impl From<$t> for PreModInv<$t> {
39 #[inline]
40 fn from(v: $t) -> Self {
41 use crate::word::$t::*;
42
43 debug_assert!(v % 2 != 0, "only odd divisors are supported");
44 let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t;
45 let q_lim = <$t>::MAX / v;
46 Self { d_inv, q_lim }
47 }
48 }
49
50 impl DivExact<$t, PreModInv<$t>> for $t {
51 type Output = $t;
52 #[inline]
53 fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self> {
54 pre.debug_check(d);
55 let q = self.wrapping_mul(pre.d_inv);
56 if q <= pre.q_lim {
57 Some(q)
58 } else {
59 None
60 }
61 }
62 }
63
64 impl DivExact<$t, PreModInv<$t>> for DoubleWord {
65 type Output = DoubleWord;
66
67 #[inline]
68 fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self::Output> {
69 pre.debug_check(d);
70
71 let (n0, n1) = split(self);
75 let q0 = n0.wrapping_mul(pre.d_inv);
76 let nr0 = wmul(q0, d);
77 let nr0 = split(nr0).1;
78 if nr0 > n1 {
79 return None;
80 }
81 let nr1 = n1 - nr0;
82 let q1 = nr1.wrapping_mul(pre.d_inv);
83 if q1 > pre.q_lim {
84 return None;
85 }
86 Some(merge(q0, q1))
87 }
88 }
89 }
90 };
91}
92impl_preinv_for_prim_int!(u8, u8_impl);
93impl_preinv_for_prim_int!(u16, u16_impl);
94impl_preinv_for_prim_int!(u32, u32_impl);
95impl_preinv_for_prim_int!(u64, u64_impl);
96impl_preinv_for_prim_int!(usize, usize_impl);
97
98#[cfg(test)]
103mod tests {
104 use super::*;
105 use rand::random;
106
107 #[test]
108 #[allow(unstable_name_collisions)]
109 fn div_exact_test() {
110 const N: u8 = 100;
111 for _ in 0..N {
112 let d = random::<u8>() | 1;
114 let pre: PreModInv<_> = d.into();
115
116 let n: u8 = random();
117 let expect = if n % d == 0 { Some(n / d) } else { None };
118 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
119 let n: u16 = random();
120 let expect = if n % (d as u16) == 0 {
121 Some(n / (d as u16))
122 } else {
123 None
124 };
125 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
126
127 let d = random::<u16>() | 1;
129 let pre: PreModInv<_> = d.into();
130
131 let n: u16 = random();
132 let expect = if n % d == 0 { Some(n / d) } else { None };
133 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
134 let n: u32 = random();
135 let expect = if n % (d as u32) == 0 {
136 Some(n / (d as u32))
137 } else {
138 None
139 };
140 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
141
142 let d = random::<u32>() | 1;
144 let pre: PreModInv<_> = d.into();
145
146 let n: u32 = random();
147 let expect = if n % d == 0 { Some(n / d) } else { None };
148 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
149 let n: u64 = random();
150 let expect = if n % (d as u64) == 0 {
151 Some(n / (d as u64))
152 } else {
153 None
154 };
155 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
156
157 let d = random::<u64>() | 1;
159 let pre: PreModInv<_> = d.into();
160
161 let n: u64 = random();
162 let expect = if n % d == 0 { Some(n / d) } else { None };
163 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
164 let n: u128 = random();
165 let expect = if n % (d as u128) == 0 {
166 Some(n / (d as u128))
167 } else {
168 None
169 };
170 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
171 }
172 }
173}