extern crate alloc;
use core::cmp::Ordering;
use core::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign,
};
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy)]
pub(crate) enum LossFraction {
ExactlyZero, LessThanHalf, ExactlyHalf, MoreThanHalf, }
impl LossFraction {
pub fn is_exactly_zero(&self) -> bool {
matches!(self, Self::ExactlyZero)
}
pub fn is_lt_half(&self) -> bool {
matches!(self, Self::LessThanHalf) || self.is_exactly_zero()
}
pub fn is_exactly_half(&self) -> bool {
matches!(self, Self::ExactlyHalf)
}
pub fn is_mt_half(&self) -> bool {
matches!(self, Self::MoreThanHalf)
}
#[allow(dead_code)]
pub fn is_lte_half(&self) -> bool {
self.is_lt_half() || self.is_exactly_half()
}
pub fn is_gte_half(&self) -> bool {
self.is_mt_half() || self.is_exactly_half()
}
pub fn invert(&self) -> LossFraction {
match self {
LossFraction::LessThanHalf => LossFraction::MoreThanHalf,
LossFraction::MoreThanHalf => LossFraction::LessThanHalf,
_ => *self,
}
}
}
#[derive(Debug, Clone)]
pub struct BigInt {
parts: Vec<u64>,
}
impl BigInt {
pub fn zero() -> Self {
BigInt::from_u64(0)
}
pub fn one() -> Self {
Self::from_u64(1)
}
pub fn one_hot(bit: usize) -> Self {
let mut x = Self::zero();
x.flip_bit(bit);
x
}
pub fn all1s(bits: usize) -> Self {
if bits == 0 {
return Self::zero();
}
let mut x = Self::one();
x.shift_left(bits);
let _ = x.inplace_sub(&Self::one());
debug_assert_eq!(x.msb_index(), bits);
x
}
pub fn from_u64(val: u64) -> Self {
let vec = Vec::from([val]);
BigInt { parts: vec }
}
pub fn from_u128(val: u128) -> Self {
let a = val as u64;
let b = (val >> 64) as u64;
let vec = Vec::from([a, b]);
BigInt { parts: vec }
}
pub fn pseudorandom(parts: usize, seed: u32) -> Self {
use crate::utils::Lfsr;
let mut ll = Lfsr::new_with_seed(seed);
BigInt::from_iter(&mut ll, parts)
}
pub fn len(&self) -> usize {
self.parts.len()
}
pub fn is_empty(&self) -> bool {
self.parts.is_empty()
}
pub fn as_u64(&self) -> u64 {
for i in 1..self.len() {
debug_assert_eq!(self.parts[i], 0);
}
self.parts[0]
}
pub fn as_u128(&self) -> u128 {
if self.len() >= 2 {
for i in 2..self.len() {
debug_assert_eq!(self.parts[i], 0);
}
(self.parts[0] as u128) + ((self.parts[1] as u128) << 64)
} else {
self.parts[0] as u128
}
}
pub fn is_zero(&self) -> bool {
for elem in self.parts.iter() {
if *elem != 0 {
return false;
}
}
true
}
pub fn is_even(&self) -> bool {
(self.parts[0] & 0x1) == 0
}
pub fn is_odd(&self) -> bool {
(self.parts[0] & 0x1) == 1
}
pub fn flip_bit(&mut self, bit_num: usize) {
let which_word = bit_num / u64::BITS as usize;
let bit_in_word = bit_num % u64::BITS as usize;
self.grow(which_word + 1);
debug_assert!(which_word < self.len(), "Bit out of bounds");
self.parts[which_word] ^= 1 << bit_in_word;
}
pub fn mask(&mut self, bits: usize) {
let mut bits = bits;
for i in 0..self.len() {
if bits >= 64 {
bits -= 64;
continue;
}
if bits == 0 {
self.parts[i] = 0;
continue;
}
let mask = (1u64 << bits) - 1;
self.parts[i] &= mask;
bits = 0;
}
}
pub(crate) fn get_loss_kind_for_bit(&self, bit: usize) -> LossFraction {
if self.is_zero() {
return LossFraction::ExactlyZero;
}
if bit > self.len() * 64 {
return LossFraction::LessThanHalf;
}
let mut a = self.clone();
a.mask(bit);
if a.is_zero() {
return LossFraction::ExactlyZero;
}
let half = Self::one_hot(bit - 1);
match a.cmp(&half) {
Ordering::Less => LossFraction::LessThanHalf,
Ordering::Equal => LossFraction::ExactlyHalf,
Ordering::Greater => LossFraction::MoreThanHalf,
}
}
pub fn msb_index(&self) -> usize {
for i in (0..self.len()).rev() {
let part = self.parts[i];
if part != 0 {
let idx = 64 - part.leading_zeros() as usize;
return i * 64 + idx;
}
}
0
}
pub fn trailing_zeros(&self) -> usize {
debug_assert!(!self.is_zero());
for i in 0..self.len() {
let part = self.parts[i];
if part != 0 {
let idx = part.trailing_zeros() as usize;
return i * 64 + idx;
}
}
panic!("Expected a non-zero number");
}
pub fn from_parts(parts: &[u64]) -> Self {
let parts: Vec<u64> = parts.to_vec();
BigInt { parts }
}
pub fn from_iter<I: Iterator<Item = u64>>(iter: &mut I, k: usize) -> Self {
let parts: Vec<u64> = iter.take(k).collect();
BigInt { parts }
}
pub fn grow(&mut self, size: usize) {
for _ in self.len()..size {
self.parts.push(0);
}
}
fn shrink(&mut self) {
while self.len() > 2 && self.parts[self.len() - 1] == 0 {
self.parts.pop();
}
}
pub fn inplace_add(&mut self, rhs: &Self) {
self.inplace_add_slice(&rhs.parts[..]);
}
#[allow(clippy::needless_range_loop)]
pub(crate) fn inplace_add_slice(&mut self, rhs: &[u64]) {
self.grow(rhs.len());
let mut carry: bool = false;
for i in 0..rhs.len() {
let first = self.parts[i].overflowing_add(rhs[i]);
let second = first.0.overflowing_add(carry as u64);
carry = first.1 || second.1;
self.parts[i] = second.0;
}
for i in rhs.len()..self.len() {
let second = self.parts[i].overflowing_add(carry as u64);
carry = second.1;
self.parts[i] = second.0;
}
if carry {
self.parts.push(1);
}
self.shrink()
}
#[must_use]
pub fn inplace_sub(&mut self, rhs: &Self) -> bool {
self.inplace_sub_slice(&rhs.parts[..], 0)
}
#[allow(clippy::needless_range_loop)]
fn inplace_sub_slice(&mut self, rhs: &[u64], bottom_zeros: usize) -> bool {
self.grow(rhs.len());
let mut borrow: bool = false;
for i in bottom_zeros..rhs.len() {
let first = self.parts[i].overflowing_sub(rhs[i]);
let second = first.0.overflowing_sub(borrow as u64);
borrow = first.1 || second.1;
self.parts[i] = second.0;
}
for i in rhs.len()..self.len() {
let second = self.parts[i].overflowing_sub(borrow as u64);
self.parts[i] = second.0;
borrow = second.1;
}
self.shrink();
borrow
}
fn zeros(size: usize) -> Vec<u64> {
core::iter::repeat(0).take(size).collect()
}
pub fn inplace_mul(&mut self, rhs: &Self) {
if self.len() > KARATSUBA_SIZE_THRESHOLD
|| rhs.len() > KARATSUBA_SIZE_THRESHOLD
{
*self = Self::mul_karatsuba(self, rhs);
return;
}
self.inplace_mul_slice(rhs);
}
fn inplace_mul_slice(&mut self, rhs: &[u64]) {
let size = self.len() + rhs.len() + 1;
let mut parts = Self::zeros(size);
let mut carries = Self::zeros(size);
for i in 0..self.len() {
for j in 0..rhs.len() {
let pi = self.parts[i] as u128;
let pij = pi * rhs[j] as u128;
let add0 = parts[i + j].overflowing_add(pij as u64);
parts[i + j] = add0.0;
carries[i + j] += add0.1 as u64;
let add1 = parts[i + j + 1].overflowing_add((pij >> 64) as u64);
parts[i + j + 1] = add1.0;
carries[i + j + 1] += add1.1 as u64;
}
}
self.grow(size);
let mut carry: u64 = 0;
for i in 0..size {
let add0 = parts[i].overflowing_add(carry);
self.parts[i] = add0.0;
carry = add0.1 as u64 + carries[i];
}
self.shrink();
assert!(carry == 0);
}
pub fn inplace_div(&mut self, divisor: &Self) -> Self {
let mut dividend = self.clone();
let mut divisor = divisor.clone();
let mut quotient = Self::zero();
if self.len() == 1 && divisor.parts.len() == 1 {
let a = dividend.get_part(0);
let b = divisor.get_part(0);
let res = a / b;
let rem = a % b;
self.parts[0] = res;
return Self::from_u64(rem);
}
let dividend_msb = dividend.msb_index();
let divisor_msb = divisor.msb_index();
assert_ne!(divisor_msb, 0, "division by zero");
if divisor_msb > dividend_msb {
let ret = self.clone();
*self = Self::zero();
return ret;
}
let bits = dividend_msb - divisor_msb;
divisor.shift_left(bits);
for i in (0..bits + 1).rev() {
let low_zeros = i / 64;
if dividend >= divisor {
let overflow = dividend.inplace_sub_slice(&divisor, low_zeros);
debug_assert!(!overflow);
quotient.flip_bit(i);
}
divisor.shift_right(1);
}
*self = quotient;
self.shrink();
dividend
}
pub fn shift_left(&mut self, bits: usize) {
let words_to_shift = bits / u64::BITS as usize;
let bits_in_word = bits % u64::BITS as usize;
for _ in 0..words_to_shift + 1 {
self.parts.push(0);
}
if bits_in_word == 0 {
for i in (0..self.len()).rev() {
self.parts[i] = if i >= words_to_shift {
self.parts[i - words_to_shift]
} else {
0
};
}
return;
}
for i in (0..self.len()).rev() {
let left_val = if i >= words_to_shift {
self.parts[i - words_to_shift]
} else {
0
};
let right_val = if i > words_to_shift {
self.parts[i - words_to_shift - 1]
} else {
0
};
let right = right_val >> (u64::BITS as usize - bits_in_word);
let left = left_val << bits_in_word;
self.parts[i] = left | right;
}
}
pub fn shift_right(&mut self, bits: usize) {
let words_to_shift = bits / u64::BITS as usize;
let bits_in_word = bits % u64::BITS as usize;
if bits_in_word == 0 {
for i in 0..self.len() {
self.parts[i] = if i + words_to_shift < self.len() {
self.parts[i + words_to_shift]
} else {
0
};
}
self.shrink();
return;
}
for i in 0..self.len() {
let left_val = if i + words_to_shift < self.len() {
self.parts[i + words_to_shift]
} else {
0
};
let right_val = if i + 1 + words_to_shift < self.len() {
self.parts[i + 1 + words_to_shift]
} else {
0
};
let right = right_val << (u64::BITS as usize - bits_in_word);
let left = left_val >> bits_in_word;
self.parts[i] = left | right;
}
self.shrink();
}
pub fn powi(&self, mut exp: u64) -> Self {
let mut v = Self::one();
let mut base = self.clone();
loop {
if exp & 0x1 == 1 {
v.inplace_mul(&base);
}
exp >>= 1;
if exp == 0 {
break;
}
base.inplace_mul(&base.clone());
}
v
}
pub fn get_part(&self, idx: usize) -> u64 {
self.parts[idx]
}
#[cfg(feature = "std")]
pub fn dump(&self) {
use std::println;
println!("[{}]", self.as_binary());
}
#[cfg(not(feature = "std"))]
pub fn dump(&self) {
}
}
impl Default for BigInt {
fn default() -> Self {
Self::zero()
}
}
#[test]
fn test_powi5() {
let lookup = [1, 5, 25, 125, 625, 3125, 15625, 78125];
for (i, val) in lookup.iter().enumerate() {
let five = BigInt::from_u64(5);
assert_eq!(five.powi(i as u64).as_u64(), *val);
}
let v15 = BigInt::from_u64(15);
assert_eq!(v15.powi(16).as_u64(), 6568408355712890625);
let v3 = BigInt::from_u64(3);
assert_eq!(v3.powi(21).as_u64(), 10460353203);
}
#[test]
fn test_shl() {
let mut x = BigInt::from_u64(0xff00ff);
assert_eq!(x.get_part(0), 0xff00ff);
x.shift_left(17);
assert_eq!(x.get_part(0), 0x1fe01fe0000);
x.shift_left(17);
assert_eq!(x.get_part(0), 0x3fc03fc00000000);
x.shift_left(64);
assert_eq!(x.get_part(1), 0x3fc03fc00000000);
}
#[test]
fn test_shr() {
let mut x = BigInt::from_u64(0xff00ff);
x.shift_left(128);
assert_eq!(x.get_part(2), 0xff00ff);
x.shift_right(17);
assert_eq!(x.get_part(1), 0x807f800000000000);
x.shift_right(17);
assert_eq!(x.get_part(1), 0x03fc03fc0000000);
x.shift_right(64);
assert_eq!(x.get_part(0), 0x03fc03fc0000000);
}
#[test]
fn test_mul_basic() {
let mut x = BigInt::from_u64(0xffff_ffff_ffff_ffff);
let y = BigInt::from_u64(25);
x.inplace_mul(&x.clone());
x.inplace_mul(&y);
assert_eq!(x.get_part(0), 0x19);
assert_eq!(x.get_part(1), 0xffff_ffff_ffff_ffce);
assert_eq!(x.get_part(2), 0x18);
}
#[test]
fn test_add_basic() {
let mut x = BigInt::from_u64(0xffffffff00000000);
let y = BigInt::from_u64(0xffffffff);
let z = BigInt::from_u64(0xf);
x.inplace_add(&y);
assert_eq!(x.get_part(0), 0xffffffffffffffff);
x.inplace_add(&z);
assert_eq!(x.get_part(0), 0xe);
assert_eq!(x.get_part(1), 0x1);
}
#[test]
fn test_div_basic() {
let mut x1 = BigInt::from_u64(49);
let mut x2 = BigInt::from_u64(703);
let y = BigInt::from_u64(7);
let rem = x1.inplace_div(&y);
assert_eq!(x1.as_u64(), 7);
assert_eq!(rem.as_u64(), 0);
let rem = x2.inplace_div(&y);
assert_eq!(x2.as_u64(), 100);
assert_eq!(rem.as_u64(), 3);
}
#[test]
fn test_div_10() {
let mut x1 = BigInt::from_u64(19940521);
let ten = BigInt::from_u64(10);
assert_eq!(x1.inplace_div(&ten).as_u64(), 1);
assert_eq!(x1.inplace_div(&ten).as_u64(), 2);
assert_eq!(x1.inplace_div(&ten).as_u64(), 5);
assert_eq!(x1.inplace_div(&ten).as_u64(), 0);
assert_eq!(x1.inplace_div(&ten).as_u64(), 4);
}
#[allow(dead_code)]
fn test_with_random_values(
correct: fn(u128, u128) -> (u128, bool),
test: fn(u128, u128) -> (u128, bool),
) {
use super::utils::Lfsr;
let mut lfsr = Lfsr::new();
for _ in 0..50000 {
let v0 = lfsr.get64();
let v1 = lfsr.get64();
let v2 = lfsr.get64();
let v3 = lfsr.get64();
let n1 = (v0 as u128) + ((v1 as u128) << 64);
let n2 = (v2 as u128) + ((v3 as u128) << 64);
let v1 = correct(n1, n2);
let v2 = test(n1, n2);
assert_eq!(v1.0, v2.0, "Incorrect value");
assert_eq!(v1.0, v2.0, "Incorrect carry");
}
}
#[test]
fn test_sub_basic() {
let mut x = BigInt::from_parts(&[0x0, 0x1, 0]);
let y = BigInt::from_u64(0x1);
let c1 = x.inplace_sub(&y);
assert!(!c1);
assert_eq!(x.get_part(0), 0xffffffffffffffff);
assert_eq!(x.get_part(1), 0);
let mut x = BigInt::from_parts(&[0x1, 0x1]);
let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
let c1 = x.inplace_sub(&y);
assert!(!c1);
assert_eq!(x.get_part(0), 0x1);
assert_eq!(x.get_part(1), 0);
let mut x = BigInt::from_parts(&[0x1, 0x1, 0x1]);
let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
let c1 = x.inplace_sub(&y);
assert!(!c1);
assert_eq!(x.get_part(0), 0x1);
assert_eq!(x.get_part(1), 0);
assert_eq!(x.get_part(2), 0x1);
}
#[test]
fn test_mask_basic() {
let mut x = BigInt::from_parts(&[0b11111, 0b10101010101010, 0b111]);
x.mask(69);
assert_eq!(x.get_part(0), 0b11111); assert_eq!(x.get_part(1), 0b01010); assert_eq!(x.get_part(2), 0b0); }
#[test]
fn test_basic_operations() {
fn correct_sub(a: u128, b: u128) -> (u128, bool) {
a.overflowing_sub(b)
}
fn correct_add(a: u128, b: u128) -> (u128, bool) {
a.overflowing_add(b)
}
fn correct_mul(a: u128, b: u128) -> (u128, bool) {
a.overflowing_mul(b)
}
fn correct_div(a: u128, b: u128) -> (u128, bool) {
a.overflowing_div(b)
}
fn test_sub(a: u128, b: u128) -> (u128, bool) {
let mut a = BigInt::from_u128(a);
let b = BigInt::from_u128(b);
let c = a.inplace_sub(&b);
(a.as_u128(), c)
}
fn test_add(a: u128, b: u128) -> (u128, bool) {
let mut a = BigInt::from_u128(a);
let b = BigInt::from_u128(b);
let mut carry = false;
a.inplace_add(&b);
if a.len() > 2 {
carry = true;
a.parts[2] = 0;
}
(a.as_u128(), carry)
}
fn test_mul(a: u128, b: u128) -> (u128, bool) {
let mut a = BigInt::from_u128(a);
let b = BigInt::from_u128(b);
let mut carry = false;
a.inplace_mul(&b);
if a.len() > 2 {
carry = true;
a.parts[2] = 0;
a.parts[3] = 0;
}
(a.as_u128(), carry)
}
fn test_div(a: u128, b: u128) -> (u128, bool) {
let mut a = BigInt::from_u128(a);
let b = BigInt::from_u128(b);
a.inplace_div(&b);
(a.as_u128(), false)
}
fn correct_cmp(a: u128, b: u128) -> (u128, bool) {
(
match a.cmp(&b) {
Ordering::Less => 1,
Ordering::Equal => 2,
Ordering::Greater => 3,
} as u128,
false,
)
}
fn test_cmp(a: u128, b: u128) -> (u128, bool) {
let a = BigInt::from_u128(a);
let b = BigInt::from_u128(b);
(
match a.cmp(&b) {
Ordering::Less => 1,
Ordering::Equal => 2,
Ordering::Greater => 3,
} as u128,
false,
)
}
test_with_random_values(correct_mul, test_mul);
test_with_random_values(correct_div, test_div);
test_with_random_values(correct_add, test_add);
test_with_random_values(correct_sub, test_sub);
test_with_random_values(correct_cmp, test_cmp);
}
#[test]
fn test_msb() {
let x = BigInt::from_u64(0xffffffff00000000);
assert_eq!(x.msb_index(), 64);
let x = BigInt::from_u64(0x0);
assert_eq!(x.msb_index(), 0);
let x = BigInt::from_u64(0x1);
assert_eq!(x.msb_index(), 1);
let mut x = BigInt::from_u64(0x1);
x.shift_left(189);
assert_eq!(x.msb_index(), 189 + 1);
for i in 0..256 {
let mut x = BigInt::from_u64(0x1);
x.shift_left(i);
assert_eq!(x.msb_index(), i + 1);
}
}
#[test]
fn test_trailing_zero() {
let x = BigInt::from_u64(0xffffffff00000000);
assert_eq!(x.trailing_zeros(), 32);
let x = BigInt::from_u64(0x1);
assert_eq!(x.trailing_zeros(), 0);
let x = BigInt::from_u64(0x8);
assert_eq!(x.trailing_zeros(), 3);
let mut x = BigInt::from_u64(0x1);
x.shift_left(189);
assert_eq!(x.trailing_zeros(), 189);
for i in 0..256 {
let mut x = BigInt::from_u64(0x1);
x.shift_left(i);
assert_eq!(x.trailing_zeros(), i);
}
}
impl Eq for BigInt {}
impl PartialEq for BigInt {
fn eq(&self, other: &BigInt) -> bool {
self.cmp(other).is_eq()
}
}
impl PartialOrd for BigInt {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for BigInt {
fn cmp(&self, other: &Self) -> Ordering {
if self.len() > other.len()
&& self.parts[other.len()..].iter().any(|&x| x != 0)
{
return Ordering::Greater;
}
if other.len() > self.len()
&& other.parts[self.len()..].iter().any(|&x| x != 0)
{
return Ordering::Less;
}
let same_len = other.len().min(self.len());
for i in (0..same_len).rev() {
match self.parts[i].cmp(&other.parts[i]) {
Ordering::Less => return Ordering::Less,
Ordering::Equal => {}
Ordering::Greater => return Ordering::Greater,
}
}
Ordering::Equal
}
}
macro_rules! declare_operator {
($trait_name:ident,
$func_name:ident,
$func_impl_name:ident) => {
impl $trait_name for BigInt {
type Output = Self;
fn $func_name(self, rhs: Self) -> Self::Output {
self.$func_name(&rhs)
}
}
impl $trait_name<&Self> for BigInt {
type Output = Self;
fn $func_name(self, rhs: &Self) -> Self::Output {
let mut n = self;
let _ = n.$func_impl_name(rhs);
n
}
}
impl $trait_name<Self> for &BigInt {
type Output = BigInt;
fn $func_name(self, rhs: Self) -> Self::Output {
let mut n = self.clone();
let _ = n.$func_impl_name(rhs);
n
}
}
impl $trait_name<u64> for BigInt {
type Output = Self;
fn $func_name(self, rhs: u64) -> Self::Output {
let mut n = self;
let _ = n.$func_impl_name(&Self::from_u64(rhs));
n
}
}
};
}
declare_operator!(Add, add, inplace_add);
declare_operator!(Sub, sub, inplace_sub);
declare_operator!(Mul, mul, inplace_mul);
declare_operator!(Div, div, inplace_div);
macro_rules! declare_assign_operator {
($trait_name:ident,
$func_name:ident,
$func_impl_name:ident) => {
impl $trait_name for BigInt {
fn $func_name(&mut self, rhs: Self) {
let _ = self.$func_impl_name(&rhs);
}
}
impl $trait_name<&BigInt> for BigInt {
fn $func_name(&mut self, rhs: &Self) {
let _ = self.$func_impl_name(&rhs);
}
}
};
}
declare_assign_operator!(AddAssign, add_assign, inplace_add);
declare_assign_operator!(SubAssign, sub_assign, inplace_sub);
declare_assign_operator!(MulAssign, mul_assign, inplace_mul);
declare_assign_operator!(DivAssign, div_assign, inplace_div);
#[test]
fn test_bigint_operators() {
type BI = BigInt;
let x = BI::from_u64(10);
let y = BI::from_u64(1);
let c = ((&x - &y) * x) / 2;
assert_eq!(c.as_u64(), 45);
assert_eq!((&y + &y).as_u64(), 2);
}
#[test]
fn test_all1s_ctor() {
type BI = BigInt;
let v0 = BI::all1s(0);
let v1 = BI::all1s(1);
let v2 = BI::all1s(5);
let v3 = BI::all1s(32);
assert_eq!(v0.get_part(0), 0b0);
assert_eq!(v1.get_part(0), 0b1);
assert_eq!(v2.get_part(0), 0b11111);
assert_eq!(v3.get_part(0), 0xffffffff);
}
#[test]
fn test_flip_bit() {
type BI = BigInt;
{
let mut v0 = BI::zero();
assert_eq!(v0.get_part(0), 0);
v0.flip_bit(0);
assert_eq!(v0.get_part(0), 1);
v0.flip_bit(0);
assert_eq!(v0.get_part(0), 0);
}
{
let mut v0 = BI::zero();
v0.flip_bit(16);
assert_eq!(v0.get_part(0), 65536);
}
{
let mut v0 = BI::zero();
v0.flip_bit(95);
v0.shift_right(95);
assert_eq!(v0.get_part(0), 1);
}
}
#[cfg(feature = "std")]
#[test]
fn test_mul_div_encode_decode() {
use alloc::vec::Vec;
const BASE: u64 = 5;
type BI = BigInt;
let base = BI::from_u64(BASE);
let mut bitstream = BI::from_u64(0);
let mut message: Vec<u64> = Vec::new();
for i in 0..275 {
message.push(((i + 6) * 17) % BASE);
}
for letter in &message {
let letter = BI::from_u64(*letter);
bitstream.inplace_mul(&base);
bitstream.inplace_add(&letter);
}
let len = message.len();
for idx in (0..len).rev() {
let rem = bitstream.inplace_div(&base);
assert_eq!(message[idx], rem.as_u64());
}
}
impl BigInt {
fn to_digits_impl<const DIGIT: u8>(
num: &mut BigInt,
num_digits: usize,
output: &mut Vec<u8>,
) -> usize {
const SPLIT_WORD_THRESHOLD: usize = 5;
let bits_per_digit = (8 - DIGIT.leading_zeros()) as usize;
let digits_per_word = 64 / bits_per_digit;
let digit = DIGIT as u64;
let len = num.len();
if len > SPLIT_WORD_THRESHOLD {
let half = len / 2 - 1;
let k = digits_per_word * half;
let mega_digit = BigInt::from_u64(digit).powi(k as u64);
let mut rem = num.inplace_div(&mega_digit);
let tail = Self::to_digits_impl::<DIGIT>(&mut rem, k, output);
let hd = Self::to_digits_impl::<DIGIT>(num, num_digits - k, output);
debug_assert_eq!(tail, k);
debug_assert_eq!(hd, num_digits - k);
return num_digits;
}
let mut extracted = 0;
let divisor = BigInt::from_u64(digit.pow(digits_per_word as u32));
for _ in 0..(num_digits / digits_per_word) {
let mut rem = num.inplace_div(&divisor);
extracted += digits_per_word;
Self::extract_digits::<DIGIT>(digits_per_word, &mut rem, output);
}
let iters = num_digits % digits_per_word;
Self::extract_digits::<DIGIT>(iters, num, output);
extracted += iters;
extracted
}
fn extract_digits<const DIGIT: u8>(
iter: usize,
num: &mut BigInt,
vec: &mut Vec<u8>,
) {
let digit = BigInt::from_u64(DIGIT as u64);
for _ in 0..iter {
let d = num.inplace_div(&digit).as_u64();
vec.push(d as u8);
}
}
pub(crate) fn to_digits<const DIGIT: u8>(&self) -> Vec<u8> {
let mut num = self.clone();
num.shrink();
let mut output: Vec<u8> = Vec::new();
while !num.is_zero() {
let len = num.len();
let digits = (len * 64 * 59) / 196;
Self::to_digits_impl::<DIGIT>(&mut num, digits, &mut output);
}
while output.len() > 1 && output[output.len() - 1] == 0 {
output.pop();
}
output.reverse();
output
}
}
#[test]
pub fn test_bigint_to_digits() {
use alloc::string::String;
use core::primitive::char;
fn vec_to_string(vec: Vec<u8>, base: u32) -> String {
let mut sb = String::new();
for d in vec {
sb.push(char::from_digit(d as u32, base).unwrap())
}
sb
}
let mut num = BigInt::from_u64(0b111000111000101010);
num.shift_left(64);
let digits = num.to_digits::<2>();
assert_eq!(
vec_to_string(digits, 2),
"1110001110001010100000000000000\
0000000000000000000000000000000\
00000000000000000000"
);
let num = BigInt::from_u64(90210);
let digits = num.to_digits::<10>();
assert_eq!(vec_to_string(digits, 10), "90210");
let num = BigInt::from_u128(123_456_123_456_987_654_987_654u128);
let digits = num.to_digits::<10>();
assert_eq!(vec_to_string(digits, 10), "123456123456987654987654");
}
const KARATSUBA_SIZE_THRESHOLD: usize = 64;
impl BigInt {
fn mul_karatsuba(lhs: &[u64], rhs: &[u64]) -> BigInt {
if lhs.len().min(rhs.len()) < KARATSUBA_SIZE_THRESHOLD {
if lhs.is_empty() || rhs.is_empty() {
return BigInt::zero();
}
let mut lhs = BigInt::from_parts(lhs);
lhs.inplace_mul_slice(rhs);
return lhs;
}
let mid = lhs.len().max(rhs.len()) / 2;
let a = &lhs[0..mid.min(lhs.len())];
let b = &lhs[mid.min(lhs.len())..];
let c = &rhs[0..mid.min(rhs.len())];
let d = &rhs[mid.min(rhs.len())..];
let ac = Self::mul_karatsuba(a, c);
let mut bd = Self::mul_karatsuba(b, d);
let mut a_b = BigInt::from_parts(a);
a_b.inplace_add_slice(b);
let mut c_d = BigInt::from_parts(c);
c_d.inplace_add_slice(d);
let mut ad_plus_bc = Self::mul_karatsuba(&a_b, &c_d);
ad_plus_bc.inplace_sub_slice(&ac, 0);
ad_plus_bc.inplace_sub_slice(&bd, 0);
bd.shift_left(64 * mid * 2);
ad_plus_bc.shift_left(64 * mid);
bd.inplace_add(&ad_plus_bc);
bd.inplace_add(&ac);
bd
}
}
#[test]
fn test_mul_karatsuba() {
use crate::utils::Lfsr;
let mut ll = Lfsr::new();
fn test_sizes(l: usize, r: usize, ll: &mut Lfsr) {
let mut a = BigInt::from_iter(ll, l);
let b = BigInt::from_iter(ll, r);
let res = BigInt::mul_karatsuba(&a, &b);
a.inplace_mul_slice(&b);
assert_eq!(res, a);
}
test_sizes(1, 1, &mut ll);
test_sizes(100, 1, &mut ll);
test_sizes(1, 100, &mut ll);
test_sizes(100, 100, &mut ll);
test_sizes(1000, 1000, &mut ll);
test_sizes(1000, 1001, &mut ll);
for i in 64..90 {
for j in 1..128 {
test_sizes(i, j, &mut ll);
}
}
}
use core::ops::Deref;
impl Deref for BigInt {
type Target = [u64];
fn deref(&self) -> &Self::Target {
&self.parts[..]
}
}