lib_modulo/residue32any.rs
1/// A modulus in `[2, 2^32)`, including even values.
2///
3/// # Fast modular multiplication
4///
5/// Provides fast modular multiplication using [Barrett multiplication].
6/// This works for any modulus (including even ones) and places no restrictions on the operands,
7/// but is generally slower than [`Residue32`](crate::Residue32).
8///
9/// Unlike Montgomery or Plantard methods, this operates directly on standard
10/// integer representations (i.e., no transformation is required).
11///
12/// [Barrett multiplication]: https://doi.org/10.1007/3-540-47721-7_24
13///
14/// # Usage
15///
16/// ```
17/// use lib_modulo::Modulus32Any;
18///
19/// let modulus = Modulus32Any::new(14);
20/// assert_eq!(modulus.mul(3, 5), 1)
21/// ```
22#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
23pub struct Modulus32Any {
24 // Lemire's remainder algorithm (N = 64, L = 32)
25 n: u64,
26 // magic number for Barrett multiplication and Lemire's remainder algorithm
27 // ceil(2^64 / n)
28 magic: u64,
29}
30
31impl Modulus32Any {
32 /// Creates new instance with the given modulus.
33 ///
34 /// # Panics
35 ///
36 /// Modulus `n` should be greater than 1.
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// use lib_modulo::Modulus32Any;
42 ///
43 /// // even number is available
44 /// let modulus = Modulus32Any::new(2);
45 /// // odd number is also available
46 /// let modulus = Modulus32Any::new(3);
47 /// // division by 0 is undefined
48 /// assert!(std::panic::catch_unwind(|| Modulus32Any::new(0)).is_err());
49 /// // meaningless division by 1 is NOT available for performance
50 /// assert!(std::panic::catch_unwind(|| Modulus32Any::new(1)).is_err());
51 /// ```
52 #[inline(always)]
53 pub const fn new(n: u32) -> Self {
54 assert!(n > 1);
55
56 let n = n as u64;
57 let magic = (u64::MAX / n).wrapping_add(1);
58
59 Self { n, magic }
60 }
61
62 /// Returns the modulus.
63 #[inline(always)]
64 pub const fn modulus(&self) -> u32 {
65 self.n as u32
66 }
67
68 /// Calculates residue of `x` modulo `self`.
69 #[inline(always)]
70 pub const fn residue32(&self, x: u32) -> u32 {
71 let lo = self.magic.wrapping_mul(x as u64);
72 ((lo as u128 * self.n as u128) >> 64) as u32
73 }
74
75 /// Calculates residue of `x` modulo `self`.
76 #[inline(always)]
77 pub const fn residue64(&self, x: u64) -> u64 {
78 let quot = ((x as u128 * self.magic as u128) >> 64) as u64;
79 let (rem, b) = x.overflowing_sub(quot * self.n);
80 if b {
81 rem.wrapping_add(self.n)
82 } else {
83 rem
84 }
85 }
86
87 /// Checks if `x` is divisible by `self`.
88 #[inline(always)]
89 pub const fn can_divide(&self, x: u32) -> bool {
90 // since `self.n` is not 1, `self.magic` never overflow
91 self.magic.wrapping_mul(x as u64) < self.magic
92 }
93
94 /// Performs modular multiplication.
95 ///
96 /// # Example
97 ///
98 /// ```
99 /// use lib_modulo::Modulus32Any;
100 ///
101 /// // even number is available
102 /// let modulus = Modulus32Any::new(1 << 8);
103 /// assert_eq!(modulus.mul(u32::MAX, u32::MAX), 1);
104 /// ```
105 #[inline(always)]
106 pub const fn mul(&self, a: u32, b: u32) -> u32 {
107 self.residue64(a as u64 * b as u64) as u32
108 }
109
110 /// Raises `x` to the power of `exp`, using exponentiation by squaring.
111 ///
112 /// # Time Complexity
113 ///
114 /// *O*(log `x`)
115 ///
116 /// # Example
117 ///
118 /// ```
119 /// use lib_modulo::Modulus32Any;
120 ///
121 /// let modulus = Modulus32Any::new(123_456);
122 ///
123 /// assert_eq!(modulus.pow(123_456 * 100 + 1, 1000), 1)
124 /// ```
125 #[inline(always)]
126 pub const fn pow(&self, mut x: u32, mut exp: u32) -> u32 {
127 let mut res = 1;
128 while exp > 0 {
129 if exp & 1 == 1 {
130 res = self.mul(res, x);
131 }
132 exp >>= 1;
133 x = self.mul(x, x);
134 }
135 res
136 }
137
138 /// Calculates the modular inverse of `x`, using extended gcd algorithm.
139 ///
140 /// Modular inverse can be defined if and only if `x` and the modulus is coprime.
141 ///
142 /// - `Ok(_)` : the modular inverse.
143 /// - `Err(_)`: the GCD of `x` and the `modulus`, where `gcd(0, a)` is defined to be `a`.
144 ///
145 /// # Time complexity
146 ///
147 /// *O*(log `x`)
148 ///
149 /// # Example
150 ///
151 /// ```
152 /// use lib_modulo::Modulus32Any;
153 ///
154 /// let modulus = Modulus32Any::new(3 * 5);
155 /// for (a, inv_a) in [(1, 1), (2, 8), (4, 4), (7, 13), (11, 11), (14, 14)] {
156 /// assert_eq!(modulus.mul(a, inv_a), 1);
157 /// assert_eq!(modulus.inv(a), Ok(inv_a));
158 /// }
159 /// for a in [3, 6, 9, 12] {
160 /// assert_eq!(modulus.inv(a), Err(3));
161 /// }
162 /// for a in [5, 10] {
163 /// assert_eq!(modulus.inv(a), Err(5));
164 /// }
165 /// // gcd(0, a) is defined t be `a`
166 /// assert_eq!(modulus.inv(0), Err(15));
167 /// assert_eq!(modulus.inv(15 * 99), Err(15));
168 /// ```
169 pub fn inv(&self, x: u32) -> Result<u32, u32> {
170 // invariant: a x0 = x, b x0 = y (mod [y])
171 let mut x = self.residue32(x) as i64;
172 let mut y = self.n as i64;
173 let [mut a, mut b] = [1, 0];
174
175 while x > 0 {
176 let (div, rem) = (y / x, y % x);
177
178 (x, y) = (rem, x);
179 (a, b) = (b - div * a, a);
180 }
181
182 // y = gcd(x0, y0) > 0
183 if y != 1 {
184 return Err(y as u32);
185 }
186 if b.is_negative() {
187 b += self.n as i64;
188 }
189
190 Ok(b as u32)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use proptest::prelude::*;
197 use rand::{random_iter, rng};
198
199 use super::Modulus32Any;
200
201 proptest! {
202 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
203 #[test]
204 fn mul(n in 2..=u32::MAX, a: u32, b: u32) {
205 let modulus = Modulus32Any::new(n);
206 assert_eq!(
207 modulus.mul(a, b),
208 (a as u64 * b as u64 % n as u64) as u32,
209 "{:?}", modulus
210 );
211 }
212 }
213
214 #[test]
215 fn mul_small() {
216 let mut rng = rng();
217 for n in 2..1 << 8 {
218 let modulus = Modulus32Any::new(n);
219 for _ in 0..1 << 12 {
220 let a = rng.random();
221 let b = rng.random();
222 assert_eq!(modulus.mul(a, b) as u64, (a as u64 * b as u64 % n as u64),)
223 }
224 }
225 }
226
227 proptest! {
228 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
229 #[test]
230 fn residue32(n in 2..=u32::MAX, a: u32) {
231 let modulus = Modulus32Any::new(n);
232 assert_eq!(modulus.residue32(a), a % n);
233 }
234 }
235
236 #[test]
237 fn residue32_small() {
238 for n in 2..1 << 8 {
239 let modulus = Modulus32Any::new(n);
240 for a in random_iter().take(1 << 12) {
241 assert_eq!(modulus.residue32(a), a % n)
242 }
243 }
244 }
245
246 proptest! {
247 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
248 #[test]
249 fn residue64(n in 2..=u32::MAX, a: u64) {
250 let modulus = Modulus32Any::new(n);
251 assert_eq!(modulus.residue64(a), a % n as u64);
252 }
253 }
254
255 #[test]
256 fn residue64_small() {
257 for n in 2..1 << 8 {
258 let modulus = Modulus32Any::new(n);
259 for a in random_iter().take(1 << 12) {
260 assert_eq!(modulus.residue64(a), a % n as u64)
261 }
262 }
263 }
264
265 fn binary_gcd(mut a: u32, mut b: u32) -> u32 {
266 if b == 0 {
267 return a;
268 }
269
270 let shift = (a | b).trailing_zeros();
271 b >>= b.trailing_zeros();
272
273 while a != 0 {
274 a >>= a.trailing_zeros();
275
276 if a < b {
277 (a, b) = (b, a)
278 }
279 a -= b
280 }
281
282 b << shift
283 }
284
285 proptest! {
286 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
287 #[test]
288 fn inv(n in 2..=u32::MAX, a: u32) {
289 let modulus = Modulus32Any::new(n);
290 match modulus.inv(a) {
291 Ok(inv) => assert_eq!(modulus.mul(a, inv), 1, "!"),
292 Err(gcd) => assert_eq!(gcd, binary_gcd(a, n), "?")
293 }
294 }
295 }
296}