1use alloc::{vec, vec::Vec};
5use core::marker::PhantomData;
6use core::num::{NonZero, NonZeroU32};
7
8use crypto_bigint::{Integer, Odd, RandomBits, RandomBitsError};
9use rand_core::CryptoRngCore;
10
11use crate::hazmat::precomputed::{SmallPrime, LAST_SMALL_PRIME, RECIPROCALS, SMALL_PRIMES};
12use crate::traits::SieveFactory;
13
14#[derive(Debug, Clone, Copy)]
17pub enum SetBits {
18 Msb,
22 TwoMsb,
26 None,
28}
29
30pub fn random_odd_integer<T: Integer + RandomBits>(
37 rng: &mut (impl CryptoRngCore + ?Sized),
38 bit_length: NonZeroU32,
39 set_bits: SetBits,
40) -> Result<Odd<T>, RandomBitsError> {
41 let bit_length = bit_length.get();
42
43 let mut random = T::try_random_bits(rng, bit_length)?;
44
45 random.set_bit_vartime(0, true);
48
49 match set_bits {
53 SetBits::None => {}
54 SetBits::Msb => random.set_bit_vartime(bit_length - 1, true),
55 SetBits::TwoMsb => {
56 random.set_bit_vartime(bit_length - 1, true);
57 if bit_length > 1 {
60 random.set_bit_vartime(bit_length - 2, true);
61 }
62 }
63 }
64
65 Ok(Odd::new(random).expect("the number is odd by construction"))
66}
67
68type Residue = u32;
71
72const INCR_LIMIT: Residue = Residue::MAX - LAST_SMALL_PRIME as Residue + 1;
75
76#[derive(Clone, Debug, PartialEq, Eq)]
79pub struct SmallPrimesSieve<T: Integer> {
80 base: T,
84 incr: Residue,
85 incr_limit: Residue,
86 safe_primes: bool,
87 residues: Vec<SmallPrime>,
88 max_bit_length: u32,
89 produces_nothing: bool,
90 starts_from_exception: bool,
91 last_round: bool,
92}
93
94impl<T: Integer> SmallPrimesSieve<T> {
95 pub fn new(start: T, max_bit_length: NonZeroU32, safe_primes: bool) -> Self {
106 let max_bit_length = max_bit_length.get();
107
108 if max_bit_length > start.bits_precision() {
109 panic!(
110 "The requested bit length ({}) is larger than the precision of `start`",
111 max_bit_length
112 );
113 }
114
115 let (max_bit_length, mut start) = if safe_primes {
118 (max_bit_length - 1, start.wrapping_shr_vartime(1))
119 } else {
120 (max_bit_length, start)
121 };
122
123 let produces_nothing = max_bit_length < start.bits_vartime() || max_bit_length < 2;
125
126 let mut starts_from_exception = false;
129 if start <= T::from(2u32) {
130 starts_from_exception = true;
131 start = T::from(3u32);
132 } else {
133 start |= T::one();
135 }
136
137 let residues_len = if T::from(LAST_SMALL_PRIME) <= start {
141 SMALL_PRIMES.len()
142 } else {
143 let start_small = start.as_ref()[0].0 as SmallPrime;
146 SMALL_PRIMES.partition_point(|x| *x < start_small)
147 };
148
149 Self {
150 base: start,
151 incr: 0, incr_limit: 0,
153 safe_primes,
154 residues: vec![0; residues_len],
155 max_bit_length,
156 produces_nothing,
157 starts_from_exception,
158 last_round: false,
159 }
160 }
161
162 fn update_residues(&mut self) -> bool {
163 if self.incr_limit != 0 && self.incr <= self.incr_limit {
164 return true;
165 }
166
167 if self.last_round {
168 return false;
169 }
170
171 self.base = self
176 .base
177 .checked_add(&self.incr.into())
178 .expect("Does not overflow by construction");
179
180 self.incr = 0;
181
182 for (i, rec) in RECIPROCALS.iter().enumerate().take(self.residues.len()) {
184 let rem = self.base.rem_limb_with_reciprocal(rec);
185 self.residues[i] = rem.0 as SmallPrime;
186 }
187
188 let max_value = match T::one_like(&self.base)
190 .overflowing_shl_vartime(self.max_bit_length)
191 .into()
192 {
193 Some(val) => val,
194 None => T::one_like(&self.base),
195 };
196 let incr_limit = max_value.wrapping_sub(&self.base);
197 self.incr_limit = if incr_limit > T::from(INCR_LIMIT) {
198 INCR_LIMIT
199 } else {
200 self.last_round = true;
203 let incr_limit_small: Residue = incr_limit.as_ref()[0]
206 .0
207 .try_into()
208 .expect("the increment limit should fit within `Residue`");
209 incr_limit_small
210 };
211
212 true
213 }
214
215 fn current_is_composite(&self) -> bool {
217 self.residues.iter().enumerate().any(|(i, m)| {
218 let d = SMALL_PRIMES[i] as Residue;
219 let r = (*m as Residue + self.incr) % d;
220
221 r == 0 || (self.safe_primes && r == (d - 1) >> 1)
228 })
229 }
230
231 fn maybe_next(&mut self) -> Option<T> {
234 let result = if self.current_is_composite() {
235 None
236 } else {
237 match self.base.checked_add(&self.incr.into()).into_option() {
238 Some(mut num) => {
239 if self.safe_primes {
240 num = num.wrapping_shl_vartime(1) | T::one_like(&self.base);
242 }
243 Some(num)
244 }
245 None => None,
246 }
247 };
248
249 self.incr += 2;
250 result
251 }
252
253 fn next(&mut self) -> Option<T> {
254 if self.produces_nothing {
256 return None;
257 }
258
259 if self.starts_from_exception {
260 self.starts_from_exception = false;
261 return Some(T::from(if self.safe_primes { 5u32 } else { 2u32 }));
262 }
263
264 while self.update_residues() {
267 match self.maybe_next() {
268 Some(x) => return Some(x),
269 None => continue,
270 };
271 }
272 None
273 }
274}
275
276impl<T: Integer> Iterator for SmallPrimesSieve<T> {
277 type Item = T;
278
279 fn next(&mut self) -> Option<Self::Item> {
280 Self::next(self)
281 }
282}
283
284#[derive(Debug, Clone, Copy)]
286pub struct SmallPrimesSieveFactory<T> {
287 max_bit_length: NonZeroU32,
288 safe_primes: bool,
289 set_bits: SetBits,
290 phantom: PhantomData<T>,
291}
292
293impl<T: Integer + RandomBits> SmallPrimesSieveFactory<T> {
294 fn new_impl(max_bit_length: u32, set_bits: SetBits, safe_primes: bool) -> Self {
295 if !safe_primes && max_bit_length < 2 {
296 panic!("`bit_length` must be 2 or greater.");
297 }
298 if safe_primes && max_bit_length < 3 {
299 panic!("`bit_length` must be 3 or greater.");
300 }
301 let max_bit_length = NonZero::new(max_bit_length).expect("`bit_length` should be non-zero");
302 Self {
303 max_bit_length,
304 safe_primes,
305 set_bits,
306 phantom: PhantomData,
307 }
308 }
309
310 pub fn new(max_bit_length: u32, set_bits: SetBits) -> Self {
313 Self::new_impl(max_bit_length, set_bits, false)
314 }
315
316 pub fn new_safe_primes(max_bit_length: u32, set_bits: SetBits) -> Self {
319 Self::new_impl(max_bit_length, set_bits, true)
320 }
321}
322
323impl<T: Integer + RandomBits> SieveFactory for SmallPrimesSieveFactory<T> {
324 type Item = T;
325 type Sieve = SmallPrimesSieve<T>;
326 fn make_sieve(
327 &mut self,
328 rng: &mut (impl CryptoRngCore + ?Sized),
329 _previous_sieve: Option<&Self::Sieve>,
330 ) -> Option<Self::Sieve> {
331 let start =
332 random_odd_integer::<T>(rng, self.max_bit_length, self.set_bits).expect("random_odd_integer() failed");
333 Some(SmallPrimesSieve::new(
334 start.get(),
335 self.max_bit_length,
336 self.safe_primes,
337 ))
338 }
339}
340
341#[cfg(test)]
342mod tests {
343
344 use alloc::format;
345 use alloc::vec::Vec;
346 use core::num::NonZero;
347
348 use crypto_bigint::U64;
349 use num_prime::nt_funcs::factorize64;
350 use rand_chacha::ChaCha8Rng;
351 use rand_core::{OsRng, SeedableRng};
352
353 use super::{random_odd_integer, SetBits, SmallPrimesSieve, SmallPrimesSieveFactory};
354 use crate::hazmat::precomputed::SMALL_PRIMES;
355
356 #[test]
357 fn random() {
358 let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
359
360 let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
361 let start = random_odd_integer::<U64>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
362 .unwrap()
363 .get();
364 for num in SmallPrimesSieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
365 let num_u64 = u64::from(num);
366 assert!(num_u64.leading_zeros() == 32);
367
368 let factors_and_powers = factorize64(num_u64);
369 let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
370
371 assert!(factors[0] > max_prime as u64);
372 }
373 }
374 #[test]
375 fn random_boxed() {
376 let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
377
378 let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
379 let start = random_odd_integer::<crypto_bigint::BoxedUint>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
380 .unwrap()
381 .get();
382
383 for num in SmallPrimesSieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
384 #[allow(clippy::useless_conversion)]
386 let num_u64: u64 = num.as_words()[0].into();
387 assert!(num_u64.leading_zeros() == 32);
388
389 let factors_and_powers = factorize64(num_u64);
390 let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
391
392 assert!(factors[0] > max_prime as u64);
393 }
394 }
395
396 fn check_sieve(start: u32, bit_length: u32, safe_prime: bool, reference: &[u32]) {
397 let test =
398 SmallPrimesSieve::new(U64::from(start), NonZero::new(bit_length).unwrap(), safe_prime).collect::<Vec<_>>();
399 assert_eq!(test.len(), reference.len());
400 for (x, y) in test.iter().zip(reference.iter()) {
401 assert_eq!(x, &U64::from(*y));
402 }
403 }
404
405 #[test]
406 fn empty_sequence() {
407 check_sieve(1, 1, false, &[]); check_sieve(1, 2, true, &[]); check_sieve(64, 6, true, &[]); }
411
412 #[test]
413 fn small_range() {
414 check_sieve(1, 2, false, &[2, 3]);
415 check_sieve(2, 2, false, &[2, 3]);
416 check_sieve(3, 2, false, &[3]);
417
418 check_sieve(1, 3, false, &[2, 3, 5, 7]);
419 check_sieve(3, 3, false, &[3, 5, 7]);
420 check_sieve(5, 3, false, &[5, 7]);
421 check_sieve(7, 3, false, &[7]);
422
423 check_sieve(1, 4, false, &[2, 3, 5, 7, 9, 11, 13, 15]);
424 check_sieve(3, 4, false, &[3, 5, 7, 9, 11, 13, 15]);
425 check_sieve(5, 4, false, &[5, 7, 11, 13]);
426 check_sieve(7, 4, false, &[7, 11, 13]);
427 check_sieve(9, 4, false, &[11, 13]);
428 check_sieve(13, 4, false, &[13]);
429 check_sieve(15, 4, false, &[]);
430
431 check_sieve(1, 3, true, &[5, 7]);
432 check_sieve(3, 3, true, &[5, 7]);
433 check_sieve(5, 3, true, &[5, 7]);
434 check_sieve(7, 3, true, &[7]);
435
436 check_sieve(1, 4, true, &[5, 7, 11, 15]);
443 check_sieve(5, 4, true, &[5, 7, 11, 15]);
444 check_sieve(7, 4, true, &[7, 11, 15]);
445
446 check_sieve(9, 4, true, &[11]);
447 check_sieve(13, 4, true, &[]);
448 }
449
450 #[test]
451 #[should_panic(expected = "The requested bit length (65) is larger than the precision of `start`")]
452 fn sieve_too_many_bits() {
453 let _sieve = SmallPrimesSieve::new(U64::ONE, NonZero::new(65).unwrap(), false);
454 }
455
456 #[test]
457 fn random_below_max_length() {
458 for _ in 0..10 {
459 let r = random_odd_integer::<U64>(&mut OsRng, NonZero::new(50).unwrap(), SetBits::Msb)
460 .unwrap()
461 .get();
462 assert_eq!(r.bits(), 50);
463 }
464 }
465
466 #[test]
467 fn random_odd_uint_too_many_bits() {
468 assert!(random_odd_integer::<U64>(&mut OsRng, NonZero::new(65).unwrap(), SetBits::Msb).is_err());
469 }
470
471 #[test]
472 fn sieve_derived_traits() {
473 let s = SmallPrimesSieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
474 assert!(format!("{s:?}").starts_with("SmallPrimesSieve"));
476 assert_eq!(s.clone(), s);
478
479 let s2 = SmallPrimesSieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
481 assert_eq!(s, s2);
482 let s3 = SmallPrimesSieve::new(U64::ONE, NonZero::new(12).unwrap(), false);
483 assert_ne!(s, s3);
484 }
485
486 #[test]
487 fn sieve_with_max_start() {
488 let start = U64::MAX;
489 let mut sieve = SmallPrimesSieve::new(start, NonZero::new(U64::BITS).unwrap(), false);
490 assert!(sieve.next().is_none());
491 }
492
493 #[test]
494 #[should_panic(expected = "`bit_length` must be 2 or greater")]
495 fn too_few_bits_regular_primes() {
496 let _fac = SmallPrimesSieveFactory::<U64>::new(1, SetBits::Msb);
497 }
498
499 #[test]
500 #[should_panic(expected = "`bit_length` must be 3 or greater")]
501 fn too_few_bits_safe_primes() {
502 let _fac = SmallPrimesSieveFactory::<U64>::new_safe_primes(2, SetBits::Msb);
503 }
504
505 #[test]
506 fn set_bits() {
507 for _ in 0..10 {
508 let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::Msb).unwrap();
509 assert!(bool::from(x.bit(63)));
510 }
511
512 for _ in 0..10 {
513 let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::TwoMsb).unwrap();
514 assert!(bool::from(x.bit(63)));
515 assert!(bool::from(x.bit(62)));
516 }
517
518 assert!((0..30)
520 .map(|_| { random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::None).unwrap() })
521 .any(|x| !bool::from(x.bit(63))));
522 }
523
524 #[test]
525 fn set_two_msb_small_bit_length() {
526 let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(1).unwrap(), SetBits::TwoMsb)
529 .unwrap()
530 .get();
531 assert_eq!(x, U64::ONE);
532 }
533}