lib_modulo/residue32.rs
1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[derive(Debug, Clone, Hash, Eq)]
7pub struct Modulus32 {
8 // n inv_n = 1 (mod 2^64)
9 n: u64,
10 inv_n: u64,
11 // 2^128 (mod n)
12 init: u64,
13}
14
15impl Modulus32 {
16 /// Maximum available modulus.
17 pub const MAX_MODULUS: u32 = 2_654_435_769;
18
19 /// Creates new context for modular arithmetics.
20 ///
21 /// # Panics
22 ///
23 /// - modulus `n` should be an odd integer.
24 /// - modulus `n` should be no more than `2_654_435_769`,
25 /// which is the floor of `2^32 / GOLDEN_RATIO`.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// use lib_modulo::Modulus32;
31 ///
32 /// // odd integer less than or equal to 2_654_435_769 is allowed.
33 /// let modulus = Modulus32::new(Modulus32::MAX_MODULUS);
34 /// let modulus = Modulus32::new(3);
35 ///
36 /// // modulus should be an odd integer!
37 /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
38 /// ```
39 #[inline]
40 pub const fn new(n: u32) -> Self {
41 assert!(
42 n & 1 == 1,
43 "invalid modulus: modulus should be an odd integer."
44 );
45 assert!(
46 n <= Self::MAX_MODULUS,
47 "invalid modulus: modulus should be no more than 2_654_435_769."
48 );
49
50 let n = n as u64;
51
52 let inv_n = {
53 // 1 * 1 = 3 * 3 = 1 (mod 4)
54 let mut inv_n = n & 3;
55 // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
56 // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
57 let mut i = u64::BITS.ilog2() - 1;
58 while i > 0 {
59 i -= 1;
60 inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
61 }
62 debug_assert!(n.wrapping_mul(inv_n) == 1);
63
64 inv_n
65 };
66
67 let init = if std::mem::size_of::<usize>() >= std::mem::size_of::<u64>() {
68 ((n as u128).wrapping_neg() % n as u128) as u64
69 } else {
70 // 128 bit division on 32-bit machine will be slow (no experiment)
71 let pow_2_64 = n.wrapping_neg() % n;
72 let pow_2_128 = pow_2_64 * pow_2_64 % n;
73 pow_2_128
74 };
75
76 Self { n, inv_n, init }
77 }
78
79 /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
80 ///
81 /// If `x y < self.n`, then returned value is less than `self.n`.
82 #[inline(always)]
83 const fn mul(&self, x: u64, y: u64) -> u64 {
84 // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
85 // let z = x.wrapping_mul(y).wrapping_mul(self.inv_n) >> 32;
86 // let z = (z + 1).wrapping_mul(self.n) >> 32;
87
88 // if z == self.n {
89 // 0
90 // } else {
91 // z
92 // }
93
94 // modified Plantard multiplication
95 // z = (Q1 2^32 + (2^32 - 1)) P >> 64 (notation is the same as the paper)
96 let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) | u32::MAX as u64;
97 let z = ((z as u128 * self.n as u128) >> 64) as u64;
98 debug_assert!(z < self.n, "this is a bug in lib-modulo");
99 z
100 }
101
102 /// Calculates the residue of `x` modulo `self`.
103 ///
104 /// # Example
105 ///
106 /// ```
107 /// use lib_modulo::Modulus32;
108 ///
109 /// let modulus = Modulus32::new(5);
110 /// assert_eq!(modulus.residue(8).get(), 3)
111 /// ```
112 #[inline(always)]
113 pub const fn residue(&self, x: u32) -> Residue32<'_> {
114 let mut x = x as u64;
115
116 if x * self.init > !(self.n << 32) {
117 x %= self.n
118 }
119
120 Residue32 {
121 x: self.mul(x, self.init),
122 modulus: &self,
123 }
124 }
125
126 /// Checks whether `x` is divisible by `self`.
127 ///
128 /// # Example
129 ///
130 /// ```
131 /// use lib_modulo::Modulus32;
132 ///
133 /// let modulus = Modulus32::new(9);
134 /// assert!(modulus.can_divide(18));
135 /// assert!(!modulus.can_divide(19));
136 /// ```
137 #[inline(always)]
138 pub const fn can_divide(&self, x: u32) -> bool {
139 self.residue(x).is_zero()
140 }
141}
142
143impl PartialEq for Modulus32 {
144 fn eq(&self, other: &Self) -> bool {
145 self.n == other.n
146 }
147}
148
149/// Residue with odd modulus which is no more than `2_654_435_769`.
150///
151/// # Fast modular multiplication
152///
153/// [`Residue32`] provides fast modular multiplication called [Plantard multiplication].
154/// This method saves one multiplication when either of two values of a multiplication is used multiple times.
155/// Therefore, [`Residue32::pow`] will be faster than that using [Montgomery multiplication].
156///
157/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
158/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
159///
160/// # Usage
161///
162/// ```
163/// use lib_modulo::Modulus32;
164///
165/// // set modulus
166/// let modulus = Modulus32::new(3);
167///
168/// // performs modular arithmetics
169/// let one = modulus.residue(1);
170/// let two = modulus.residue(2);
171/// let five = modulus.residue(5);
172/// assert_eq!(two * five, one)
173/// ```
174///
175/// Two residues with different modulus can interact, but the result will be meaningless.
176/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
177#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
178pub struct Residue32<'a> {
179 // compare modulus first
180 modulus: &'a Modulus32,
181 x: u64,
182}
183
184impl<'a> Residue32<'a> {
185 /// Checks whether `self` is `0`.
186 ///
187 /// # Example
188 ///
189 /// ```
190 /// use lib_modulo::Modulus32;
191 ///
192 /// let modulus = Modulus32::new(5);
193 /// assert!(modulus.residue(10).is_zero())
194 /// ```
195 #[inline(always)]
196 pub const fn is_zero(self) -> bool {
197 self.x == 0
198 }
199
200 /// Returns the residue.
201 ///
202 /// # Example
203 ///
204 /// ```
205 /// use lib_modulo::Modulus32;
206 ///
207 /// let modulus = Modulus32::new(7);
208 /// assert_eq!(modulus.residue(10).get(), 3)
209 /// ```
210 #[inline(always)]
211 pub const fn get(self) -> u64 {
212 self.modulus.mul(self.x, 1)
213 }
214
215 /// Returns the modulus.
216 ///
217 /// # Example
218 ///
219 /// ```
220 /// use lib_modulo::Modulus32;
221 ///
222 /// let modulus = Modulus32::new(11);
223 /// assert_eq!(modulus.residue(2).modulus(), 11);
224 /// ```
225 #[inline(always)]
226 pub const fn modulus(&self) -> u64 {
227 self.modulus.n
228 }
229
230 /// Raises `self` to the power of `exp`, using exponentiation by squaring.
231 ///
232 /// # Time complexity
233 ///
234 /// *Θ*(log `exp`)
235 ///
236 /// # Example
237 ///
238 /// ```
239 /// use lib_modulo::Modulus32;
240 ///
241 /// let modulus = Modulus32::new(1001);
242 /// let residue = modulus.residue(2);
243 /// for exp in 0..64 {
244 /// assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
245 /// }
246 /// ```
247 #[inline(always)]
248 pub const fn pow(self, mut exp: u32) -> Self {
249 let Self { mut x, modulus } = self;
250 // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
251 let mut prod = modulus.mul(1, modulus.init);
252
253 while exp > 1 {
254 if exp & 1 == 1 {
255 // インライン展開されると,掛け算を1回節約できる。
256 prod = modulus.mul(prod, x)
257 }
258
259 exp >>= 1;
260 x = modulus.mul(x, x); // skip last useless one
261 }
262 if exp != 0 {
263 prod = modulus.mul(prod, x);
264 }
265
266 Self { x: prod, modulus }
267 }
268
269 /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
270 ///
271 /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
272 ///
273 /// - `Ok(x)` : `x` is the modular inverse.
274 /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
275 /// where `gcd(0, a)` is defined to be `a`.
276 ///
277 /// # Time complexity
278 ///
279 /// *O*(log `self`)
280 ///
281 /// # Example
282 ///
283 /// ```
284 /// use lib_modulo::Modulus32;
285 ///
286 /// let modulus = Modulus32::new(3 * 5);
287 ///
288 /// let residue = modulus.residue(2);
289 /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
290 ///
291 /// let residue = modulus.residue(6);
292 /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
293 /// ```
294 pub const fn try_inv(self) -> Result<Self, u64> {
295 // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
296 let mut a = self.get();
297 let mut b = self.modulus();
298 let Self { modulus, .. } = self;
299 let mut x = modulus.residue(1).x;
300 let mut y = 0;
301 let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
302
303 while a > 0 {
304 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
305 a >>= a.trailing_zeros();
306
307 if a < b {
308 (a, b) = (b, a);
309 (x, y) = (y, x);
310 }
311 a -= b;
312 let (z, b) = x.overflowing_sub(y);
313 x = if b { z.wrapping_add(modulus.n) } else { z };
314 }
315
316 // b = gcd([a], n)
317 if b == 1 {
318 Ok(Self { x: y, modulus })
319 } else {
320 Err(b)
321 }
322 }
323}
324
325impl<'a> Add for Residue32<'a> {
326 type Output = Self;
327
328 fn add(mut self, rhs: Self) -> Self::Output {
329 let (x, b) = self.x.overflowing_add(rhs.x);
330 self.x = if b || x >= self.modulus() {
331 x.wrapping_sub(self.modulus())
332 } else {
333 x
334 };
335
336 self
337 }
338}
339
340impl<'a> AddAssign for Residue32<'a> {
341 fn add_assign(&mut self, rhs: Self) {
342 *self = *self + rhs
343 }
344}
345
346impl<'a> Sub for Residue32<'a> {
347 type Output = Self;
348
349 fn sub(mut self, rhs: Self) -> Self::Output {
350 let (x, b) = self.x.overflowing_sub(rhs.x);
351 self.x = if b { x.wrapping_add(self.modulus()) } else { x };
352
353 self
354 }
355}
356
357impl<'a> SubAssign for Residue32<'a> {
358 fn sub_assign(&mut self, rhs: Self) {
359 *self = *self - rhs
360 }
361}
362
363impl<'a> Mul for Residue32<'a> {
364 type Output = Self;
365
366 fn mul(mut self, rhs: Self) -> Self::Output {
367 self.x = self.modulus.mul(self.x, rhs.x);
368 self
369 }
370}
371
372impl<'a> MulAssign for Residue32<'a> {
373 fn mul_assign(&mut self, rhs: Self) {
374 *self = *self * rhs
375 }
376}
377
378impl<'a> Neg for Residue32<'a> {
379 type Output = Self;
380
381 fn neg(mut self) -> Self::Output {
382 self.x = if self.x == 0 {
383 0
384 } else {
385 self.modulus() - self.x
386 };
387
388 self
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 use proptest::prelude::*;
397
398 proptest! {
399 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
400 #[test]
401 fn mul(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
402 let modulus = Modulus32::new(n);
403
404 let res = modulus.residue(x);
405 assert_eq!(res.get() as u32, x % n)
406 }
407 }
408
409 proptest! {
410 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
411 #[test]
412 fn pow(n in (0..=2_654_435_769_u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
413 let modulus = Modulus32::new(n as u32);
414
415 let res = modulus.residue(x as u32);
416 let mut naive = 1;
417 for i in 0..100 {
418 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
419 naive = naive * x % n
420 }
421 }
422 }
423
424 proptest! {
425 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
426 #[test]
427 fn divisible(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
428 let modulus = Modulus32::new(n);
429
430 assert_eq!(modulus.can_divide(x), x % n == 0);
431 for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
432 assert!(modulus.can_divide(m))
433 }
434 }
435 }
436
437 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
438 if b == 0 {
439 return a;
440 }
441
442 let shift = (a | b).trailing_zeros();
443 b >>= b.trailing_zeros();
444
445 while a != 0 {
446 a >>= a.trailing_zeros();
447
448 if a < b {
449 (a, b) = (b, a)
450 }
451 a -= b
452 }
453
454 b << shift
455 }
456
457 proptest! {
458 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
459 #[test]
460 fn try_inv(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
461 let modulus = Modulus32::new(n);
462 let res = modulus.residue(x);
463
464 match res.try_inv() {
465 Ok(inv) => assert_eq!((inv * res).get(), 1),
466 Err(gcd) => {
467 assert!(res.get() % gcd == 0);
468 assert!(res.modulus() % gcd == 0);
469 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
470 }
471 }
472 }
473 }
474}