1use crate::{DivExact, ModularUnaryOps};
2
3#[must_use]
8#[derive(Debug, Clone, Copy)]
9pub struct PreModInv<T> {
10 d_inv: T, q_lim: T, }
13
14macro_rules! impl_preinv_for_prim_int {
15 ($t:ident, $ns:ident) => {
16 mod $ns {
17 use super::*;
18 use crate::word::$t::*;
19
20 impl PreModInv<$t> {
21 #[inline]
26 pub const fn new(d_inv: $t, q_lim: $t) -> Self {
27 Self { d_inv, q_lim }
28 }
29
30 #[inline]
32 fn debug_check(&self, d: $t) {
33 debug_assert!(d % 2 != 0, "only odd divisors are supported");
34 debug_assert!(d.wrapping_mul(self.d_inv) == 1);
35 debug_assert!(self.q_lim * d > (<$t>::MAX - d));
36 }
37 }
38
39 impl From<$t> for PreModInv<$t> {
40 #[inline]
41 fn from(v: $t) -> Self {
42 use crate::word::$t::*;
43
44 debug_assert!(v % 2 != 0, "only odd divisors are supported");
45 let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t;
46 let q_lim = <$t>::MAX / v;
47 Self { d_inv, q_lim }
48 }
49 }
50
51 impl DivExact<$t, PreModInv<$t>> for $t {
52 type Output = $t;
53 #[inline]
54 fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self> {
55 pre.debug_check(d);
56 let q = self.wrapping_mul(pre.d_inv);
57 if q <= pre.q_lim {
58 Some(q)
59 } else {
60 None
61 }
62 }
63 }
64
65 impl DivExact<$t, PreModInv<$t>> for DoubleWord {
66 type Output = DoubleWord;
67
68 #[inline]
69 fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option<Self::Output> {
70 pre.debug_check(d);
71
72 let (n0, n1) = split(self);
76 let q0 = n0.wrapping_mul(pre.d_inv);
77 let nr0 = wmul(q0, d);
78 let nr0 = split(nr0).1;
79 if nr0 > n1 {
80 return None;
81 }
82 let nr1 = n1 - nr0;
83 let q1 = nr1.wrapping_mul(pre.d_inv);
84 if q1 > pre.q_lim {
85 return None;
86 }
87 Some(merge(q0, q1))
88 }
89 }
90 }
91 };
92}
93impl_preinv_for_prim_int!(u8, u8_impl);
94impl_preinv_for_prim_int!(u16, u16_impl);
95impl_preinv_for_prim_int!(u32, u32_impl);
96impl_preinv_for_prim_int!(u64, u64_impl);
97impl_preinv_for_prim_int!(usize, usize_impl);
98
99#[cfg(test)]
104mod tests {
105 use super::*;
106 use rand::random;
107
108 #[test]
109 #[allow(unstable_name_collisions)]
110 fn div_exact_test() {
111 const N: u8 = 100;
112 for _ in 0..N {
113 let d = random::<u8>() | 1;
115 let pre: PreModInv<_> = d.into();
116
117 let n: u8 = random();
118 let expect = if n % d == 0 { Some(n / d) } else { None };
119 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
120 let n: u16 = random();
121 let expect = if n % (d as u16) == 0 {
122 Some(n / (d as u16))
123 } else {
124 None
125 };
126 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
127
128 let d = random::<u16>() | 1;
130 let pre: PreModInv<_> = d.into();
131
132 let n: u16 = random();
133 let expect = if n % d == 0 { Some(n / d) } else { None };
134 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
135 let n: u32 = random();
136 let expect = if n % (d as u32) == 0 {
137 Some(n / (d as u32))
138 } else {
139 None
140 };
141 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
142
143 let d = random::<u32>() | 1;
145 let pre: PreModInv<_> = d.into();
146
147 let n: u32 = random();
148 let expect = if n % d == 0 { Some(n / d) } else { None };
149 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
150 let n: u64 = random();
151 let expect = if n % (d as u64) == 0 {
152 Some(n / (d as u64))
153 } else {
154 None
155 };
156 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
157
158 let d = random::<u64>() | 1;
160 let pre: PreModInv<_> = d.into();
161
162 let n: u64 = random();
163 let expect = if n % d == 0 { Some(n / d) } else { None };
164 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
165 let n: u128 = random();
166 let expect = if n % (d as u128) == 0 {
167 Some(n / (d as u128))
168 } else {
169 None
170 };
171 assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d);
172 }
173 }
174}