use rand::distr::uniform::{self, SampleBorrow, SampleUniform, UniformSampler};
use rand::prelude::*;
use rand::Rng;
use crate::BigInt;
use crate::BigUint;
use crate::Sign::*;
use crate::big_digit::BigDigit;
use crate::bigint::{into_magnitude, magnitude};
use crate::integer::Integer;
#[cfg(feature = "prime")]
use num_iter::range_step;
use num_traits::Zero;
#[cfg(feature = "prime")]
use num_traits::{FromPrimitive, ToPrimitive};
#[cfg(feature = "prime")]
use once_cell::sync::Lazy;
#[cfg(feature = "prime")]
use crate::prime::probably_prime;
pub trait RandBigInt {
fn gen_biguint(&mut self, bit_size: usize) -> BigUint;
fn gen_bigint(&mut self, bit_size: usize) -> BigInt;
fn gen_biguint_below(&mut self, bound: &BigUint) -> BigUint;
fn gen_biguint_range(&mut self, lbound: &BigUint, ubound: &BigUint) -> BigUint;
fn gen_bigint_range(&mut self, lbound: &BigInt, ubound: &BigInt) -> BigInt;
}
impl<R: Rng + ?Sized> RandBigInt for R {
fn gen_biguint(&mut self, bit_size: usize) -> BigUint {
use super::big_digit::BITS;
let (digits, rem) = bit_size.div_rem(&BITS);
let mut data = smallvec![BigDigit::default(); digits + (rem > 0) as usize];
self.fill(data.as_mut_slice());
if rem > 0 {
data[digits] >>= BITS - rem;
}
BigUint::new_native(data)
}
fn gen_bigint(&mut self, bit_size: usize) -> BigInt {
loop {
let biguint = self.gen_biguint(bit_size);
let sign = if biguint.is_zero() {
if self.random() {
continue;
} else {
NoSign
}
} else if self.random() {
Plus
} else {
Minus
};
return BigInt::from_biguint(sign, biguint);
}
}
fn gen_biguint_below(&mut self, bound: &BigUint) -> BigUint {
assert!(!bound.is_zero());
let bits = bound.bits();
loop {
let n = self.gen_biguint(bits);
if n < *bound {
return n;
}
}
}
fn gen_biguint_range(&mut self, lbound: &BigUint, ubound: &BigUint) -> BigUint {
assert!(*lbound < *ubound);
if lbound.is_zero() {
self.gen_biguint_below(ubound)
} else {
lbound + self.gen_biguint_below(&(ubound - lbound))
}
}
fn gen_bigint_range(&mut self, lbound: &BigInt, ubound: &BigInt) -> BigInt {
assert!(*lbound < *ubound);
if lbound.is_zero() {
BigInt::from(self.gen_biguint_below(magnitude(&ubound)))
} else if ubound.is_zero() {
lbound + BigInt::from(self.gen_biguint_below(magnitude(&lbound)))
} else {
let delta = ubound - lbound;
lbound + BigInt::from(self.gen_biguint_below(magnitude(&delta)))
}
}
}
#[derive(Clone, Debug)]
pub struct UniformBigUint {
base: BigUint,
len: BigUint,
}
impl UniformSampler for UniformBigUint {
type X = BigUint;
#[inline]
fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = low_b.borrow();
let high = high_b.borrow();
assert!(low < high);
Ok(UniformBigUint {
len: high - low,
base: low.clone(),
})
}
#[inline]
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
Self::new(low_b, high_b.borrow() + 1u32)
}
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
&self.base + rng.gen_biguint_below(&self.len)
}
#[inline]
fn sample_single<R: Rng + ?Sized, B1, B2>(
low_b: B1,
high_b: B2,
rng: &mut R,
) -> Result<Self::X, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = low_b.borrow();
let high = high_b.borrow();
Ok(rng.gen_biguint_range(low, high))
}
}
impl SampleUniform for BigUint {
type Sampler = UniformBigUint;
}
#[derive(Clone, Debug)]
pub struct UniformBigInt {
base: BigInt,
len: BigUint,
}
impl UniformSampler for UniformBigInt {
type X = BigInt;
#[inline]
fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = low_b.borrow();
let high = high_b.borrow();
assert!(low < high);
Ok(UniformBigInt {
len: into_magnitude(high - low),
base: low.clone(),
})
}
#[inline]
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = low_b.borrow();
let high = high_b.borrow();
assert!(low <= high);
Self::new(low, high + 1u32)
}
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
&self.base + BigInt::from(rng.gen_biguint_below(&self.len))
}
#[inline]
fn sample_single<R: Rng + ?Sized, B1, B2>(
low_b: B1,
high_b: B2,
rng: &mut R,
) -> Result<Self::X, uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = low_b.borrow();
let high = high_b.borrow();
Ok(rng.gen_bigint_range(low, high))
}
}
impl SampleUniform for BigInt {
type Sampler = UniformBigInt;
}
#[derive(Clone, Copy, Debug)]
pub struct RandomBits {
bits: usize,
}
impl RandomBits {
#[inline]
pub fn new(bits: usize) -> RandomBits {
RandomBits { bits }
}
}
impl Distribution<BigUint> for RandomBits {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> BigUint {
rng.gen_biguint(self.bits)
}
}
impl Distribution<BigInt> for RandomBits {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> BigInt {
rng.gen_bigint(self.bits)
}
}
#[cfg_attr(feature = "std", doc = " ```")]
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
#[cfg(feature = "prime")]
pub trait RandPrime {
fn gen_prime(&mut self, bits: usize) -> BigUint;
}
#[cfg(feature = "prime")]
const SMALL_PRIMES: [u8; 15] = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53];
#[cfg(feature = "prime")]
static SMALL_PRIMES_PRODUCT: Lazy<BigUint> =
Lazy::new(|| BigUint::from_u64(16_294_579_238_595_022_365).unwrap());
#[cfg(feature = "prime")]
impl<R: Rng + ?Sized> RandPrime for R {
fn gen_prime(&mut self, bit_size: usize) -> BigUint {
if bit_size < 2 {
panic!("prime size must be at least 2-bit");
}
let mut b = bit_size % 8;
if b == 0 {
b = 8;
}
let bytes_len = (bit_size + 7) / 8;
let mut bytes = alloc::vec![0u8; bytes_len];
loop {
self.fill_bytes(&mut bytes);
bytes[0] &= ((1u32 << (b as u32)) - 1) as u8;
if b >= 2 {
bytes[0] |= 3u8.wrapping_shl(b as u32 - 2);
} else {
bytes[0] |= 1;
if bytes_len > 1 {
bytes[1] |= 0x80;
}
}
bytes[bytes_len - 1] |= 1u8;
let mut p = BigUint::from_bytes_be(&bytes);
let rem = (&p % &*SMALL_PRIMES_PRODUCT).to_u64().unwrap();
'next: for delta in range_step(0, 1 << 20, 2) {
let m = rem + delta;
for prime in &SMALL_PRIMES {
if m % u64::from(*prime) == 0 && (bit_size > 6 || m != u64::from(*prime)) {
continue 'next;
}
}
if delta > 0 {
p += BigUint::from_u64(delta).unwrap();
}
break;
}
if p.bits() == bit_size && probably_prime(&p, 20) {
return p;
}
}
}
}