use alloc::{vec, vec::Vec};
use core::marker::PhantomData;
use core::num::{NonZero, NonZeroU32};
use crypto_bigint::{Limb, Odd, RandomBits, RandomBitsError, Unsigned, Word};
use rand_core::CryptoRng;
use super::{
Primality,
precomputed::{LAST_SMALL_PRIME, RECIPROCALS, SMALL_PRIMES, SmallPrime},
};
use crate::{error::Error, presets::Flavor};
#[derive(Debug, Clone, Copy)]
pub enum SetBits {
Msb,
TwoMsb,
None,
}
pub fn random_odd_integer<T, R>(rng: &mut R, bit_length: NonZeroU32, set_bits: SetBits) -> Result<Odd<T>, Error>
where
T: Unsigned + RandomBits,
R: CryptoRng + ?Sized,
{
let bit_length = bit_length.get();
let mut random = T::try_random_bits(rng, bit_length).map_err(|err| match err {
RandomBitsError::RandCore(_) => unreachable!("`rng` impls `CryptoRng` and therefore is infallible"),
RandomBitsError::BitsPrecisionMismatch { .. } => {
unreachable!("we are not requesting a specific `bits_precision`")
}
RandomBitsError::BitLengthTooLarge {
bit_length,
bits_precision,
} => Error::BitLengthTooLarge {
bit_length,
bits_precision,
},
})?;
random.set_bit_vartime(0, true);
match set_bits {
SetBits::None => {}
SetBits::Msb => random.set_bit_vartime(bit_length - 1, true),
SetBits::TwoMsb => {
random.set_bit_vartime(bit_length - 1, true);
if bit_length > 1 {
random.set_bit_vartime(bit_length - 2, true);
}
}
}
Ok(Odd::new(random).expect("the number is odd by construction"))
}
pub(crate) fn equals_primitive<T>(num: &T, primitive: Word) -> bool
where
T: Unsigned,
{
num.bits_vartime() <= Word::BITS && num.as_limbs()[0].0 == primitive
}
type Residue = u32;
const INCR_LIMIT: Residue = Residue::MAX - LAST_SMALL_PRIME as Residue + 1;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmallFactorsSieve<T: Unsigned> {
base: T,
incr: Residue,
incr_limit: Residue,
safe_primes: bool,
residues: Vec<SmallPrime>,
max_bit_length: u32,
produces_nothing: bool,
starts_from_exception: bool,
last_round: bool,
}
impl<T> SmallFactorsSieve<T>
where
T: Unsigned,
{
pub fn new(start: T, max_bit_length: NonZeroU32, safe_primes: bool) -> Result<Self, Error> {
let max_bit_length_nz = max_bit_length;
let max_bit_length = max_bit_length.get();
if max_bit_length > start.bits_precision() {
return Err(Error::BitLengthTooLarge {
bit_length: max_bit_length,
bits_precision: start.bits_precision(),
});
}
let (max_bit_length, mut start) = if safe_primes {
(max_bit_length - 1, start.wrapping_shr_vartime(1))
} else {
(max_bit_length, start)
};
let produces_nothing = max_bit_length < start.bits_vartime() || max_bit_length < 2;
let mut starts_from_exception = false;
if start <= T::from(2u32) {
starts_from_exception = true;
start = T::from(3u32);
} else {
start |= T::one();
}
let residues_len = trial_primes_num(&start, max_bit_length_nz);
Ok(Self {
base: start,
incr: 0, incr_limit: 0,
safe_primes,
residues: vec![0; residues_len],
max_bit_length,
produces_nothing,
starts_from_exception,
last_round: false,
})
}
fn update_residues(&mut self) -> bool {
if self.incr_limit != 0 && self.incr <= self.incr_limit {
return true;
}
if self.last_round {
return false;
}
self.base = self
.base
.checked_add(&self.incr.into())
.expect("Does not overflow by construction");
self.incr = 0;
for (i, rec) in RECIPROCALS.iter().enumerate().take(self.residues.len()) {
let rem = self.base.rem_limb_with_reciprocal(rec);
self.residues[i] = rem.0 as SmallPrime;
}
let max_value = match T::one_like(&self.base).overflowing_shl_vartime(self.max_bit_length) {
Some(val) => val,
None => T::one_like(&self.base),
};
let incr_limit = max_value.wrapping_sub(&self.base);
self.incr_limit = if incr_limit > T::from(INCR_LIMIT) {
INCR_LIMIT
} else {
self.last_round = true;
let incr_limit_small: Residue = incr_limit.as_limbs()[0]
.0
.try_into()
.expect("the increment limit should fit within `Residue`");
incr_limit_small
};
true
}
fn current_is_composite(&self) -> bool {
self.residues.iter().enumerate().any(|(i, m)| {
let d = SMALL_PRIMES[i] as Residue;
let r = (*m as Residue + self.incr) % d;
r == 0 || (self.safe_primes && r == (d - 1) >> 1)
})
}
fn maybe_next(&mut self) -> Option<T> {
let result = if self.current_is_composite() {
None
} else {
match self.base.checked_add(&self.incr.into()).into_option() {
Some(mut num) => {
if self.safe_primes {
num = num.wrapping_shl_vartime(1) | T::one_like(&self.base);
}
Some(num)
}
None => None,
}
};
self.incr += 2;
result
}
fn next(&mut self) -> Option<T> {
if self.produces_nothing {
return None;
}
if self.starts_from_exception {
self.starts_from_exception = false;
return Some(T::from(if self.safe_primes { 5u32 } else { 2u32 }));
}
while self.update_residues() {
match self.maybe_next() {
Some(x) => return Some(x),
None => continue,
};
}
None
}
}
impl<T> Iterator for SmallFactorsSieve<T>
where
T: Unsigned,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
Self::next(self)
}
}
pub trait SieveFactory {
type Item;
type Sieve: Iterator<Item = Self::Item>;
fn make_sieve<R>(
&mut self,
rng: &mut R,
previous_sieve: Option<&Self::Sieve>,
) -> Result<Option<Self::Sieve>, Error>
where
R: CryptoRng + ?Sized;
}
#[derive(Debug, Clone, Copy)]
pub struct SmallFactorsSieveFactory<T> {
max_bit_length: NonZeroU32,
safe_primes: bool,
set_bits: SetBits,
phantom: PhantomData<T>,
}
impl<T> SmallFactorsSieveFactory<T>
where
T: Unsigned + RandomBits,
{
pub fn new(flavor: Flavor, max_bit_length: u32, set_bits: SetBits) -> Result<Self, Error> {
match flavor {
Flavor::Any => {
if max_bit_length < 2 {
return Err(Error::BitLengthTooSmall {
bit_length: max_bit_length,
flavor,
});
}
}
Flavor::Safe => {
if max_bit_length < 3 {
return Err(Error::BitLengthTooSmall {
bit_length: max_bit_length,
flavor,
});
}
}
}
let max_bit_length = NonZero::new(max_bit_length).expect("`bit_length` should be non-zero");
Ok(Self {
max_bit_length,
safe_primes: match flavor {
Flavor::Any => false,
Flavor::Safe => true,
},
set_bits,
phantom: PhantomData,
})
}
}
impl<T> SieveFactory for SmallFactorsSieveFactory<T>
where
T: Unsigned + RandomBits,
{
type Item = T;
type Sieve = SmallFactorsSieve<T>;
fn make_sieve<R>(
&mut self,
rng: &mut R,
_previous_sieve: Option<&Self::Sieve>,
) -> Result<Option<Self::Sieve>, Error>
where
R: CryptoRng + ?Sized,
{
let start = random_odd_integer::<T, _>(rng, self.max_bit_length, self.set_bits)?;
Ok(Some(SmallFactorsSieve::new(
start.get(),
self.max_bit_length,
self.safe_primes,
)?))
}
}
fn trial_primes_num<T>(start: &T, max_bit_length: NonZeroU32) -> usize
where
T: Unsigned,
{
let end_bits = max_bit_length.get().div_ceil(2);
let start_bits = start.bits_vartime();
let max_prime_bits = SmallPrime::BITS - LAST_SMALL_PRIME.leading_zeros();
if end_bits > max_prime_bits && start_bits > max_prime_bits {
return SMALL_PRIMES.len();
}
let end_limit: SmallPrime = if end_bits <= max_prime_bits {
1 << end_bits
} else {
SmallPrime::MAX
};
let start_limit: SmallPrime = if start_bits <= max_prime_bits {
start.as_limbs()[0].0.try_into().expect("The number is in range")
} else {
SmallPrime::MAX
};
SMALL_PRIMES.partition_point(|x| *x <= end_limit && *x < start_limit)
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ConventionsTestResult<T> {
Prime,
Composite,
Undecided { odd_candidate: Odd<T> },
}
pub(crate) fn conventions_test<T>(candidate: T) -> ConventionsTestResult<T>
where
T: Unsigned,
{
if equals_primitive(&candidate, 1) {
return ConventionsTestResult::Composite;
}
if equals_primitive(&candidate, 2) {
return ConventionsTestResult::Prime;
}
let odd_candidate: Odd<T> = match Odd::new(candidate).into() {
Some(x) => x,
None => return ConventionsTestResult::Composite,
};
ConventionsTestResult::Undecided { odd_candidate }
}
pub(crate) fn small_factors_test<T>(candidate: &Odd<T>) -> Primality
where
T: Unsigned,
{
let candidate_bits = NonZeroU32::new(candidate.bits_vartime()).expect("an odd integer is non-zero");
let len = trial_primes_num(candidate.as_ref(), candidate_bits);
for rec in RECIPROCALS.iter().take(len) {
if candidate.rem_limb_with_reciprocal(rec) == Limb::ZERO {
return Primality::Composite;
}
}
Primality::ProbablyPrime
}
#[cfg(test)]
mod tests {
use alloc::format;
use alloc::vec::Vec;
use core::num::NonZero;
use crypto_bigint::{Odd, U64, U256};
use num_prime::nt_funcs::factorize64;
use rand::rngs::ChaCha8Rng;
use rand_core::SeedableRng;
use super::{
ConventionsTestResult, SetBits, SmallFactorsSieve, SmallFactorsSieveFactory, conventions_test,
random_odd_integer, small_factors_test, trial_primes_num,
};
use crate::{
Error, Flavor,
hazmat::{
Primality,
precomputed::{LAST_SMALL_PRIME, SMALL_PRIMES},
},
};
#[test]
fn trial_primes_num_corner_cases() {
let len = trial_primes_num(&U64::from(0x123456789abcdef0u64), 64.try_into().unwrap());
assert_eq!(len, SMALL_PRIMES.len());
let len = trial_primes_num(&U64::from(1u64 << 13), 14.try_into().unwrap());
assert_eq!(len, SMALL_PRIMES.partition_point(|x| *x < (1 << 7)));
let len = trial_primes_num(&U64::from(LAST_SMALL_PRIME as u64 - 1), 64.try_into().unwrap());
assert_eq!(len, SMALL_PRIMES.len() - 1);
}
#[test]
fn conventions() {
assert_eq!(conventions_test(U64::ZERO), ConventionsTestResult::Composite);
assert_eq!(conventions_test(U64::ONE), ConventionsTestResult::Composite);
assert_eq!(conventions_test(U64::from(2u64)), ConventionsTestResult::Prime);
assert_eq!(
conventions_test(U64::from(3u64)),
ConventionsTestResult::Undecided {
odd_candidate: Odd::new(U64::from(3u64)).unwrap()
}
);
}
#[test]
fn small_factors() {
assert_eq!(
small_factors_test(&Odd::new(U64::from(5u64)).unwrap()),
Primality::ProbablyPrime
);
assert_eq!(
small_factors_test(&Odd::new(U64::from(9u64)).unwrap()),
Primality::Composite
);
}
#[test]
fn random() {
let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
let start = random_odd_integer::<U64, _>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
.unwrap()
.get();
for num in SmallFactorsSieve::new(start, NonZero::new(32).unwrap(), false)
.unwrap()
.take(100)
{
let num_u64 = u64::from(num);
assert!(num_u64.leading_zeros() == 32);
let factors_and_powers = factorize64(num_u64);
let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
assert!(factors[0] > max_prime as u64);
}
}
#[test]
fn random_boxed() {
let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
let start =
random_odd_integer::<crypto_bigint::BoxedUint, _>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
.unwrap()
.get();
for num in SmallFactorsSieve::new(start, NonZero::new(32).unwrap(), false)
.unwrap()
.take(100)
{
#[allow(clippy::useless_conversion)]
let num_u64: u64 = num.as_words()[0].into();
assert!(num_u64.leading_zeros() == 32);
let factors_and_powers = factorize64(num_u64);
let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
assert!(factors[0] > max_prime as u64);
}
}
fn check_sieve(start: u32, bit_length: u32, safe_prime: bool, reference: &[u32]) {
let test = SmallFactorsSieve::new(U64::from(start), NonZero::new(bit_length).unwrap(), safe_prime)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(test.len(), reference.len());
for (x, y) in test.iter().zip(reference.iter()) {
assert_eq!(x, &U64::from(*y));
}
}
#[test]
fn empty_sequence() {
check_sieve(1, 1, false, &[]); check_sieve(1, 2, true, &[]); check_sieve(64, 6, true, &[]); }
#[test]
fn small_range() {
check_sieve(1, 2, false, &[2, 3]);
check_sieve(2, 2, false, &[2, 3]);
check_sieve(3, 2, false, &[3]);
check_sieve(1, 3, false, &[2, 3, 5, 7]);
check_sieve(3, 3, false, &[3, 5, 7]);
check_sieve(5, 3, false, &[5, 7]);
check_sieve(7, 3, false, &[7]);
check_sieve(1, 4, false, &[2, 3, 5, 7, 9, 11, 13, 15]);
check_sieve(3, 4, false, &[3, 5, 7, 9, 11, 13, 15]);
check_sieve(5, 4, false, &[5, 7, 11, 13]);
check_sieve(7, 4, false, &[7, 11, 13]);
check_sieve(9, 4, false, &[11, 13]);
check_sieve(13, 4, false, &[13]);
check_sieve(15, 4, false, &[]);
check_sieve(1, 3, true, &[5, 7]);
check_sieve(3, 3, true, &[5, 7]);
check_sieve(5, 3, true, &[5, 7]);
check_sieve(7, 3, true, &[7]);
check_sieve(1, 4, true, &[5, 7, 11, 15]);
check_sieve(5, 4, true, &[5, 7, 11, 15]);
check_sieve(7, 4, true, &[7, 11, 15]);
check_sieve(9, 4, true, &[11]);
check_sieve(13, 4, true, &[]);
}
#[test]
fn sieve_too_many_bits() {
assert_eq!(
SmallFactorsSieve::new(U64::ONE, NonZero::new(65).unwrap(), false).unwrap_err(),
Error::BitLengthTooLarge {
bit_length: 65,
bits_precision: 64
}
);
}
#[test]
fn random_below_max_length() {
let mut rng = rand::rng();
for _ in 0..10 {
let r = random_odd_integer::<U64, _>(&mut rng, NonZero::new(50).unwrap(), SetBits::Msb)
.unwrap()
.get();
assert_eq!(r.bits(), 50);
}
}
#[test]
fn random_odd_uint_too_many_bits() {
let mut rng = rand::rng();
assert!(random_odd_integer::<U64, _>(&mut rng, NonZero::new(65).unwrap(), SetBits::Msb).is_err());
}
#[test]
fn sieve_derived_traits() {
let s = SmallFactorsSieve::new(U64::ONE, NonZero::new(10).unwrap(), false).unwrap();
assert!(format!("{s:?}").starts_with("SmallFactorsSieve"));
assert_eq!(s.clone(), s);
let s2 = SmallFactorsSieve::new(U64::ONE, NonZero::new(10).unwrap(), false).unwrap();
assert_eq!(s, s2);
let s3 = SmallFactorsSieve::new(U64::ONE, NonZero::new(12).unwrap(), false).unwrap();
assert_ne!(s, s3);
}
#[test]
fn sieve_with_max_start() {
let start = U64::MAX;
let mut sieve = SmallFactorsSieve::new(start, NonZero::new(U64::BITS).unwrap(), false).unwrap();
assert!(sieve.next().is_none());
}
#[test]
fn too_few_bits_regular_primes() {
assert_eq!(
SmallFactorsSieveFactory::<U64>::new(Flavor::Any, 1, SetBits::Msb).unwrap_err(),
Error::BitLengthTooSmall {
bit_length: 1,
flavor: Flavor::Any
}
);
}
#[test]
fn too_few_bits_safe_primes() {
assert_eq!(
SmallFactorsSieveFactory::<U64>::new(Flavor::Safe, 2, SetBits::Msb).unwrap_err(),
Error::BitLengthTooSmall {
bit_length: 2,
flavor: Flavor::Safe
}
);
}
#[test]
fn set_bits() {
let mut rng = rand::rng();
for _ in 0..10 {
let x = random_odd_integer::<U64, _>(&mut rng, NonZero::new(64).unwrap(), SetBits::Msb).unwrap();
assert!(bool::from(x.bit(63)));
}
for _ in 0..10 {
let x = random_odd_integer::<U64, _>(&mut rng, NonZero::new(64).unwrap(), SetBits::TwoMsb).unwrap();
assert!(bool::from(x.bit(63)));
assert!(bool::from(x.bit(62)));
}
assert!(
(0..30)
.map(|_| { random_odd_integer::<U64, _>(&mut rng, NonZero::new(64).unwrap(), SetBits::None).unwrap() })
.any(|x| !bool::from(x.bit(63)))
);
}
#[test]
fn set_two_msb_small_bit_length() {
let mut rng = rand::rng();
let x = random_odd_integer::<U64, _>(&mut rng, NonZero::new(1).unwrap(), SetBits::TwoMsb)
.unwrap()
.get();
assert_eq!(x, U64::ONE);
}
#[test]
fn platform_independence() {
let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
let x = random_odd_integer::<U256, _>(&mut rng, NonZero::new(200).unwrap(), SetBits::TwoMsb)
.unwrap()
.get();
assert_eq!(
x,
U256::from_be_hex("00000000000000E94A74F9D90C0982D7D4F5378BDA8143E6391EBC3F59CBD0E5")
);
let x = random_odd_integer::<U256, _>(&mut rng, NonZero::new(220).unwrap(), SetBits::TwoMsb)
.unwrap()
.get();
assert_eq!(
x,
U256::from_be_hex("000000000E28CE6059E357411C67F6539AEF56F2B4653F0583D6A2195A9897BB")
);
}
}