Skip to main content

lib_modulo/
lib.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3pub mod factorize;
4pub mod prime;
5
6pub type Context64 = Context<u64>;
7pub type Context32 = Context<u32>;
8
9pub type Modulo64<'a> = Modulo<'a, u64>;
10pub type Modulo32<'a> = Modulo<'a, u32>;
11
12/// Storage of parameters for Montgomery multiplication.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct Context<U> {
15    // n inv_n = 1 (mod r = 2^32 or 2^64)
16    n: U,
17    inv_n: U,
18    r2_mod_n: U,
19}
20
21impl<U> Context<U> {
22    pub const fn modulus(&self) -> &U {
23        &self.n
24    }
25}
26
27/// Modulo with a runtime-specified odd modulus.
28///
29/// # Usage
30///
31/// ```
32/// use lib_modulo::Context64;
33///
34/// // runtime-specified *odd* modulus
35/// let modulus = 5;
36///
37/// let ctx = Context64::new(modulus); // slow
38/// let n = ctx.modulo(2) * ctx.modulo(3); // fast
39/// assert_eq!(n.get(), 1);
40/// ```
41///
42/// # Caution
43///
44/// [`Modulo`] values created from different [`Context`]s can technically interact,
45/// but the results will be meaningless.
46/// It is recommended to use a block to ensure that each [`Context`] is dropped
47/// before another one is introduced.
48#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
49pub struct Modulo<'a, U> {
50    // x r (mod n)
51    value: U,
52    ctx: &'a Context<U>,
53}
54
55macro_rules! montgomery_impl {
56    ( $single:ty, $double:ty ) => {
57        impl Context<$single> {
58            /// Calculates some parameters for Montgomery multiplication.
59            ///
60            /// # Panics
61            ///
62            /// - modulus `n` should be an odd number.
63            #[inline]
64            pub const fn new(n: $single) -> Self {
65                assert!(n & 1 == 1, "modulus should be an odd number");
66
67                let inv_n = {
68                    const TABLE: u32 = {
69                        // | n     | 1 | 3  | 5  | 7 | 9 | 11 | 13 | 15 |
70                        // | inv_n | 1 | 11 | 13 | 7 | 9 | 3  | 5  | 15 | <- 4 bits * 8
71                        let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
72
73                        let mut table = 0;
74                        let mut i = 0;
75                        while i < 8 {
76                            table |= inv_n[i] << (i * 4);
77                            i += 1;
78                        }
79
80                        table
81                    };
82                    // n inv_n = 1 (mod 8)
83                    let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as $single;
84
85                    let mut d = const { <$single>::BITS.ilog2() - 2 };
86                    while d > 0 {
87                        inv_n =
88                            inv_n.wrapping_mul((2 as $single).wrapping_sub(n.wrapping_mul(inv_n)));
89                        d -= 1;
90                    }
91                    debug_assert!(n.wrapping_mul(inv_n) == 1);
92
93                    inv_n
94                };
95                let r2_mod_n = ((n as $double).wrapping_neg() % (n as $double)) as $single;
96
97                Self { n, inv_n, r2_mod_n }
98            }
99
100            #[inline(always)]
101            pub const fn modulo(&self, x: $single) -> Modulo<'_, $single> {
102                // `x r2 < r n`
103                let x = self.mul(x, self.r2_mod_n);
104
105                Modulo {
106                    value: x,
107                    ctx: &self,
108                }
109            }
110
111            /// Performs Montgomery multiplication.
112            ///
113            /// if `lhs rhs < n r`, then `result < n`
114            #[inline(always)]
115            const fn mul(&self, lhs: $single, rhs: $single) -> $single {
116                self.mul_add(lhs, rhs, 0)
117            }
118
119            /// Performs `lhs rhs + add`.
120            ///
121            /// If `lhs rhs + add < n r`, then the result is less than `n`.
122            #[inline(always)]
123            const fn mul_add(&self, lhs: $single, rhs: $single, add: $single) -> $single {
124                // FIXME: use `a.widening_mul(b)`
125                let (x_hi, x_lo) = {
126                    let x = lhs as $double * rhs as $double + add as $double;
127                    ((x >> <$single>::BITS) as $single, x as $single)
128                };
129                // FIXME: use `mul_hi()`
130                // y = x n nn = x (mod r) => yl = x_lo
131                let y_hi = ((x_lo.wrapping_mul(self.inv_n) as $double * self.n as $double)
132                    >> <$single>::BITS) as $single;
133                // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
134                let (z, b) = x_hi.overflowing_sub(y_hi);
135
136                // x < n r, y < n r => |z| < n
137                if b {
138                    z.wrapping_add(self.n)
139                } else {
140                    z
141                }
142            }
143
144            /// Checks whether `x` is multiple of `self`.
145            ///
146            /// # Example
147            ///
148            /// ```
149            /// use lib_modulo::Context;
150            ///
151            /// for n in (1..1 << 10).step_by(2) {
152            #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
153            ///
154            ///     (0..1 << 10).for_each(|k| assert!(ctx.can_divide(n * k)));
155            /// }
156            /// ```
157            #[inline]
158            pub const fn can_divide(&self, x: $single) -> bool {
159                // x < n r
160                let x = self.mul(x, 1);
161                x == 0
162            }
163        }
164
165        impl<'a> Modulo<'a, $single> {
166            /// Returns value.
167            ///
168            /// # Example
169            ///
170            /// ```
171            /// use lib_modulo::Context;
172            ///
173            /// let n = 101;
174            #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
175            ///
176            /// let n = ctx.modulo(99);
177            ///
178            /// assert_eq!(n.get(), 99);
179            /// assert_eq!(n.modulus(), 101);
180            /// ```
181            #[inline(always)]
182            pub const fn get(&self) -> $single {
183                self.ctx.mul(self.value, 1)
184            }
185
186            /// Returns modulus.
187            ///
188            /// # Example
189            ///
190            /// ```
191            /// use lib_modulo::Context;
192            ///
193            /// let n = 101;
194            #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
195            ///
196            /// let n = ctx.modulo(99);
197            ///
198            /// assert_eq!(n.get(), 99);
199            /// assert_eq!(n.modulus(), 101);
200            /// ```
201            #[inline(always)]
202            pub const fn modulus(&self) -> $single {
203                self.ctx.n
204            }
205
206            /// Returns `true` if `self` is `0`.
207            ///
208            /// # Example
209            ///
210            /// ```
211            /// use lib_modulo::Context;
212            ///
213            /// for n in (1..100_000).step_by(2) {
214            #[doc = concat!("    let ctx = Context::<", stringify!($single), ">::new(n);")]
215            ///     assert!(ctx.modulo(0).is_zero());
216            /// }
217            /// ```
218            #[inline(always)]
219            pub const fn is_zero(self) -> bool {
220                self.value == 0
221            }
222
223            /// Returns `0`.
224            ///
225            /// # Example
226            ///
227            /// ```
228            /// use lib_modulo::{Context, Modulo};
229            ///
230            /// for n in (1..100_000).step_by(2) {
231            #[doc = concat!("    let ctx = Context::<", stringify!($single), ">::new(n);")]
232            #[doc = concat!("    assert_eq!(Modulo::<'_, ", stringify!($single), ">::zero(&ctx).get(), 0);")]
233            /// }
234            /// ```
235            #[inline(always)]
236            pub const fn zero(ctx: &'a Context<$single>) -> Self {
237                Self { value: 0, ctx }
238            }
239
240            /// Raises `self` to the power of `exp`, using exponentiation by squaring.
241            ///
242            /// # Time complexity
243            ///
244            /// *O*(log `exp`)
245            ///
246            /// # Example
247            ///
248            /// ```
249            /// use lib_modulo::Context;
250            ///
251            /// let n = 12_345;
252            #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
253            ///
254            /// let mut pow10 = 1;
255            /// for i in 0..1_000 {
256            ///     assert_eq!(ctx.modulo(10).pow(i).get(), pow10);
257            ///     pow10 = pow10 * 10 % n;
258            /// }
259            /// ```
260            #[inline]
261            pub const fn pow(mut self, mut exp: $single) -> Self {
262                // r inv_r = 1 (mod n)
263                let mut result = self.ctx.modulo(1).value;
264
265                while exp > 0 {
266                    if exp & 1 == 1 {
267                        // n < r
268                        result = self.ctx.mul(result, self.value)
269                    }
270
271                    exp >>= 1;
272                    // n < r
273                    self.value = self.ctx.mul(self.value, self.value)
274                }
275                self.value = result;
276
277                self
278            }
279
280            /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
281            ///
282            /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
283            ///
284            /// - `Ok(x)` : `x` is the modular inverse.
285            /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
286            /// where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
287            ///
288            /// # Time complexity
289            ///
290            /// *O*(log `self`)
291            ///
292            /// # Example
293            ///
294            /// ```
295            /// use lib_modulo::Context;
296            ///
297            /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
298            #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(998_244_353);")]
299            ///
300            /// for n in 1..500_000 {
301            ///     let n = ctx.modulo(n);
302            ///     assert!(n.try_inv().is_ok_and(|i| (i * n).get() == 1));
303            /// }
304            /// // 0 n = 0 != 1 for any integer n
305            /// assert!(ctx.modulo(0).try_inv().is_err());
306            /// ```
307            #[inline]
308            pub const fn try_inv(self) -> Result<Self, $single> {
309                let mut a = self.get();
310                let Self { ctx, .. } = self;
311
312                // performs extended binary gcd
313                //
314                // invariants: a = [a] x,  b = [a] y (mod n) where [a] is initial value
315                let mut b = ctx.n;
316                let mut x = ctx.modulo(1).value; // 1 r mod n
317                let mut y = 0; // 0 r mod n
318                let frac_1_2 = ctx.modulo(ctx.n.div_ceil(2));
319
320                while a > 0 {
321                    x = ctx.mul(x, frac_1_2.pow(a.trailing_zeros() as $single).value);
322                    a >>= a.trailing_zeros();
323
324                    if a < b {
325                        (a, b) = (b, a);
326                        (x, y) = (y, x);
327                    }
328                    a -= b;
329                    let (diff, b) = x.overflowing_sub(y);
330                    x = if b { diff.wrapping_add(ctx.n) } else { diff };
331                }
332
333                // b = gcd([a], [b])
334                if b == 1 {
335                    Ok(Self { value: y, ctx })
336                } else {
337                    Err(b)
338                }
339            }
340        }
341
342        impl<'a> Add for Modulo<'a, $single> {
343            type Output = Self;
344
345            #[inline(always)]
346            fn add(mut self, rhs: Self) -> Self {
347                let (sum, b) = self.value.overflowing_add(rhs.value);
348                self.value = if b || sum >= self.ctx.n {
349                    sum.wrapping_sub(self.ctx.n)
350                } else {
351                    sum
352                };
353
354                self
355            }
356        }
357
358        impl<'a> Sub for Modulo<'a, $single> {
359            type Output = Self;
360
361            #[inline(always)]
362            fn sub(mut self, rhs: Self) -> Self {
363                let (diff, b) = self.value.overflowing_sub(rhs.value);
364                self.value = if b {
365                    diff.wrapping_add(self.ctx.n)
366                } else {
367                    diff
368                };
369
370                self
371            }
372        }
373
374        impl<'a> Mul for Modulo<'a, $single> {
375            type Output = Self;
376
377            #[inline(always)]
378            fn mul(mut self, rhs: Self) -> Self {
379                // n < r
380                self.value = self.ctx.mul(self.value, rhs.value);
381
382                self
383            }
384        }
385
386        impl<'a> Neg for Modulo<'a, $single> {
387            type Output = Self;
388
389            #[inline(always)]
390            fn neg(mut self) -> Self::Output {
391                // (x - x) r = 0 (mod n)
392                self.value = if self.value == 0 {
393                    self.value
394                } else {
395                    self.ctx.n - self.value
396                };
397
398                self
399            }
400        }
401    };
402}
403montgomery_impl!(u64, u128);
404montgomery_impl!(u32, u64);
405
406impl<'a, U> AddAssign for Modulo<'a, U>
407where
408    Self: Add<Output = Self> + Copy,
409{
410    #[inline(always)]
411    fn add_assign(&mut self, rhs: Self) {
412        *self = *self + rhs
413    }
414}
415
416impl<'a, U> SubAssign for Modulo<'a, U>
417where
418    Self: Sub<Output = Self> + Copy,
419{
420    #[inline(always)]
421    fn sub_assign(&mut self, rhs: Self) {
422        *self = *self - rhs
423    }
424}
425
426impl<'a, U> MulAssign for Modulo<'a, U>
427where
428    Self: Mul<Output = Self> + Copy,
429{
430    #[inline(always)]
431    fn mul_assign(&mut self, rhs: Self) {
432        *self = *self * rhs
433    }
434}