use crate::{
Choice, Gcd, Int, NonZero, NonZeroUint, Odd, OddUint, Uint, Xgcd,
modular::{bingcd::xgcd::PatternXgcdOutput, safegcd},
primitives::u32_min,
};
impl<const LIMBS: usize> Uint<LIMBS> {
#[must_use]
pub const fn gcd(&self, rhs: &Self) -> Self {
let self_is_nz = self.is_nonzero();
let self_nz = NonZero(Uint::select(&Uint::ONE, self, self_is_nz));
Uint::select(rhs, self_nz.gcd_unsigned(rhs).as_ref(), self_is_nz)
}
#[must_use]
pub const fn gcd_vartime(&self, rhs: &Self) -> Self {
if self.is_zero_vartime() {
return *rhs;
}
NonZero(*self).gcd_unsigned_vartime(rhs).0
}
#[must_use]
pub const fn xgcd(&self, rhs: &Self) -> UintXgcdOutput<LIMBS> {
let self_is_zero = self.is_nonzero().not();
let self_nz = NonZero(Uint::select(self, &Uint::ONE, self_is_zero));
let rhs_is_zero = rhs.is_nonzero().not();
let rhs_nz = NonZero(Uint::select(rhs, &Uint::ONE, rhs_is_zero));
let NonZeroUintXgcdOutput {
gcd,
mut x,
mut y,
mut lhs_on_gcd,
mut rhs_on_gcd,
} = self_nz.xgcd(&rhs_nz);
let mut gcd = *gcd.as_ref();
gcd = Uint::select(&gcd, rhs, self_is_zero);
gcd = Uint::select(&gcd, self, rhs_is_zero);
x = Int::select(&x, &Int::ZERO, self_is_zero);
y = Int::select(&y, &Int::ONE, self_is_zero);
x = Int::select(&x, &Int::ONE, rhs_is_zero);
y = Int::select(&y, &Int::ZERO, rhs_is_zero);
lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ZERO, self_is_zero);
rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ONE, self_is_zero);
lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ONE, rhs_is_zero);
rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ZERO, rhs_is_zero);
UintXgcdOutput {
gcd,
x,
y,
lhs_on_gcd,
rhs_on_gcd,
}
}
}
impl<const LIMBS: usize> NonZeroUint<LIMBS> {
#[must_use]
pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
let lhs = self.as_ref();
let i = lhs.trailing_zeros();
let j = rhs.trailing_zeros();
let k = u32_min(i, j);
let odd_lhs = Odd(lhs.shr(i));
let gcd_div_2k = odd_lhs.gcd_unsigned(rhs);
NonZero(gcd_div_2k.as_ref().shl(k))
}
#[must_use]
pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
let lhs = self.as_ref();
let i = lhs.trailing_zeros_vartime();
let j = rhs.trailing_zeros_vartime();
let k = u32_min(i, j);
let odd_lhs = Odd(lhs.shr_vartime(i));
let gcd_div_2k = odd_lhs.gcd_unsigned_vartime(rhs);
NonZero(gcd_div_2k.as_ref().shl_vartime(k))
}
#[must_use]
pub const fn xgcd(&self, rhs: &Self) -> NonZeroUintXgcdOutput<LIMBS> {
let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref());
let i = lhs.trailing_zeros();
let j = rhs.trailing_zeros();
let k = u32_min(i, j);
lhs = lhs.shr(k);
rhs = rhs.shr(k);
let swap = Choice::from_u32_lt(j, i);
Uint::conditional_swap(&mut lhs, &mut rhs, swap);
let lhs = lhs.to_odd().expect_copied("odd by construction");
let rhs = rhs.to_nz().expect_copied("non-zero by construction");
let odd_output = OddUintXgcdOutput::from_pattern_output(lhs.binxgcd_nz(&rhs));
odd_output.to_nz_output(k, swap)
}
}
impl<const LIMBS: usize> OddUint<LIMBS> {
#[inline(always)]
#[must_use]
pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
if LIMBS == 1 {
Self::classic_bingcd(self, rhs)
} else {
Self::safegcd(self, rhs)
}
}
#[inline(always)]
#[must_use]
pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
if LIMBS == 1 {
Self::classic_bingcd_vartime(self, rhs)
} else {
Self::safegcd_vartime(self, rhs)
}
}
#[doc(hidden)]
#[inline(always)]
#[must_use]
pub const fn bingcd(&self, rhs: &Uint<LIMBS>) -> Self {
if LIMBS < 4 {
self.classic_bingcd(rhs)
} else {
self.optimized_bingcd(rhs)
}
}
#[doc(hidden)]
#[inline(always)]
#[must_use]
pub const fn bingcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
if LIMBS < 4 {
self.classic_bingcd_vartime(rhs)
} else {
self.optimized_bingcd_vartime(rhs)
}
}
#[doc(hidden)]
#[inline]
#[must_use]
pub const fn safegcd(&self, rhs: &Uint<LIMBS>) -> Self {
safegcd::gcd_odd::<LIMBS, false>(self, rhs)
}
#[doc(hidden)]
#[inline]
#[must_use]
pub const fn safegcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
safegcd::gcd_odd::<LIMBS, true>(self, rhs)
}
#[must_use]
pub const fn xgcd(&self, rhs: &Self) -> OddUintXgcdOutput<LIMBS> {
OddUintXgcdOutput::from_pattern_output(self.binxgcd_odd(rhs))
}
}
pub type UintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, Uint<LIMBS>>;
pub type NonZeroUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, NonZeroUint<LIMBS>>;
pub type OddUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, OddUint<LIMBS>>;
#[derive(Debug, Copy, Clone)]
pub struct XgcdOutput<const LIMBS: usize, GCD: Copy> {
pub gcd: GCD,
pub x: Int<LIMBS>,
pub y: Int<LIMBS>,
pub lhs_on_gcd: Uint<LIMBS>,
pub rhs_on_gcd: Uint<LIMBS>,
}
impl<const LIMBS: usize, GCD: Copy> XgcdOutput<LIMBS, GCD> {
pub const fn gcd(&self) -> GCD {
self.gcd
}
pub const fn bezout_coefficients(&self) -> (Int<LIMBS>, Int<LIMBS>) {
(self.x, self.y)
}
pub const fn quotients(&self) -> (Uint<LIMBS>, Uint<LIMBS>) {
(self.lhs_on_gcd, self.rhs_on_gcd)
}
}
impl<const LIMBS: usize> OddUintXgcdOutput<LIMBS> {
pub(crate) const fn from_pattern_output(output: PatternXgcdOutput<LIMBS>) -> Self {
let gcd = output.gcd();
let (x, y) = output.bezout_coefficients();
let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
OddUintXgcdOutput {
gcd,
x,
y,
lhs_on_gcd,
rhs_on_gcd,
}
}
pub(crate) const fn to_nz_output(self, k: u32, swap: Choice) -> NonZeroUintXgcdOutput<LIMBS> {
let Self {
ref gcd,
mut x,
mut y,
mut lhs_on_gcd,
mut rhs_on_gcd,
} = self;
let gcd = gcd
.as_ref()
.shl(k)
.to_nz()
.expect_copied("is non-zero by construction");
Int::conditional_swap(&mut x, &mut y, swap);
Uint::conditional_swap(&mut lhs_on_gcd, &mut rhs_on_gcd, swap);
NonZeroUintXgcdOutput {
gcd,
x,
y,
lhs_on_gcd,
rhs_on_gcd,
}
}
}
macro_rules! impl_gcd {
($slf:ty, [$($rhs:ty),+]) => {
$(
impl_gcd!($slf, $rhs, $rhs);
)+
};
($slf:ty, $rhs:ty, $out:ty) => {
impl<const LIMBS: usize> Gcd<$rhs> for $slf {
type Output = $out;
#[inline]
fn gcd(&self, rhs: &$rhs) -> Self::Output {
rhs.gcd(self)
}
#[inline]
fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
rhs.gcd_vartime(self)
}
}
};
}
macro_rules! impl_gcd_unsigned_lhs {
($slf:ty, [$($rhs:ty),+]) => {
$(
impl_gcd_unsigned_lhs!($slf, $rhs, $slf);
)+
};
($slf:ty, $rhs:ty, $out:ty) => {
impl<const LIMBS: usize> Gcd<$rhs> for $slf {
type Output = $out;
#[inline]
fn gcd(&self, rhs: &$rhs) -> Self::Output {
self.gcd_unsigned(&rhs)
}
#[inline]
fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
self.gcd_unsigned_vartime(&rhs)
}
}
};
}
macro_rules! impl_gcd_unsigned_rhs {
($slf:ty, [$($rhs:ty),+]) => {
$(
impl_gcd_unsigned_rhs!($slf, $rhs, $rhs);
)+
};
($slf:ty, $rhs:ty, $out:ty) => {
impl<const LIMBS: usize> Gcd<$rhs> for $slf {
type Output = $out;
#[inline]
fn gcd(&self, rhs: &$rhs) -> Self::Output {
rhs.gcd_unsigned(self)
}
#[inline]
fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
rhs.gcd_unsigned_vartime(self)
}
}
};
}
pub(crate) use impl_gcd_unsigned_lhs;
pub(crate) use impl_gcd_unsigned_rhs;
impl_gcd!(
Uint<LIMBS>,
[Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]
);
impl_gcd_unsigned_lhs!(NonZeroUint<LIMBS>, [Uint<LIMBS>]);
impl_gcd_unsigned_rhs!(
NonZeroUint<LIMBS>,
[NonZeroUint<LIMBS>, OddUint<LIMBS>]
);
impl_gcd_unsigned_lhs!(OddUint<LIMBS>, [Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]);
impl<const LIMBS: usize> Xgcd for Uint<LIMBS> {
type Output = UintXgcdOutput<LIMBS>;
fn xgcd(&self, rhs: &Uint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
fn xgcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
}
impl<const LIMBS: usize> Xgcd for NonZeroUint<LIMBS> {
type Output = NonZeroUintXgcdOutput<LIMBS>;
fn xgcd(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
fn xgcd_vartime(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
}
impl<const LIMBS: usize> Xgcd for OddUint<LIMBS> {
type Output = OddUintXgcdOutput<LIMBS>;
fn xgcd(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
fn xgcd_vartime(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
self.xgcd(rhs)
}
}
#[cfg(all(test, not(miri)))]
mod tests {
mod gcd {
use crate::{U64, U128, U256, U512, U1024, U2048, U4096, Uint};
fn test<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>, target: Uint<LIMBS>) {
assert_eq!(lhs.gcd(&rhs), target);
assert_eq!(lhs.gcd_vartime(&rhs), target);
}
fn run_tests<const LIMBS: usize>() {
test(Uint::<LIMBS>::ZERO, Uint::ZERO, Uint::ZERO);
test(Uint::<LIMBS>::ZERO, Uint::ONE, Uint::ONE);
test(Uint::<LIMBS>::ZERO, Uint::MAX, Uint::MAX);
test(Uint::<LIMBS>::ONE, Uint::ZERO, Uint::ONE);
test(Uint::<LIMBS>::ONE, Uint::ONE, Uint::ONE);
test(Uint::<LIMBS>::ONE, Uint::MAX, Uint::ONE);
test(Uint::<LIMBS>::MAX, Uint::ZERO, Uint::MAX);
test(Uint::<LIMBS>::MAX, Uint::ONE, Uint::ONE);
test(Uint::<LIMBS>::MAX, Uint::MAX, Uint::MAX);
}
#[test]
fn gcd_sizes() {
run_tests::<{ U64::LIMBS }>();
run_tests::<{ U128::LIMBS }>();
run_tests::<{ U256::LIMBS }>();
run_tests::<{ U512::LIMBS }>();
run_tests::<{ U1024::LIMBS }>();
run_tests::<{ U2048::LIMBS }>();
run_tests::<{ U4096::LIMBS }>();
}
}
mod xgcd {
use crate::{Concat, Int, U64, U128, U256, U512, U1024, U2048, U4096, U8192, U16384, Uint};
use core::ops::Div;
fn test<const LIMBS: usize, const DOUBLE: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>)
where
Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
{
let output = lhs.xgcd(&rhs);
assert_eq!(output.gcd, lhs.gcd(&rhs));
if output.gcd > Uint::ZERO {
assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.to_nz().unwrap()));
assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.to_nz().unwrap()));
}
let (x, y) = output.bezout_coefficients();
assert_eq!(
x.concatenating_mul_unsigned(&lhs) + y.concatenating_mul_unsigned(&rhs),
*output.gcd.resize().as_int()
);
}
fn run_tests<const LIMBS: usize, const DOUBLE: usize>()
where
Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
{
let min = Int::MIN.abs();
test(Uint::ZERO, Uint::ZERO);
test(Uint::ZERO, Uint::ONE);
test(Uint::ZERO, min);
test(Uint::ZERO, Uint::MAX);
test(Uint::ONE, Uint::ZERO);
test(Uint::ONE, Uint::ONE);
test(Uint::ONE, min);
test(Uint::ONE, Uint::MAX);
test(min, Uint::ZERO);
test(min, Uint::ONE);
test(min, Int::MIN.abs());
test(min, Uint::MAX);
test(Uint::MAX, Uint::ZERO);
test(Uint::MAX, Uint::ONE);
test(Uint::MAX, min);
test(Uint::MAX, Uint::MAX);
}
#[test]
fn binxgcd() {
run_tests::<{ U64::LIMBS }, { U128::LIMBS }>();
run_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
run_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
run_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
run_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
run_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
run_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
run_tests::<{ U8192::LIMBS }, { U16384::LIMBS }>();
}
#[test]
fn regression_tests() {
let a = U256::from_be_hex(
"000000000000000000000000000000000000001B5DFB3BA1D549DFAF611B8D4C",
);
let b = U256::from_be_hex(
"000000000000345EAEDFA8CA03C1F0F5B578A787FE2D23B82A807F178B37FD8E",
);
test(a, b);
let a = U256::from_be_hex(
"000000000000000000000000000000000000001A0DEEF6F3AC2566149D925044",
);
let b = U256::from_be_hex(
"000000000000072B69C9DD0AA15F135675EA9C5180CF8FF0A59298CFC92E87FA",
);
test(a, b);
let a = U512::from_be_hex(concat![
"7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364142",
"4EB38E6AC0E34DE2F34BFAF22DE683E1F4B92847B6871C780488D797042229E1"
]);
let b = U512::from_be_hex(concat![
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C8283",
"9D671CD581C69BC5E697F5E45BCD07C52EC373A8BDC598B4493F50A1380E1281"
]);
test(a, b);
}
}
mod traits {
use crate::{Gcd, I256, U256};
#[test]
fn gcd_relatively_prime() {
let f = U256::from(59u32 * 67);
let g = U256::from(61u32 * 71);
let gcd = f.gcd(&g);
assert_eq!(gcd, U256::ONE);
}
#[test]
fn gcd_nonprime() {
let f = U256::from(4391633u32);
let g = U256::from(2022161u32);
let gcd = f.gcd(&g);
assert_eq!(gcd, U256::from(1763u32));
}
#[test]
fn gcd_zero() {
assert_eq!(U256::ZERO.gcd(&U256::ZERO), U256::ZERO);
assert_eq!(U256::ZERO.gcd(&U256::ONE), U256::ONE);
assert_eq!(U256::ONE.gcd(&U256::ZERO), U256::ONE);
}
#[test]
fn gcd_one() {
let f = U256::ONE;
assert_eq!(U256::ONE, f.gcd(&U256::ONE));
assert_eq!(U256::ONE, f.gcd(&U256::from(2u8)));
}
#[test]
fn gcd_two() {
let f = U256::from_u8(2);
assert_eq!(f, f.gcd(&f));
let g = U256::from_u8(4);
assert_eq!(f, f.gcd(&g));
assert_eq!(f, g.gcd(&f));
}
#[test]
fn gcd_unsigned_int() {
let f = U256::from(61u32 * 71);
let g = I256::from(59i32 * 61);
let sixty_one = U256::from(61u32);
assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g));
assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g.wrapping_neg()));
}
#[test]
fn xgcd_expected() {
let f = U256::from(61u32 * 71);
let g = U256::from(59u32 * 61);
let actual = f.xgcd(&g);
assert_eq!(U256::from(61u32), actual.gcd);
assert_eq!(I256::from(5i32), actual.x);
assert_eq!(I256::from(-6i32), actual.y);
}
}
}