use core::cmp::Ordering;
#[derive(Clone, Debug)]
pub struct BigInt {
pub limbs: Vec<u64>,
}
impl BigInt {
pub fn zero() -> Self {
Self { limbs: vec![0] }
}
pub fn from_u64(v: u64) -> Self {
Self { limbs: vec![v] }
}
pub fn from_be_bytes(bytes: &[u8]) -> Self {
if bytes.is_empty() {
return Self::zero();
}
let padded_len = (bytes.len() + 7) / 8 * 8;
let mut padded = vec![0u8; padded_len];
padded[padded_len - bytes.len()..].copy_from_slice(bytes);
let n_limbs = padded_len / 8;
let mut limbs = Vec::with_capacity(n_limbs);
for i in (0..n_limbs).rev() {
let off = i * 8;
let limb = u64::from_be_bytes([
padded[off],
padded[off + 1],
padded[off + 2],
padded[off + 3],
padded[off + 4],
padded[off + 5],
padded[off + 6],
padded[off + 7],
]);
limbs.push(limb);
}
let mut r = Self { limbs };
r.trim();
r
}
pub fn to_be_bytes(&self, min_len: usize) -> Vec<u8> {
let n = self.limbs.len();
let byte_len = n * 8;
let mut buf = vec![0u8; byte_len];
for (i, &limb) in self.limbs.iter().enumerate() {
let off = byte_len - (i + 1) * 8;
buf[off..off + 8].copy_from_slice(&limb.to_be_bytes());
}
let start = buf.iter().position(|&b| b != 0).unwrap_or(buf.len());
let significant = &buf[start..];
if significant.len() >= min_len {
significant.to_vec()
} else {
let mut out = vec![0u8; min_len];
out[min_len - significant.len()..].copy_from_slice(significant);
out
}
}
pub fn bit_len(&self) -> usize {
let n = self.limbs.len();
if n == 0 {
return 0;
}
let top = self.limbs[n - 1];
if top == 0 && n == 1 {
return 0;
}
(n - 1) * 64 + (64 - top.leading_zeros() as usize)
}
pub fn bit(&self, i: usize) -> bool {
let limb_idx = i / 64;
let bit_idx = i % 64;
if limb_idx >= self.limbs.len() {
false
} else {
(self.limbs[limb_idx] >> bit_idx) & 1 == 1
}
}
pub fn set_bit(&mut self, i: usize) {
let limb_idx = i / 64;
let bit_idx = i % 64;
while self.limbs.len() <= limb_idx {
self.limbs.push(0);
}
self.limbs[limb_idx] |= 1u64 << bit_idx;
}
pub fn is_zero(&self) -> bool {
self.limbs.iter().all(|&l| l == 0)
}
pub fn is_even(&self) -> bool {
self.limbs.first().map_or(true, |&l| l & 1 == 0)
}
pub fn is_odd(&self) -> bool {
!self.is_even()
}
fn trim(&mut self) {
while self.limbs.len() > 1 && *self.limbs.last().unwrap() == 0 {
self.limbs.pop();
}
}
pub fn byte_len(&self) -> usize {
(self.bit_len() + 7) / 8
}
pub fn random(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> Self {
let byte_len = (bits + 7) / 8;
let mut buf = vec![0u8; byte_len];
rng(&mut buf);
let top_bit = (bits - 1) % 8;
if top_bit < 7 {
buf[0] &= (1u8 << (top_bit + 1)) - 1;
}
buf[0] |= 1u8 << top_bit; Self::from_be_bytes(&buf)
}
pub fn random_odd(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> Self {
let mut n = Self::random(bits, rng);
n.limbs[0] |= 1; n
}
}
impl PartialEq for BigInt {
fn eq(&self, other: &Self) -> bool {
self.cmp_to(other) == Ordering::Equal
}
}
impl Eq for BigInt {}
impl PartialOrd for BigInt {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp_to(other))
}
}
impl Ord for BigInt {
fn cmp(&self, other: &Self) -> Ordering {
self.cmp_to(other)
}
}
impl BigInt {
pub fn cmp_to(&self, other: &Self) -> Ordering {
let a_len = self.limbs.len();
let b_len = other.limbs.len();
let max_len = a_len.max(b_len);
for i in (0..max_len).rev() {
let a = if i < a_len { self.limbs[i] } else { 0 };
let b = if i < b_len { other.limbs[i] } else { 0 };
match a.cmp(&b) {
Ordering::Equal => continue,
ord => return ord,
}
}
Ordering::Equal
}
}
impl BigInt {
pub fn add(&self, other: &BigInt) -> BigInt {
let max_len = self.limbs.len().max(other.limbs.len());
let mut result = Vec::with_capacity(max_len + 1);
let mut carry: u64 = 0;
for i in 0..max_len {
let a = if i < self.limbs.len() { self.limbs[i] } else { 0 };
let b = if i < other.limbs.len() { other.limbs[i] } else { 0 };
let (sum1, c1) = a.overflowing_add(b);
let (sum2, c2) = sum1.overflowing_add(carry);
result.push(sum2);
carry = (c1 as u64) + (c2 as u64);
}
if carry > 0 {
result.push(carry);
}
let mut r = BigInt { limbs: result };
r.trim();
r
}
pub fn add_u64(&self, v: u64) -> BigInt {
self.add(&BigInt::from_u64(v))
}
}
impl BigInt {
pub fn sub(&self, other: &BigInt) -> BigInt {
debug_assert!(self >= other, "BigInt::sub: underflow");
let mut result = Vec::with_capacity(self.limbs.len());
let mut borrow: u64 = 0;
for i in 0..self.limbs.len() {
let a = self.limbs[i];
let b = if i < other.limbs.len() { other.limbs[i] } else { 0 };
let (diff1, b1) = a.overflowing_sub(b);
let (diff2, b2) = diff1.overflowing_sub(borrow);
result.push(diff2);
borrow = (b1 as u64) + (b2 as u64);
}
let mut r = BigInt { limbs: result };
r.trim();
r
}
pub fn sub_one(&self) -> BigInt {
self.sub(&BigInt::from_u64(1))
}
}
impl BigInt {
pub fn mul(&self, other: &BigInt) -> BigInt {
let n = self.limbs.len();
let m = other.limbs.len();
let mut result = vec![0u64; n + m];
for i in 0..n {
let mut carry: u64 = 0;
for j in 0..m {
let (lo, hi) = mul_u64(self.limbs[i], other.limbs[j]);
let (s1, c1) = result[i + j].overflowing_add(lo);
let (s2, c2) = s1.overflowing_add(carry);
result[i + j] = s2;
carry = hi + (c1 as u64) + (c2 as u64);
}
result[i + m] = carry;
}
let mut r = BigInt { limbs: result };
r.trim();
r
}
}
#[inline]
fn mul_u64(a: u64, b: u64) -> (u64, u64) {
let full = (a as u128) * (b as u128);
(full as u64, (full >> 64) as u64)
}
impl BigInt {
pub fn div_rem(&self, divisor: &BigInt) -> (BigInt, BigInt) {
assert!(!divisor.is_zero(), "BigInt: division by zero");
if self < divisor {
return (BigInt::zero(), self.clone());
}
if divisor.limbs.len() == 1 {
return self.div_rem_u64(divisor.limbs[0]);
}
self.div_rem_knuth(divisor)
}
fn div_rem_u64(&self, d: u64) -> (BigInt, BigInt) {
let mut rem: u128 = 0;
let mut quotient = vec![0u64; self.limbs.len()];
for i in (0..self.limbs.len()).rev() {
rem = (rem << 64) | (self.limbs[i] as u128);
quotient[i] = (rem / d as u128) as u64;
rem %= d as u128;
}
let mut q = BigInt { limbs: quotient };
q.trim();
(q, BigInt::from_u64(rem as u64))
}
fn div_rem_knuth(&self, divisor: &BigInt) -> (BigInt, BigInt) {
let shift = divisor.limbs.last().unwrap().leading_zeros() as usize;
let a = self.shl(shift);
let b = divisor.shl(shift);
let n = b.limbs.len();
let m = a.limbs.len() - n;
let mut q_limbs = vec![0u64; m + 1];
let mut u = a.limbs.clone();
if u.len() <= n + m {
u.resize(n + m + 1, 0);
}
let b_top = *b.limbs.last().unwrap() as u128;
for j in (0..=m).rev() {
let u_hi = ((u[j + n] as u128) << 64) | (u[j + n - 1] as u128);
let mut q_hat = u_hi / b_top;
let mut r_hat = u_hi % b_top;
if n >= 2 {
let b_second = b.limbs[n - 2] as u128;
while q_hat >= (1u128 << 64) || q_hat * b_second > (r_hat << 64) | (u[j + n - 2] as u128) {
q_hat -= 1;
r_hat += b_top;
if r_hat >= (1u128 << 64) {
break;
}
}
}
let mut borrow: i128 = 0;
for i in 0..n {
let prod = q_hat * (b.limbs[i] as u128);
let diff = (u[j + i] as i128) - borrow - (prod as u64 as i128);
u[j + i] = diff as u64;
borrow = (prod >> 64) as i128 - (diff >> 64) as i128;
}
let diff = (u[j + n] as i128) - borrow;
u[j + n] = diff as u64;
q_limbs[j] = q_hat as u64;
if diff < 0 {
q_limbs[j] -= 1;
let mut carry: u64 = 0;
for i in 0..n {
let (s1, c1) = u[j + i].overflowing_add(b.limbs[i]);
let (s2, c2) = s1.overflowing_add(carry);
u[j + i] = s2;
carry = (c1 as u64) + (c2 as u64);
}
u[j + n] = u[j + n].wrapping_add(carry);
}
}
let mut q = BigInt { limbs: q_limbs };
q.trim();
u.truncate(n);
let mut r = BigInt { limbs: u };
r.trim();
r = r.shr(shift);
(q, r)
}
pub fn shl(&self, bits: usize) -> BigInt {
if bits == 0 {
return self.clone();
}
let limb_shift = bits / 64;
let bit_shift = bits % 64;
let mut result = vec![0u64; self.limbs.len() + limb_shift + 1];
let mut carry: u64 = 0;
for i in 0..self.limbs.len() {
if bit_shift == 0 {
result[i + limb_shift] = self.limbs[i];
} else {
result[i + limb_shift] = (self.limbs[i] << bit_shift) | carry;
carry = self.limbs[i] >> (64 - bit_shift);
}
}
if carry != 0 {
result[self.limbs.len() + limb_shift] = carry;
}
let mut r = BigInt { limbs: result };
r.trim();
r
}
pub fn shr(&self, bits: usize) -> BigInt {
if bits == 0 {
return self.clone();
}
let limb_shift = bits / 64;
let bit_shift = bits % 64;
if limb_shift >= self.limbs.len() {
return BigInt::zero();
}
let new_len = self.limbs.len() - limb_shift;
let mut result = vec![0u64; new_len];
for i in 0..new_len {
let src = i + limb_shift;
result[i] = if bit_shift == 0 {
self.limbs[src]
} else {
let lo = self.limbs[src] >> bit_shift;
let hi = if src + 1 < self.limbs.len() {
self.limbs[src + 1] << (64 - bit_shift)
} else {
0
};
lo | hi
};
}
let mut r = BigInt { limbs: result };
r.trim();
r
}
pub fn rem(&self, modulus: &BigInt) -> BigInt {
self.div_rem(modulus).1
}
}
pub struct MontParams {
pub n: BigInt,
pub n_limbs: usize,
pub n_inv_neg: u64,
pub r_mod_n: BigInt,
pub r2_mod_n: BigInt,
}
impl MontParams {
pub fn new(n: &BigInt) -> Self {
let n_limbs = n.limbs.len();
debug_assert!(n.is_odd(), "Montgomery requires odd modulus");
let n0 = n.limbs[0];
let n_inv_neg = mod_inv_u64_neg(n0);
let mut r_val = BigInt::zero();
r_val.set_bit(64 * n_limbs);
let r_mod_n = r_val.rem(n);
let r2_val = r_val.mul(&r_val);
let r2_mod_n = r2_val.rem(n);
MontParams {
n: n.clone(),
n_limbs,
n_inv_neg,
r_mod_n,
r2_mod_n,
}
}
pub fn to_mont(&self, a: &BigInt) -> BigInt {
self.mont_mul(a, &self.r2_mod_n)
}
pub fn from_mont(&self, a: &BigInt) -> BigInt {
self.mont_mul(a, &BigInt::from_u64(1))
}
pub fn mont_mul(&self, a: &BigInt, b: &BigInt) -> BigInt {
let n = self.n_limbs;
let mut t = vec![0u64; n + 2];
for i in 0..n {
let bi = if i < b.limbs.len() { b.limbs[i] } else { 0 };
let mut carry: u64 = 0;
for j in 0..n {
let aj = if j < a.limbs.len() { a.limbs[j] } else { 0 };
let (lo, hi) = mul_u64(aj, bi);
let (s1, c1) = t[j].overflowing_add(lo);
let (s2, c2) = s1.overflowing_add(carry);
t[j] = s2;
carry = hi + (c1 as u64) + (c2 as u64);
}
let (s, c) = t[n].overflowing_add(carry);
t[n] = s;
t[n + 1] = c as u64;
let m = t[0].wrapping_mul(self.n_inv_neg);
let mut carry: u64 = 0;
{
let (lo, hi) = mul_u64(m, self.n.limbs[0]);
let (s1, c1) = t[0].overflowing_add(lo);
let (_s2, c2) = s1.overflowing_add(carry);
carry = hi + (c1 as u64) + (c2 as u64);
}
for j in 1..n {
let nj = self.n.limbs[j];
let (lo, hi) = mul_u64(m, nj);
let (s1, c1) = t[j].overflowing_add(lo);
let (s2, c2) = s1.overflowing_add(carry);
t[j - 1] = s2;
carry = hi + (c1 as u64) + (c2 as u64);
}
let (s1, c1) = t[n].overflowing_add(carry);
t[n - 1] = s1;
t[n] = t[n + 1] + (c1 as u64);
t[n + 1] = 0;
}
t.truncate(n + 1);
let mut result = BigInt { limbs: t };
result.trim();
if result >= self.n {
result = result.sub(&self.n);
}
result.trim();
result
}
pub fn mod_exp(&self, base: &BigInt, exp: &BigInt) -> BigInt {
let base_mont = self.to_mont(base);
let bits = exp.bit_len();
if bits == 0 {
return BigInt::from_u64(1);
}
let mut r0 = self.r_mod_n.clone(); let mut r1 = base_mont.clone();
for i in (0..bits).rev() {
if exp.bit(i) {
r0 = self.mont_mul(&r0, &r1);
r1 = self.mont_mul(&r1, &r1);
} else {
r1 = self.mont_mul(&r0, &r1);
r0 = self.mont_mul(&r0, &r0);
}
}
self.from_mont(&r0)
}
}
fn mod_inv_u64_neg(n0: u64) -> u64 {
let mut x: u64 = 1; for _ in 0..6 {
x = x.wrapping_mul(2u64.wrapping_sub(n0.wrapping_mul(x)));
}
x.wrapping_neg()
}
impl BigInt {
pub fn mod_exp(&self, exp: &BigInt, modulus: &BigInt) -> BigInt {
let params = MontParams::new(modulus);
params.mod_exp(self, exp)
}
pub fn mod_inv(&self, modulus: &BigInt) -> Option<BigInt> {
let (g, x, _neg) = extended_gcd(self, modulus);
if g != BigInt::from_u64(1) {
return None;
}
Some(x)
}
}
fn extended_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, bool) {
if a.is_zero() {
return (b.clone(), BigInt::zero(), false);
}
let mut old_r = a.clone();
let mut r = b.clone();
let mut old_s = BigInt::from_u64(1);
let mut s = BigInt::zero();
let mut old_s_neg = false; let mut s_neg = false;
while !r.is_zero() {
let (q, remainder) = old_r.div_rem(&r);
old_r = r;
r = remainder;
let qs = q.mul(&s);
let (new_s, new_s_neg) = signed_sub(&old_s, old_s_neg, &qs, s_neg);
old_s = s;
old_s_neg = s_neg;
s = new_s;
s_neg = new_s_neg;
}
let x = if old_s_neg { b.sub(&old_s.rem(b)) } else { old_s.rem(b) };
(old_r, x, false)
}
fn signed_sub(a: &BigInt, a_neg: bool, b: &BigInt, b_neg: bool) -> (BigInt, bool) {
if a_neg == b_neg {
if a >= b { (a.sub(b), a_neg) } else { (b.sub(a), !a_neg) }
} else {
(a.add(b), a_neg)
}
}
impl BigInt {
pub fn is_probably_prime(&self, rounds: usize, rng: &mut dyn FnMut(&mut [u8])) -> bool {
if self.limbs.len() == 1 {
let v = self.limbs[0];
if v < 2 {
return false;
}
if v == 2 || v == 3 {
return true;
}
if v % 2 == 0 {
return false;
}
}
if self.is_even() {
return false;
}
let one = BigInt::from_u64(1);
let two = BigInt::from_u64(2);
let n_minus_1 = self.sub(&one);
let n_minus_2 = self.sub(&two);
let mut d = n_minus_1.clone();
let mut s: usize = 0;
while d.is_even() {
d = d.shr(1);
s += 1;
}
let mont = MontParams::new(self);
'next_round: for _ in 0..rounds {
let a = loop {
let candidate = BigInt::random(self.bit_len(), rng);
if candidate >= two && candidate <= n_minus_2 {
break candidate;
}
};
let mut x = mont.mod_exp(&a, &d);
if x == one || x == n_minus_1 {
continue 'next_round;
}
for _ in 0..s - 1 {
x = mont.mod_exp(&x, &two);
if x == n_minus_1 {
continue 'next_round;
}
}
return false; }
true }
pub fn random_prime(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> BigInt {
loop {
let candidate = BigInt::random_odd(bits, rng);
let small_primes: &[u64] = &[
3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103,
107, 109, 113,
];
let mut skip = false;
for &p in small_primes {
let (_, rem) = candidate.div_rem(&BigInt::from_u64(p));
if rem.is_zero() {
if candidate == BigInt::from_u64(p) {
return candidate;
}
skip = true;
break;
}
}
if skip {
continue;
}
let rounds = if bits >= 1024 { 4 } else { 8 };
if candidate.is_probably_prime(rounds, rng) {
return candidate;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_sub() {
let a = BigInt::from_u64(u64::MAX);
let b = BigInt::from_u64(1);
let c = a.add(&b);
assert_eq!(c.limbs.len(), 2);
assert_eq!(c.limbs[0], 0);
assert_eq!(c.limbs[1], 1);
let d = c.sub(&b);
assert_eq!(d, a);
}
#[test]
fn test_mul() {
let a = BigInt::from_u64(0xFFFFFFFF);
let b = BigInt::from_u64(0xFFFFFFFF);
let c = a.mul(&b);
assert_eq!(c.limbs[0], 0xFFFFFFFE00000001);
}
#[test]
fn test_div_rem() {
let a = BigInt::from_u64(100);
let b = BigInt::from_u64(7);
let (q, r) = a.div_rem(&b);
assert_eq!(q, BigInt::from_u64(14));
assert_eq!(r, BigInt::from_u64(2));
}
#[test]
fn test_mod_exp() {
let base = BigInt::from_u64(3);
let exp = BigInt::from_u64(10);
let modulus = BigInt::from_u64(7);
let result = base.mod_exp(&exp, &modulus);
assert_eq!(result, BigInt::from_u64(4));
}
#[test]
fn test_mod_inv() {
let a = BigInt::from_u64(3);
let m = BigInt::from_u64(7);
let inv = a.mod_inv(&m).unwrap();
assert_eq!(inv, BigInt::from_u64(5));
}
#[test]
fn test_from_be_bytes_roundtrip() {
let bytes = vec![0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01];
let n = BigInt::from_be_bytes(&bytes);
let out = n.to_be_bytes(8);
assert_eq!(out, bytes);
}
#[test]
fn test_bit_ops() {
let mut n = BigInt::zero();
n.set_bit(65);
assert!(n.bit(65));
assert!(!n.bit(64));
assert_eq!(n.bit_len(), 66);
}
fn test_rng() -> impl FnMut(&mut [u8]) {
let mut state: u64 = 0xdeadbeefcafebabe;
move |buf: &mut [u8]| {
for b in buf.iter_mut() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*b = (state >> 33) as u8;
}
}
}
#[test]
fn test_primality() {
let mut rng = test_rng();
assert!(BigInt::from_u64(7).is_probably_prime(10, &mut rng));
assert!(!BigInt::from_u64(15).is_probably_prime(10, &mut rng));
assert!(BigInt::from_u64(104729).is_probably_prime(10, &mut rng));
}
}