lib_modulo/lib.rs
1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3pub mod factorize;
4pub mod prime;
5
6pub type Context64 = Context<u64>;
7pub type Context32 = Context<u32>;
8
9pub type Modulo64<'a> = Modulo<'a, u64>;
10pub type Modulo32<'a> = Modulo<'a, u32>;
11
12/// Storage of parameters for Montgomery multiplication.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct Context<U> {
15 // n inv_n = 1 (mod r = 2^32 or 2^64)
16 n: U,
17 inv_n: U,
18 r2_mod_n: U,
19}
20
21impl<U> Context<U> {
22 pub const fn modulus(&self) -> &U {
23 &self.n
24 }
25}
26
27/// Modulo with a runtime-specified odd modulus.
28///
29/// # Usage
30///
31/// ```
32/// use lib_modulo::Context64;
33///
34/// // runtime-specified *odd* modulus
35/// let modulus = 5;
36///
37/// let ctx = Context64::new(modulus); // slow
38/// let n = ctx.modulo(2) * ctx.modulo(3); // fast
39/// assert_eq!(n.get(), 1);
40/// ```
41///
42/// # Caution
43///
44/// [`Modulo`] values created from different [`Context`]s can technically interact,
45/// but the results will be meaningless.
46/// It is recommended to use a block to ensure that each [`Context`] is dropped
47/// before another one is introduced.
48#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
49pub struct Modulo<'a, U> {
50 // x r (mod n)
51 value: U,
52 ctx: &'a Context<U>,
53}
54
55macro_rules! montgomery_impl {
56 ( $single:ty, $double:ty ) => {
57 impl Context<$single> {
58 /// Calculates some parameters for Montgomery multiplication.
59 ///
60 /// # Panics
61 ///
62 /// - modulus `n` should be an odd number.
63 #[inline]
64 pub const fn new(n: $single) -> Self {
65 assert!(n & 1 == 1, "modulus should be an odd number");
66
67 let inv_n = {
68 const TABLE: u32 = {
69 // | n | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 |
70 // | inv_n | 1 | 11 | 13 | 7 | 9 | 3 | 5 | 15 | <- 4 bits * 8
71 let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
72
73 let mut table = 0;
74 let mut i = 0;
75 while i < 8 {
76 table |= inv_n[i] << (i * 4);
77 i += 1;
78 }
79
80 table
81 };
82 // n inv_n = 1 (mod 8)
83 let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as $single;
84
85 let mut d = const { <$single>::BITS.ilog2() - 2 };
86 while d > 0 {
87 inv_n =
88 inv_n.wrapping_mul((2 as $single).wrapping_sub(n.wrapping_mul(inv_n)));
89 d -= 1;
90 }
91 debug_assert!(n.wrapping_mul(inv_n) == 1);
92
93 inv_n
94 };
95 let r2_mod_n = ((n as $double).wrapping_neg() % (n as $double)) as $single;
96
97 Self { n, inv_n, r2_mod_n }
98 }
99
100 #[inline(always)]
101 pub const fn modulo(&self, x: $single) -> Modulo<'_, $single> {
102 // `x r2 < r n`
103 let x = self.mul(x, self.r2_mod_n);
104
105 Modulo {
106 value: x,
107 ctx: &self,
108 }
109 }
110
111 /// Performs Montgomery multiplication.
112 ///
113 /// if `lhs rhs < n r`, then `result < n`
114 #[inline(always)]
115 const fn mul(&self, lhs: $single, rhs: $single) -> $single {
116 self.mul_add(lhs, rhs, 0)
117 }
118
119 /// Performs `lhs rhs + add`.
120 ///
121 /// If `lhs rhs + add < n r`, then the result is less than `n`.
122 #[inline(always)]
123 const fn mul_add(&self, lhs: $single, rhs: $single, add: $single) -> $single {
124 // FIXME: use `a.widening_mul(b)`
125 let (x_hi, x_lo) = {
126 let x = lhs as $double * rhs as $double + add as $double;
127 ((x >> <$single>::BITS) as $single, x as $single)
128 };
129 // FIXME: use `mul_hi()`
130 // y = x n nn = x (mod r) => yl = x_lo
131 let y_hi = ((x_lo.wrapping_mul(self.inv_n) as $double * self.n as $double)
132 >> <$single>::BITS) as $single;
133 // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
134 let (z, b) = x_hi.overflowing_sub(y_hi);
135
136 // x < n r, y < n r => |z| < n
137 if b {
138 z.wrapping_add(self.n)
139 } else {
140 z
141 }
142 }
143
144 /// Checks whether `x` is multiple of `self`.
145 ///
146 /// # Example
147 ///
148 /// ```
149 /// use lib_modulo::Context;
150 ///
151 /// for n in (1..1 << 10).step_by(2) {
152 #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
153 ///
154 /// (0..1 << 10).for_each(|k| assert!(ctx.can_divide(n * k)));
155 /// }
156 /// ```
157 #[inline]
158 pub const fn can_divide(&self, x: $single) -> bool {
159 // x < n r
160 let x = self.mul(x, 1);
161 x == 0
162 }
163 }
164
165 impl<'a> Modulo<'a, $single> {
166 /// Returns value.
167 ///
168 /// # Example
169 ///
170 /// ```
171 /// use lib_modulo::Context;
172 ///
173 /// let n = 101;
174 #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
175 ///
176 /// let n = ctx.modulo(99);
177 ///
178 /// assert_eq!(n.get(), 99);
179 /// assert_eq!(n.modulus(), 101);
180 /// ```
181 #[inline(always)]
182 pub const fn get(&self) -> $single {
183 self.ctx.mul(self.value, 1)
184 }
185
186 /// Returns modulus.
187 ///
188 /// # Example
189 ///
190 /// ```
191 /// use lib_modulo::Context;
192 ///
193 /// let n = 101;
194 #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
195 ///
196 /// let n = ctx.modulo(99);
197 ///
198 /// assert_eq!(n.get(), 99);
199 /// assert_eq!(n.modulus(), 101);
200 /// ```
201 #[inline(always)]
202 pub const fn modulus(&self) -> $single {
203 self.ctx.n
204 }
205
206 /// Returns `true` if `self` is `0`.
207 ///
208 /// # Example
209 ///
210 /// ```
211 /// use lib_modulo::Context;
212 ///
213 /// for n in (1..100_000).step_by(2) {
214 #[doc = concat!(" let ctx = Context::<", stringify!($single), ">::new(n);")]
215 /// assert!(ctx.modulo(0).is_zero());
216 /// }
217 /// ```
218 #[inline(always)]
219 pub const fn is_zero(self) -> bool {
220 self.value == 0
221 }
222
223 /// Returns `0`.
224 ///
225 /// # Example
226 ///
227 /// ```
228 /// use lib_modulo::{Context, Modulo};
229 ///
230 /// for n in (1..100_000).step_by(2) {
231 #[doc = concat!(" let ctx = Context::<", stringify!($single), ">::new(n);")]
232 #[doc = concat!(" assert_eq!(Modulo::<'_, ", stringify!($single), ">::zero(&ctx).get(), 0);")]
233 /// }
234 /// ```
235 #[inline(always)]
236 pub const fn zero(ctx: &'a Context<$single>) -> Self {
237 Self { value: 0, ctx }
238 }
239
240 /// Raises `self` to the power of `exp`, using exponentiation by squaring.
241 ///
242 /// # Time complexity
243 ///
244 /// *O*(log `exp`)
245 ///
246 /// # Example
247 ///
248 /// ```
249 /// use lib_modulo::Context;
250 ///
251 /// let n = 12_345;
252 #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
253 ///
254 /// let mut pow10 = 1;
255 /// for i in 0..1_000 {
256 /// assert_eq!(ctx.modulo(10).pow(i).get(), pow10);
257 /// pow10 = pow10 * 10 % n;
258 /// }
259 /// ```
260 #[inline]
261 pub const fn pow(mut self, mut exp: $single) -> Self {
262 // r inv_r = 1 (mod n)
263 let mut result = self.ctx.modulo(1).value;
264
265 while exp > 0 {
266 if exp & 1 == 1 {
267 // n < r
268 result = self.ctx.mul(result, self.value)
269 }
270
271 exp >>= 1;
272 // n < r
273 self.value = self.ctx.mul(self.value, self.value)
274 }
275 self.value = result;
276
277 self
278 }
279
280 /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
281 ///
282 /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
283 ///
284 /// - `Ok(x)` : `x` is the modular inverse.
285 /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
286 /// where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
287 ///
288 /// # Time complexity
289 ///
290 /// *O*(log `self`)
291 ///
292 /// # Example
293 ///
294 /// ```
295 /// use lib_modulo::Context;
296 ///
297 /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
298 #[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(998_244_353);")]
299 ///
300 /// for n in 1..500_000 {
301 /// let n = ctx.modulo(n);
302 /// assert!(n.try_inv().is_ok_and(|i| (i * n).get() == 1));
303 /// }
304 /// // 0 n = 0 != 1 for any integer n
305 /// assert!(ctx.modulo(0).try_inv().is_err());
306 /// ```
307 #[inline]
308 pub const fn try_inv(self) -> Result<Self, $single> {
309 let mut a = self.get();
310 let Self { ctx, .. } = self;
311
312 // performs extended binary gcd
313 //
314 // invariants: a = [a] x, b = [a] y (mod n) where [a] is initial value
315 let mut b = ctx.n;
316 let mut x = ctx.modulo(1).value; // 1 r mod n
317 let mut y = 0; // 0 r mod n
318 let frac_1_2 = ctx.modulo(ctx.n.div_ceil(2));
319
320 while a > 0 {
321 x = ctx.mul(x, frac_1_2.pow(a.trailing_zeros() as $single).value);
322 a >>= a.trailing_zeros();
323
324 if a < b {
325 (a, b) = (b, a);
326 (x, y) = (y, x);
327 }
328 a -= b;
329 let (diff, b) = x.overflowing_sub(y);
330 x = if b { diff.wrapping_add(ctx.n) } else { diff };
331 }
332
333 // b = gcd([a], [b])
334 if b == 1 {
335 Ok(Self { value: y, ctx })
336 } else {
337 Err(b)
338 }
339 }
340 }
341
342 impl<'a> Add for Modulo<'a, $single> {
343 type Output = Self;
344
345 #[inline(always)]
346 fn add(mut self, rhs: Self) -> Self {
347 let (sum, b) = self.value.overflowing_add(rhs.value);
348 self.value = if b || sum >= self.ctx.n {
349 sum.wrapping_sub(self.ctx.n)
350 } else {
351 sum
352 };
353
354 self
355 }
356 }
357
358 impl<'a> Sub for Modulo<'a, $single> {
359 type Output = Self;
360
361 #[inline(always)]
362 fn sub(mut self, rhs: Self) -> Self {
363 let (diff, b) = self.value.overflowing_sub(rhs.value);
364 self.value = if b {
365 diff.wrapping_add(self.ctx.n)
366 } else {
367 diff
368 };
369
370 self
371 }
372 }
373
374 impl<'a> Mul for Modulo<'a, $single> {
375 type Output = Self;
376
377 #[inline(always)]
378 fn mul(mut self, rhs: Self) -> Self {
379 // n < r
380 self.value = self.ctx.mul(self.value, rhs.value);
381
382 self
383 }
384 }
385
386 impl<'a> Neg for Modulo<'a, $single> {
387 type Output = Self;
388
389 #[inline(always)]
390 fn neg(mut self) -> Self::Output {
391 // (x - x) r = 0 (mod n)
392 self.value = if self.value == 0 {
393 self.value
394 } else {
395 self.ctx.n - self.value
396 };
397
398 self
399 }
400 }
401 };
402}
403montgomery_impl!(u64, u128);
404montgomery_impl!(u32, u64);
405
406impl<'a, U> AddAssign for Modulo<'a, U>
407where
408 Self: Add<Output = Self> + Copy,
409{
410 #[inline(always)]
411 fn add_assign(&mut self, rhs: Self) {
412 *self = *self + rhs
413 }
414}
415
416impl<'a, U> SubAssign for Modulo<'a, U>
417where
418 Self: Sub<Output = Self> + Copy,
419{
420 #[inline(always)]
421 fn sub_assign(&mut self, rhs: Self) {
422 *self = *self - rhs
423 }
424}
425
426impl<'a, U> MulAssign for Modulo<'a, U>
427where
428 Self: Mul<Output = Self> + Copy,
429{
430 #[inline(always)]
431 fn mul_assign(&mut self, rhs: Self) {
432 *self = *self * rhs
433 }
434}