1use crate::{
4 Choice, Gcd, Int, NonZero, NonZeroUint, Odd, OddUint, Uint, Xgcd,
5 modular::{bingcd::xgcd::PatternXgcdOutput, safegcd},
6 primitives::u32_min,
7};
8
9impl<const LIMBS: usize> Uint<LIMBS> {
10 #[must_use]
12 pub const fn gcd(&self, rhs: &Self) -> Self {
13 let self_is_nz = self.is_nonzero();
14 let self_nz = NonZero(Uint::select(&Uint::ONE, self, self_is_nz));
16 Uint::select(rhs, self_nz.gcd_unsigned(rhs).as_ref(), self_is_nz)
17 }
18
19 #[must_use]
23 pub const fn gcd_vartime(&self, rhs: &Self) -> Self {
24 if self.is_zero_vartime() {
25 return *rhs;
26 }
27 NonZero(*self).gcd_unsigned_vartime(rhs).0
28 }
29
30 #[must_use]
34 pub const fn xgcd(&self, rhs: &Self) -> UintXgcdOutput<LIMBS> {
35 let self_is_zero = self.is_nonzero().not();
37 let self_nz = NonZero(Uint::select(self, &Uint::ONE, self_is_zero));
38 let rhs_is_zero = rhs.is_nonzero().not();
39 let rhs_nz = NonZero(Uint::select(rhs, &Uint::ONE, rhs_is_zero));
40
41 let NonZeroUintXgcdOutput {
42 gcd,
43 mut x,
44 mut y,
45 mut lhs_on_gcd,
46 mut rhs_on_gcd,
47 } = self_nz.xgcd(&rhs_nz);
48
49 let mut gcd = *gcd.as_ref();
51 gcd = Uint::select(&gcd, rhs, self_is_zero);
52 gcd = Uint::select(&gcd, self, rhs_is_zero);
53
54 x = Int::select(&x, &Int::ZERO, self_is_zero);
56 y = Int::select(&y, &Int::ONE, self_is_zero);
57 x = Int::select(&x, &Int::ONE, rhs_is_zero);
58 y = Int::select(&y, &Int::ZERO, rhs_is_zero);
59
60 lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ZERO, self_is_zero);
62 rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ONE, self_is_zero);
63 lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ONE, rhs_is_zero);
64 rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ZERO, rhs_is_zero);
65
66 UintXgcdOutput {
67 gcd,
68 x,
69 y,
70 lhs_on_gcd,
71 rhs_on_gcd,
72 }
73 }
74}
75
76impl<const LIMBS: usize> NonZeroUint<LIMBS> {
77 #[must_use]
79 pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
80 let lhs = self.as_ref();
81
82 let i = lhs.trailing_zeros();
93 let j = rhs.trailing_zeros();
94 let k = u32_min(i, j);
95
96 let odd_lhs = Odd(lhs.shr(i));
97 let gcd_div_2k = odd_lhs.gcd_unsigned(rhs);
98 NonZero(gcd_div_2k.as_ref().shl(k))
99 }
100
101 #[must_use]
105 pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
106 let lhs = self.as_ref();
107
108 let i = lhs.trailing_zeros_vartime();
109 let j = rhs.trailing_zeros_vartime();
110 let k = u32_min(i, j);
111
112 let odd_lhs = Odd(lhs.shr_vartime(i));
113 let gcd_div_2k = odd_lhs.gcd_unsigned_vartime(rhs);
114 NonZero(gcd_div_2k.as_ref().shl_vartime(k))
115 }
116
117 #[must_use]
121 pub const fn xgcd(&self, rhs: &Self) -> NonZeroUintXgcdOutput<LIMBS> {
122 let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref());
123
124 let i = lhs.trailing_zeros();
126 let j = rhs.trailing_zeros();
127 let k = u32_min(i, j);
128 lhs = lhs.shr(k);
129 rhs = rhs.shr(k);
130
131 let swap = Choice::from_u32_lt(j, i);
133 Uint::conditional_swap(&mut lhs, &mut rhs, swap);
134 let lhs = lhs.to_odd().expect_copied("odd by construction");
135 let rhs = rhs.to_nz().expect_copied("non-zero by construction");
136
137 let odd_output = OddUintXgcdOutput::from_pattern_output(lhs.binxgcd_nz(&rhs));
138 odd_output.to_nz_output(k, swap)
139 }
140}
141
142impl<const LIMBS: usize> OddUint<LIMBS> {
143 #[inline(always)]
145 #[must_use]
146 pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
147 if LIMBS == 1 {
148 Self::classic_bingcd(self, rhs)
149 } else {
150 Self::safegcd(self, rhs)
151 }
152 }
153
154 #[inline(always)]
158 #[must_use]
159 pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
160 if LIMBS == 1 {
161 Self::classic_bingcd_vartime(self, rhs)
162 } else {
163 Self::safegcd_vartime(self, rhs)
164 }
165 }
166
167 #[doc(hidden)]
173 #[inline(always)]
174 #[must_use]
175 pub const fn bingcd(&self, rhs: &Uint<LIMBS>) -> Self {
176 if LIMBS < 4 {
177 self.classic_bingcd(rhs)
178 } else {
179 self.optimized_bingcd(rhs)
180 }
181 }
182
183 #[doc(hidden)]
191 #[inline(always)]
192 #[must_use]
193 pub const fn bingcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
194 if LIMBS < 4 {
195 self.classic_bingcd_vartime(rhs)
196 } else {
197 self.optimized_bingcd_vartime(rhs)
198 }
199 }
200
201 #[doc(hidden)]
203 #[inline]
204 #[must_use]
205 pub const fn safegcd(&self, rhs: &Uint<LIMBS>) -> Self {
206 safegcd::gcd_odd::<LIMBS, false>(self, rhs)
207 }
208
209 #[doc(hidden)]
213 #[inline]
214 #[must_use]
215 pub const fn safegcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
216 safegcd::gcd_odd::<LIMBS, true>(self, rhs)
217 }
218
219 #[must_use]
223 pub const fn xgcd(&self, rhs: &Self) -> OddUintXgcdOutput<LIMBS> {
224 OddUintXgcdOutput::from_pattern_output(self.binxgcd_odd(rhs))
225 }
226}
227
228pub type UintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, Uint<LIMBS>>;
229pub type NonZeroUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, NonZeroUint<LIMBS>>;
230pub type OddUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, OddUint<LIMBS>>;
231
232#[derive(Debug, Copy, Clone)]
234pub struct XgcdOutput<const LIMBS: usize, GCD: Copy> {
235 pub gcd: GCD,
237 pub x: Int<LIMBS>,
239 pub y: Int<LIMBS>,
241 pub lhs_on_gcd: Uint<LIMBS>,
243 pub rhs_on_gcd: Uint<LIMBS>,
245}
246
247impl<const LIMBS: usize, GCD: Copy> XgcdOutput<LIMBS, GCD> {
248 pub const fn gcd(&self) -> GCD {
250 self.gcd
251 }
252
253 pub const fn bezout_coefficients(&self) -> (Int<LIMBS>, Int<LIMBS>) {
255 (self.x, self.y)
256 }
257
258 pub const fn quotients(&self) -> (Uint<LIMBS>, Uint<LIMBS>) {
260 (self.lhs_on_gcd, self.rhs_on_gcd)
261 }
262}
263
264impl<const LIMBS: usize> OddUintXgcdOutput<LIMBS> {
265 pub(crate) const fn from_pattern_output(output: PatternXgcdOutput<LIMBS>) -> Self {
266 let gcd = output.gcd();
267 let (x, y) = output.bezout_coefficients();
268 let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
269
270 OddUintXgcdOutput {
271 gcd,
272 x,
273 y,
274 lhs_on_gcd,
275 rhs_on_gcd,
276 }
277 }
278
279 pub(crate) const fn to_nz_output(self, k: u32, swap: Choice) -> NonZeroUintXgcdOutput<LIMBS> {
280 let Self {
281 ref gcd,
282 mut x,
283 mut y,
284 mut lhs_on_gcd,
285 mut rhs_on_gcd,
286 } = self;
287
288 let gcd = gcd
290 .as_ref()
291 .shl(k)
292 .to_nz()
293 .expect_copied("is non-zero by construction");
294 Int::conditional_swap(&mut x, &mut y, swap);
295 Uint::conditional_swap(&mut lhs_on_gcd, &mut rhs_on_gcd, swap);
296
297 NonZeroUintXgcdOutput {
298 gcd,
299 x,
300 y,
301 lhs_on_gcd,
302 rhs_on_gcd,
303 }
304 }
305}
306
307macro_rules! impl_gcd {
308 ($slf:ty, [$($rhs:ty),+]) => {
309 $(
310 impl_gcd!($slf, $rhs, $rhs);
311 )+
312 };
313 ($slf:ty, $rhs:ty, $out:ty) => {
314 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
315 type Output = $out;
316
317 #[inline]
318 fn gcd(&self, rhs: &$rhs) -> Self::Output {
319 rhs.gcd(self)
320 }
321
322 #[inline]
323 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
324 rhs.gcd_vartime(self)
325 }
326 }
327 };
328}
329
330macro_rules! impl_gcd_unsigned_lhs {
331 ($slf:ty, [$($rhs:ty),+]) => {
332 $(
333 impl_gcd_unsigned_lhs!($slf, $rhs, $slf);
334 )+
335 };
336 ($slf:ty, $rhs:ty, $out:ty) => {
337 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
338 type Output = $out;
339
340 #[inline]
341 fn gcd(&self, rhs: &$rhs) -> Self::Output {
342 self.gcd_unsigned(&rhs)
343 }
344
345 #[inline]
346 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
347 self.gcd_unsigned_vartime(&rhs)
348 }
349 }
350 };
351}
352
353macro_rules! impl_gcd_unsigned_rhs {
354 ($slf:ty, [$($rhs:ty),+]) => {
355 $(
356 impl_gcd_unsigned_rhs!($slf, $rhs, $rhs);
357 )+
358 };
359 ($slf:ty, $rhs:ty, $out:ty) => {
360 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
361 type Output = $out;
362
363 #[inline]
364 fn gcd(&self, rhs: &$rhs) -> Self::Output {
365 rhs.gcd_unsigned(self)
366 }
367
368 #[inline]
369 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
370 rhs.gcd_unsigned_vartime(self)
371 }
372 }
373 };
374}
375
376pub(crate) use impl_gcd_unsigned_lhs;
377pub(crate) use impl_gcd_unsigned_rhs;
378
379impl_gcd!(
380 Uint<LIMBS>,
381 [Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]
382);
383impl_gcd_unsigned_lhs!(NonZeroUint<LIMBS>, [Uint<LIMBS>]);
384impl_gcd_unsigned_rhs!(
385 NonZeroUint<LIMBS>,
386 [NonZeroUint<LIMBS>, OddUint<LIMBS>]
387);
388impl_gcd_unsigned_lhs!(OddUint<LIMBS>, [Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]);
389
390impl<const LIMBS: usize> Xgcd for Uint<LIMBS> {
391 type Output = UintXgcdOutput<LIMBS>;
392
393 fn xgcd(&self, rhs: &Uint<LIMBS>) -> Self::Output {
394 self.xgcd(rhs)
395 }
396
397 fn xgcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self::Output {
398 self.xgcd(rhs)
400 }
401}
402
403impl<const LIMBS: usize> Xgcd for NonZeroUint<LIMBS> {
404 type Output = NonZeroUintXgcdOutput<LIMBS>;
405
406 fn xgcd(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
407 self.xgcd(rhs)
408 }
409
410 fn xgcd_vartime(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
411 self.xgcd(rhs)
413 }
414}
415
416impl<const LIMBS: usize> Xgcd for OddUint<LIMBS> {
417 type Output = OddUintXgcdOutput<LIMBS>;
418
419 fn xgcd(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
420 self.xgcd(rhs)
421 }
422
423 fn xgcd_vartime(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
424 self.xgcd(rhs)
426 }
427}
428
429#[cfg(all(test, not(miri)))]
430mod tests {
431 mod gcd {
432 use crate::{U64, U128, U256, U512, U1024, U2048, U4096, Uint};
433
434 fn test<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>, target: Uint<LIMBS>) {
435 assert_eq!(lhs.gcd(&rhs), target);
436 assert_eq!(lhs.gcd_vartime(&rhs), target);
437 }
438
439 fn run_tests<const LIMBS: usize>() {
440 test(Uint::<LIMBS>::ZERO, Uint::ZERO, Uint::ZERO);
441 test(Uint::<LIMBS>::ZERO, Uint::ONE, Uint::ONE);
442 test(Uint::<LIMBS>::ZERO, Uint::MAX, Uint::MAX);
443 test(Uint::<LIMBS>::ONE, Uint::ZERO, Uint::ONE);
444 test(Uint::<LIMBS>::ONE, Uint::ONE, Uint::ONE);
445 test(Uint::<LIMBS>::ONE, Uint::MAX, Uint::ONE);
446 test(Uint::<LIMBS>::MAX, Uint::ZERO, Uint::MAX);
447 test(Uint::<LIMBS>::MAX, Uint::ONE, Uint::ONE);
448 test(Uint::<LIMBS>::MAX, Uint::MAX, Uint::MAX);
449 }
450
451 #[test]
452 fn gcd_sizes() {
453 run_tests::<{ U64::LIMBS }>();
454 run_tests::<{ U128::LIMBS }>();
455 run_tests::<{ U256::LIMBS }>();
456 run_tests::<{ U512::LIMBS }>();
457 run_tests::<{ U1024::LIMBS }>();
458 run_tests::<{ U2048::LIMBS }>();
459 run_tests::<{ U4096::LIMBS }>();
460 }
461 }
462
463 mod xgcd {
464 use crate::{Concat, Int, U64, U128, U256, U512, U1024, U2048, U4096, U8192, U16384, Uint};
465 use core::ops::Div;
466
467 fn test<const LIMBS: usize, const DOUBLE: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>)
468 where
469 Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
470 {
471 let output = lhs.xgcd(&rhs);
472 assert_eq!(output.gcd, lhs.gcd(&rhs));
473
474 if output.gcd > Uint::ZERO {
475 assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.to_nz().unwrap()));
476 assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.to_nz().unwrap()));
477 }
478
479 let (x, y) = output.bezout_coefficients();
480 assert_eq!(
481 x.concatenating_mul_unsigned(&lhs) + y.concatenating_mul_unsigned(&rhs),
482 *output.gcd.resize().as_int()
483 );
484 }
485
486 fn run_tests<const LIMBS: usize, const DOUBLE: usize>()
487 where
488 Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
489 {
490 let min = Int::MIN.abs();
491 test(Uint::ZERO, Uint::ZERO);
492 test(Uint::ZERO, Uint::ONE);
493 test(Uint::ZERO, min);
494 test(Uint::ZERO, Uint::MAX);
495 test(Uint::ONE, Uint::ZERO);
496 test(Uint::ONE, Uint::ONE);
497 test(Uint::ONE, min);
498 test(Uint::ONE, Uint::MAX);
499 test(min, Uint::ZERO);
500 test(min, Uint::ONE);
501 test(min, Int::MIN.abs());
502 test(min, Uint::MAX);
503 test(Uint::MAX, Uint::ZERO);
504 test(Uint::MAX, Uint::ONE);
505 test(Uint::MAX, min);
506 test(Uint::MAX, Uint::MAX);
507 }
508
509 #[test]
510 fn binxgcd() {
511 run_tests::<{ U64::LIMBS }, { U128::LIMBS }>();
512 run_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
513 run_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
514 run_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
515 run_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
516 run_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
517 run_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
518 run_tests::<{ U8192::LIMBS }, { U16384::LIMBS }>();
519 }
520
521 #[test]
522 fn regression_tests() {
523 let a = U256::from_be_hex(
525 "000000000000000000000000000000000000001B5DFB3BA1D549DFAF611B8D4C",
526 );
527 let b = U256::from_be_hex(
528 "000000000000345EAEDFA8CA03C1F0F5B578A787FE2D23B82A807F178B37FD8E",
529 );
530 test(a, b);
531
532 let a = U256::from_be_hex(
534 "000000000000000000000000000000000000001A0DEEF6F3AC2566149D925044",
535 );
536 let b = U256::from_be_hex(
537 "000000000000072B69C9DD0AA15F135675EA9C5180CF8FF0A59298CFC92E87FA",
538 );
539 test(a, b);
540
541 let a = U512::from_be_hex(concat![
543 "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364142",
544 "4EB38E6AC0E34DE2F34BFAF22DE683E1F4B92847B6871C780488D797042229E1"
545 ]);
546 let b = U512::from_be_hex(concat![
547 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C8283",
548 "9D671CD581C69BC5E697F5E45BCD07C52EC373A8BDC598B4493F50A1380E1281"
549 ]);
550 test(a, b);
551 }
552 }
553
554 mod traits {
555 use crate::{Gcd, I256, U256};
556
557 #[test]
558 fn gcd_relatively_prime() {
559 let f = U256::from(59u32 * 67);
561 let g = U256::from(61u32 * 71);
562 let gcd = f.gcd(&g);
563 assert_eq!(gcd, U256::ONE);
564 }
565
566 #[test]
567 fn gcd_nonprime() {
568 let f = U256::from(4391633u32);
569 let g = U256::from(2022161u32);
570 let gcd = f.gcd(&g);
571 assert_eq!(gcd, U256::from(1763u32));
572 }
573
574 #[test]
575 fn gcd_zero() {
576 assert_eq!(U256::ZERO.gcd(&U256::ZERO), U256::ZERO);
577 assert_eq!(U256::ZERO.gcd(&U256::ONE), U256::ONE);
578 assert_eq!(U256::ONE.gcd(&U256::ZERO), U256::ONE);
579 }
580
581 #[test]
582 fn gcd_one() {
583 let f = U256::ONE;
584 assert_eq!(U256::ONE, f.gcd(&U256::ONE));
585 assert_eq!(U256::ONE, f.gcd(&U256::from(2u8)));
586 }
587
588 #[test]
589 fn gcd_two() {
590 let f = U256::from_u8(2);
591 assert_eq!(f, f.gcd(&f));
592
593 let g = U256::from_u8(4);
594 assert_eq!(f, f.gcd(&g));
595 assert_eq!(f, g.gcd(&f));
596 }
597
598 #[test]
599 fn gcd_unsigned_int() {
600 let f = U256::from(61u32 * 71);
602 let g = I256::from(59i32 * 61);
603
604 let sixty_one = U256::from(61u32);
605 assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g));
606 assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g.wrapping_neg()));
607 }
608
609 #[test]
610 fn xgcd_expected() {
611 let f = U256::from(61u32 * 71);
613 let g = U256::from(59u32 * 61);
614
615 let actual = f.xgcd(&g);
616 assert_eq!(U256::from(61u32), actual.gcd);
617 assert_eq!(I256::from(5i32), actual.x);
618 assert_eq!(I256::from(-6i32), actual.y);
619 }
620 }
621}