1use crate::{ModularAbs, ModularCoreOps, ModularPow, ModularSymbols, ModularUnaryOps};
2use core::convert::TryInto;
3use num_integer::Integer;
4use num_traits::{One, ToPrimitive, Zero};
5
6macro_rules! impl_mod_ops_by_ref {
11 ($T:ty) => {
12 impl ModularCoreOps<$T, &$T> for &$T {
14 type Output = $T;
15 #[inline]
16 fn addm(self, rhs: $T, m: &$T) -> $T {
17 self.addm(&rhs, &m)
18 }
19 #[inline]
20 fn subm(self, rhs: $T, m: &$T) -> $T {
21 self.subm(&rhs, &m)
22 }
23 #[inline]
24 fn mulm(self, rhs: $T, m: &$T) -> $T {
25 self.mulm(&rhs, &m)
26 }
27 }
28 impl ModularCoreOps<&$T, &$T> for $T {
29 type Output = $T;
30 #[inline]
31 fn addm(self, rhs: &$T, m: &$T) -> $T {
32 (&self).addm(rhs, &m)
33 }
34 #[inline]
35 fn subm(self, rhs: &$T, m: &$T) -> $T {
36 (&self).subm(rhs, &m)
37 }
38 #[inline]
39 fn mulm(self, rhs: &$T, m: &$T) -> $T {
40 (&self).mulm(rhs, &m)
41 }
42 }
43 impl ModularCoreOps<$T, &$T> for $T {
44 type Output = $T;
45 #[inline]
46 fn addm(self, rhs: $T, m: &$T) -> $T {
47 (&self).addm(&rhs, &m)
48 }
49 #[inline]
50 fn subm(self, rhs: $T, m: &$T) -> $T {
51 (&self).subm(&rhs, &m)
52 }
53 #[inline]
54 fn mulm(self, rhs: $T, m: &$T) -> $T {
55 (&self).mulm(&rhs, &m)
56 }
57 }
58
59 impl ModularPow<$T, &$T> for &$T {
61 type Output = $T;
62 #[inline]
63 fn powm(self, exp: $T, m: &$T) -> $T {
64 self.powm(&exp, &m)
65 }
66 }
67 impl ModularPow<&$T, &$T> for $T {
68 type Output = $T;
69 #[inline]
70 fn powm(self, exp: &$T, m: &$T) -> $T {
71 (&self).powm(exp, &m)
72 }
73 }
74 impl ModularPow<$T, &$T> for $T {
75 type Output = $T;
76 #[inline]
77 fn powm(self, exp: $T, m: &$T) -> $T {
78 (&self).powm(&exp, &m)
79 }
80 }
81
82 impl ModularUnaryOps<&$T> for $T {
84 type Output = $T;
85 #[inline]
86 fn negm(self, m: &$T) -> $T {
87 ModularUnaryOps::<&$T>::negm(&self, m)
88 }
89 #[inline]
90 fn invm(self, m: &$T) -> Option<$T> {
91 ModularUnaryOps::<&$T>::invm(&self, m)
92 }
93 #[inline]
94 fn dblm(self, m: &$T) -> $T {
95 ModularUnaryOps::<&$T>::dblm(&self, m)
96 }
97 #[inline]
98 fn sqm(self, m: &$T) -> $T {
99 ModularUnaryOps::<&$T>::sqm(&self, m)
100 }
101 }
102 };
103}
104
105#[cfg(feature = "num-bigint")]
106mod _num_bigint {
107 use super::*;
108 use num_bigint::{BigInt, BigUint};
109 use num_traits::Signed;
110
111 impl ModularCoreOps<&BigUint, &BigUint> for &BigUint {
112 type Output = BigUint;
113
114 #[inline]
115 fn addm(self, rhs: &BigUint, m: &BigUint) -> BigUint {
116 (self + rhs) % m
117 }
118 fn subm(self, rhs: &BigUint, m: &BigUint) -> BigUint {
119 let (lhs, rhs) = (self % m, rhs % m);
120 if lhs >= rhs {
121 lhs - rhs
122 } else {
123 m - (rhs - lhs)
124 }
125 }
126
127 fn mulm(self, rhs: &BigUint, m: &BigUint) -> BigUint {
128 let a = self % m;
129 let b = rhs % m;
130
131 if let Some(sm) = m.to_usize() {
132 let a_usize = a.to_usize().unwrap();
133 let srhs = b.to_usize().unwrap();
134 return BigUint::from(a_usize.mulm(srhs, &sm));
135 }
136
137 (a * b) % m
138 }
139 }
140
141 impl ModularUnaryOps<&BigUint> for &BigUint {
142 type Output = BigUint;
143 #[inline]
144 fn negm(self, m: &BigUint) -> BigUint {
145 let x = self % m;
146 if x.is_zero() {
147 BigUint::zero()
148 } else {
149 m - x
150 }
151 }
152
153 fn invm(self, m: &BigUint) -> Option<Self::Output> {
154 let x = if self >= m { self % m } else { self.clone() };
155
156 let (mut last_r, mut r) = (m.clone(), x);
157 let (mut last_t, mut t) = (BigUint::zero(), BigUint::one());
158
159 while r > BigUint::zero() {
160 let (quo, rem) = last_r.div_rem(&r);
161 last_r = r;
162 r = rem;
163
164 let new_t = last_t.subm(&quo.mulm(&t, m), m);
165 last_t = t;
166 t = new_t;
167 }
168
169 if last_r > BigUint::one() {
171 None
172 } else {
173 Some(last_t)
174 }
175 }
176
177 #[inline]
178 fn dblm(self, m: &BigUint) -> BigUint {
179 let x = self % m;
180 let d = x << 1;
181 if &d >= m {
182 d - m
183 } else {
184 d
185 }
186 }
187
188 #[inline]
189 fn sqm(self, m: &BigUint) -> BigUint {
190 (self * self) % m
191 }
192 }
193
194 impl ModularPow<&BigUint, &BigUint> for &BigUint {
195 type Output = BigUint;
196 #[inline]
197 fn powm(self, exp: &BigUint, m: &BigUint) -> BigUint {
198 self.modpow(exp, m)
199 }
200 }
201
202 impl ModularSymbols<&BigUint> for BigUint {
203 #[inline]
204 fn checked_legendre(&self, n: &BigUint) -> Option<i8> {
205 let r = self.powm((n - 1u8) >> 1u8, n);
206 if r.is_zero() {
207 Some(0)
208 } else if r.is_one() {
209 Some(1)
210 } else if &(r + 1u8) == n {
211 Some(-1)
212 } else {
213 None
214 }
215 }
216
217 fn checked_jacobi(&self, n: &BigUint) -> Option<i8> {
218 if n.is_even() {
219 return None;
220 }
221 if self.is_zero() {
222 return Some(if n.is_one() { 1 } else { 0 });
223 }
224 if self.is_one() {
225 return Some(1);
226 }
227
228 let three = BigUint::from(3u8);
229 let five = BigUint::from(5u8);
230 let seven = BigUint::from(7u8);
231
232 let mut a = self % n;
233 let mut n = n.clone();
234 let mut t = 1;
235 while a > BigUint::zero() {
236 while a.is_even() {
237 a >>= 1;
238 if &n & &seven == three || &n & &seven == five {
239 t *= -1;
240 }
241 }
242 core::mem::swap(&mut a, &mut n);
243 if (&a & &three) == three && (&n & &three) == three {
244 t *= -1;
245 }
246 a %= &n;
247 }
248 Some(if n.is_one() { t } else { 0 })
249 }
250
251 #[inline]
252 fn kronecker(&self, n: &BigUint) -> i8 {
253 if n.is_zero() {
254 return if self.is_one() { 1 } else { 0 };
255 }
256 if n.is_one() {
257 return 1;
258 }
259 if n == &BigUint::from(2u8) {
260 return if self.is_even() {
261 0
262 } else {
263 let seven = BigUint::from(7u8);
264 if (self & &seven).is_one() || self & &seven == seven {
265 1
266 } else {
267 -1
268 }
269 };
270 }
271
272 let f = n.trailing_zeros().unwrap_or(0);
273 let n = n >> f;
274 let t1 = self.kronecker(&BigUint::from(2u8));
275 let t2 = self.jacobi(&n);
276 t1.pow(f.try_into().unwrap()) * t2
277 }
278 }
279
280 impl ModularSymbols<&BigInt> for BigInt {
281 #[inline]
282 fn checked_legendre(&self, n: &BigInt) -> Option<i8> {
283 if n < &BigInt::one() {
284 return None;
285 }
286 self.mod_floor(n)
287 .magnitude()
288 .checked_legendre(n.magnitude())
289 }
290
291 fn checked_jacobi(&self, n: &BigInt) -> Option<i8> {
292 if n < &BigInt::one() {
293 return None;
294 }
295 self.mod_floor(n).magnitude().checked_jacobi(n.magnitude())
296 }
297
298 #[inline]
299 fn kronecker(&self, n: &BigInt) -> i8 {
300 if n.is_negative() {
301 return if n.magnitude().is_one() {
302 if self.is_negative() {
303 -1
304 } else {
305 1
306 }
307 } else {
308 self.kronecker(&-BigInt::one()) * self.kronecker(&-n)
309 };
310 }
311
312 let n = n.magnitude();
314 if n.is_zero() {
315 return if self.is_one() { 1 } else { 0 };
316 }
317 if n.is_one() {
318 return 1;
319 }
320 if n == &BigUint::from(2u8) {
321 return if self.is_even() {
322 0
323 } else {
324 let eight = BigInt::from(8u8);
325 if (self.mod_floor(&eight)).is_one()
326 || self.mod_floor(&eight) == BigInt::from(7u8)
327 {
328 1
329 } else {
330 -1
331 }
332 };
333 }
334
335 let f = n.trailing_zeros().unwrap_or(0);
336 let n = n >> f;
337 let t1 = self.kronecker(&BigInt::from(2u8));
338 let t2 = self.jacobi(&n.into());
339 t1.pow(f.try_into().unwrap()) * t2
340 }
341 }
342
343 impl_mod_ops_by_ref!(BigUint);
344
345 impl ModularAbs<BigUint> for BigInt {
346 fn absm(self, m: &BigUint) -> BigUint {
347 if self.is_negative() {
348 self.magnitude().negm(m)
349 } else {
350 self.magnitude() % m
351 }
352 }
353 }
354
355 #[cfg(test)]
356 mod tests {
357 use super::*;
358 use rand::random;
359
360 const NRANDOM: u32 = 10; #[test]
363 fn basic_tests() {
364 for _ in 0..NRANDOM {
365 let a = random::<u128>();
366 let ra = &BigUint::from(a);
367 let b = random::<u128>();
368 let rb = &BigUint::from(b);
369 let m = random::<u128>() | 1;
370 let rm = &BigUint::from(m);
371 assert_eq!(ra.addm(rb, rm), (ra + rb) % rm);
372 assert_eq!(ra.mulm(rb, rm), (ra * rb) % rm);
373
374 let a = random::<u8>();
375 let ra = &BigUint::from(a);
376 let e = random::<u8>();
377 let re = &BigUint::from(e);
378 let m = random::<u128>() | 1;
379 let rm = &BigUint::from(m);
380 assert_eq!(ra.powm(re, rm), ra.pow(e as u32) % rm);
381 }
382 }
383
384 #[test]
385 fn test_against_prim() {
386 for _ in 0..NRANDOM {
387 let a = random::<u128>();
388 let ra = &BigUint::from(a);
389 let b = random::<u128>();
390 let rb = &BigUint::from(b);
391 let m = random::<u128>();
392 let rm = &BigUint::from(m);
393 assert_eq!(ra.addm(rb, rm), a.addm(b, &m).into());
394 assert_eq!(ra.subm(rb, rm), a.subm(b, &m).into());
395 assert_eq!(ra.mulm(rb, rm), a.mulm(b, &m).into());
396 assert_eq!(ra.negm(rm), a.negm(&m).into());
397 assert_eq!(ra.invm(rm), a.invm(&m).map(|v| v.into()));
398 assert_eq!(ra.checked_legendre(rm), a.checked_legendre(&m));
399 assert_eq!(ra.checked_jacobi(rm), a.checked_jacobi(&m));
400 assert_eq!(ra.kronecker(rm), a.kronecker(&m));
401
402 let e = random::<u8>();
403 let re = &BigUint::from(e);
404 assert_eq!(ra.powm(re, rm), a.powm(e as u128, &m).into());
405
406 let a = random::<i128>();
408 let ra = &BigInt::from(a);
409 let m = random::<i128>();
410 let rm = &BigInt::from(m);
411 assert_eq!(ra.checked_legendre(rm), a.checked_legendre(&m));
412 assert_eq!(ra.checked_jacobi(rm), a.checked_jacobi(&m));
413 assert_eq!(ra.kronecker(rm), a.kronecker(&m));
414 }
415 }
416
417 #[test]
418 fn dblm_edge_case() {
419 let m = &BigUint::from(10u8);
421 let x = &BigUint::from(5u8); assert_eq!(x.dblm(m), BigUint::from(0u8), "dblm(m/2, m) should be 0");
423
424 let m = &BigUint::from(100u32);
425 let x = &BigUint::from(50u32); assert_eq!(x.dblm(m), BigUint::from(0u8), "dblm(m/2, m) should be 0");
427
428 let m = &BigUint::from(10u8);
430 assert_eq!(BigUint::from(3u8).dblm(m), BigUint::from(6u8));
431 assert_eq!(BigUint::from(7u8).dblm(m), BigUint::from(4u8));
432 assert_eq!(BigUint::from(9u8).dblm(m), BigUint::from(8u8));
433 }
434
435 #[test]
436 fn sqm_matches_mulm() {
437 for _ in 0..NRANDOM {
438 let a = random::<u128>();
439 let ra = &BigUint::from(a);
440 let m = random::<u128>() | 1;
441 let rm = &BigUint::from(m);
442 assert_eq!(ra.sqm(rm), ra.mulm(ra, rm), "sqm should match mulm");
443 }
444 }
445 }
446}