crypto_primes/hazmat/
sieve.rs

1//! An iterator for weeding out multiples of small primes,
2//! before proceeding with slower tests.
3
4use alloc::{vec, vec::Vec};
5use core::marker::PhantomData;
6use core::num::{NonZero, NonZeroU32};
7
8use crypto_bigint::{Integer, Odd, RandomBits, RandomBitsError};
9use rand_core::CryptoRngCore;
10
11use crate::hazmat::precomputed::{SmallPrime, LAST_SMALL_PRIME, RECIPROCALS, SMALL_PRIMES};
12use crate::traits::SieveFactory;
13
14/// Decide how prime candidates are manipulated by setting certain bits before primality testing,
15/// influencing the range of the prime.
16#[derive(Debug, Clone, Copy)]
17pub enum SetBits {
18    /// Set the most significant bit, thus limiting the range to `[MAX/2 + 1, MAX]`.
19    ///
20    /// In other words, all candidates will have the same bit size.
21    Msb,
22    /// Set two most significant bits, limiting the range to `[MAX - MAX/4 + 1, MAX]`.
23    ///
24    /// This is useful in the RSA case because a product of two such numbers will have a guaranteed bit size.
25    TwoMsb,
26    /// No additional bits set; uses the full range `[1, MAX]`.
27    None,
28}
29
30/// Returns a random odd integer up to the given bit length.
31///
32/// The `set_bits` parameter decides which extra bits are set, which decides the range of the number.
33///
34/// Returns an error variant if `bit_length` is greater than the maximum allowed for `T`
35/// (applies to fixed-length types).
36pub fn random_odd_integer<T: Integer + RandomBits>(
37    rng: &mut (impl CryptoRngCore + ?Sized),
38    bit_length: NonZeroU32,
39    set_bits: SetBits,
40) -> Result<Odd<T>, RandomBitsError> {
41    let bit_length = bit_length.get();
42
43    let mut random = T::try_random_bits(rng, bit_length)?;
44
45    // Make it odd
46    // `bit_length` is non-zero, so the 0-th bit exists.
47    random.set_bit_vartime(0, true);
48
49    // Will not overflow since `bit_length` is ensured to be within the size of the integer
50    // (checked within the `T::try_random_bits()` call).
51    // `bit_length - 1`-th bit exists since `bit_length` is non-zero.
52    match set_bits {
53        SetBits::None => {}
54        SetBits::Msb => random.set_bit_vartime(bit_length - 1, true),
55        SetBits::TwoMsb => {
56            random.set_bit_vartime(bit_length - 1, true);
57            // We could panic here, but since the primary purpose of `TwoMsb` is to ensure the bit length
58            // of the product of two numbers, ignoring this for `bit_length = 1` leads to the desired result.
59            if bit_length > 1 {
60                random.set_bit_vartime(bit_length - 2, true);
61            }
62        }
63    }
64
65    Ok(Odd::new(random).expect("the number is odd by construction"))
66}
67
68// The type we use to calculate incremental residues.
69// Should be >= `SmallPrime` in size.
70type Residue = u32;
71
72// The maximum increment that won't overflow the type we use to calculate residues of increments:
73// we need `(max_prime - 1) + max_incr <= Type::MAX`.
74const INCR_LIMIT: Residue = Residue::MAX - LAST_SMALL_PRIME as Residue + 1;
75
76/// An iterator returning numbers with up to and including given bit length,
77/// starting from a given number, that are not multiples of the first 2048 small primes.
78#[derive(Clone, Debug, PartialEq, Eq)]
79pub struct SmallPrimesSieve<T: Integer> {
80    // Instead of dividing a big integer by small primes every time (which is slow),
81    // we keep a "base" and a small increment separately,
82    // so that we can only calculate the residues of the increment.
83    base: T,
84    incr: Residue,
85    incr_limit: Residue,
86    safe_primes: bool,
87    residues: Vec<SmallPrime>,
88    max_bit_length: u32,
89    produces_nothing: bool,
90    starts_from_exception: bool,
91    last_round: bool,
92}
93
94impl<T: Integer> SmallPrimesSieve<T> {
95    /// Creates a new sieve, iterating from `start` and until the last number with `max_bit_length`
96    /// bits, producing numbers that are not non-trivial multiples of a list of small primes in the
97    /// range `[2, start)` (`safe_primes = false`) or `[2, start/2)` (`safe_primes = true`).
98    ///
99    /// Note that `start` is adjusted to `2`, or the next `1 mod 2` number (`safe_primes = false`);
100    /// and `5`, or `3 mod 4` number (`safe_primes = true`).
101    ///
102    /// Panics if `max_bit_length` greater than the precision of `start`.
103    ///
104    /// If `safe_primes` is `true`, both the returned `n` and `n/2` are sieved.
105    pub fn new(start: T, max_bit_length: NonZeroU32, safe_primes: bool) -> Self {
106        let max_bit_length = max_bit_length.get();
107
108        if max_bit_length > start.bits_precision() {
109            panic!(
110                "The requested bit length ({}) is larger than the precision of `start`",
111                max_bit_length
112            );
113        }
114
115        // If we are targeting safe primes, iterate over the corresponding
116        // possible Germain primes (`n/2`), reducing the task to that with `safe_primes = false`.
117        let (max_bit_length, mut start) = if safe_primes {
118            (max_bit_length - 1, start.wrapping_shr_vartime(1))
119        } else {
120            (max_bit_length, start)
121        };
122
123        // This is easier than making all the methods generic enough to handle these corner cases.
124        let produces_nothing = max_bit_length < start.bits_vartime() || max_bit_length < 2;
125
126        // Add the exception to the produced candidates - the only one that doesn't fit
127        // the general pattern of incrementing the base by 2.
128        let mut starts_from_exception = false;
129        if start <= T::from(2u32) {
130            starts_from_exception = true;
131            start = T::from(3u32);
132        } else {
133            // Adjust the start so that we hit odd numbers when incrementing it by 2.
134            start |= T::one();
135        }
136
137        // Only calculate residues by primes up to and not including `start`, because when we only
138        // have the resiude, we cannot distinguish between a prime itself and a multiple of that
139        // prime.
140        let residues_len = if T::from(LAST_SMALL_PRIME) <= start {
141            SMALL_PRIMES.len()
142        } else {
143            // `start` is smaller than the last prime in the list so casting `start` to a `u16` is
144            // safe. We need to find out how many residues we can use.
145            let start_small = start.as_ref()[0].0 as SmallPrime;
146            SMALL_PRIMES.partition_point(|x| *x < start_small)
147        };
148
149        Self {
150            base: start,
151            incr: 0, // This will ensure that `update_residues()` is called right away.
152            incr_limit: 0,
153            safe_primes,
154            residues: vec![0; residues_len],
155            max_bit_length,
156            produces_nothing,
157            starts_from_exception,
158            last_round: false,
159        }
160    }
161
162    fn update_residues(&mut self) -> bool {
163        if self.incr_limit != 0 && self.incr <= self.incr_limit {
164            return true;
165        }
166
167        if self.last_round {
168            return false;
169        }
170
171        // Set the new base.
172        // Should not overflow since `incr` is never greater than `incr_limit`,
173        // and the latter is chosen such that it doesn't overflow when added to `base`
174        // (see the rest of this method).
175        self.base = self
176            .base
177            .checked_add(&self.incr.into())
178            .expect("Does not overflow by construction");
179
180        self.incr = 0;
181
182        // Re-calculate residues. This is taking up most of the sieving time.
183        for (i, rec) in RECIPROCALS.iter().enumerate().take(self.residues.len()) {
184            let rem = self.base.rem_limb_with_reciprocal(rec);
185            self.residues[i] = rem.0 as SmallPrime;
186        }
187
188        // Find the increment limit.
189        let max_value = match T::one_like(&self.base)
190            .overflowing_shl_vartime(self.max_bit_length)
191            .into()
192        {
193            Some(val) => val,
194            None => T::one_like(&self.base),
195        };
196        let incr_limit = max_value.wrapping_sub(&self.base);
197        self.incr_limit = if incr_limit > T::from(INCR_LIMIT) {
198            INCR_LIMIT
199        } else {
200            // We are close to `2^max_bit_length - 1`.
201            // Mark this round as the last.
202            self.last_round = true;
203            // Can unwrap here since we just checked above that `incr_limit <= INCR_LIMIT`,
204            // and `INCR_LIMIT` fits into `Residue`.
205            let incr_limit_small: Residue = incr_limit.as_ref()[0]
206                .0
207                .try_into()
208                .expect("the increment limit should fit within `Residue`");
209            incr_limit_small
210        };
211
212        true
213    }
214
215    // Returns `true` if the current `base + incr` is divisible by any of the small primes.
216    fn current_is_composite(&self) -> bool {
217        self.residues.iter().enumerate().any(|(i, m)| {
218            let d = SMALL_PRIMES[i] as Residue;
219            let r = (*m as Residue + self.incr) % d;
220
221            // A trick from "Safe Prime Generation with a Combined Sieve" by Michael J. Wiener
222            // (https://eprint.iacr.org/2003/186).
223            // Remember that the check above was for the `(n - 1)/2`;
224            // If `(n - 1)/2 mod d == (d - 1)/2`, it means that `n mod d == 0`.
225            // In other words, we are checking the remainder of `n mod d`
226            // for virtually no additional cost.
227            r == 0 || (self.safe_primes && r == (d - 1) >> 1)
228        })
229    }
230
231    // Returns the restored `base + incr` if it is not composite (wrt the small primes),
232    // and bumps the increment unconditionally.
233    fn maybe_next(&mut self) -> Option<T> {
234        let result = if self.current_is_composite() {
235            None
236        } else {
237            match self.base.checked_add(&self.incr.into()).into_option() {
238                Some(mut num) => {
239                    if self.safe_primes {
240                        // Divide by 2 and ensure it's odd with an OR.
241                        num = num.wrapping_shl_vartime(1) | T::one_like(&self.base);
242                    }
243                    Some(num)
244                }
245                None => None,
246            }
247        };
248
249        self.incr += 2;
250        result
251    }
252
253    fn next(&mut self) -> Option<T> {
254        // Corner cases handled here
255        if self.produces_nothing {
256            return None;
257        }
258
259        if self.starts_from_exception {
260            self.starts_from_exception = false;
261            return Some(T::from(if self.safe_primes { 5u32 } else { 2u32 }));
262        }
263
264        // Main loop
265
266        while self.update_residues() {
267            match self.maybe_next() {
268                Some(x) => return Some(x),
269                None => continue,
270            };
271        }
272        None
273    }
274}
275
276impl<T: Integer> Iterator for SmallPrimesSieve<T> {
277    type Item = T;
278
279    fn next(&mut self) -> Option<Self::Item> {
280        Self::next(self)
281    }
282}
283
284/// A sieve returning numbers that are not multiples of a set of small factors.
285#[derive(Debug, Clone, Copy)]
286pub struct SmallPrimesSieveFactory<T> {
287    max_bit_length: NonZeroU32,
288    safe_primes: bool,
289    set_bits: SetBits,
290    phantom: PhantomData<T>,
291}
292
293impl<T: Integer + RandomBits> SmallPrimesSieveFactory<T> {
294    fn new_impl(max_bit_length: u32, set_bits: SetBits, safe_primes: bool) -> Self {
295        if !safe_primes && max_bit_length < 2 {
296            panic!("`bit_length` must be 2 or greater.");
297        }
298        if safe_primes && max_bit_length < 3 {
299            panic!("`bit_length` must be 3 or greater.");
300        }
301        let max_bit_length = NonZero::new(max_bit_length).expect("`bit_length` should be non-zero");
302        Self {
303            max_bit_length,
304            safe_primes,
305            set_bits,
306            phantom: PhantomData,
307        }
308    }
309
310    /// Creates a factory that produces sieves returning numbers of `max_bit_length` bits (with the top bit set)
311    /// that are not divisible by a number of small factors.
312    pub fn new(max_bit_length: u32, set_bits: SetBits) -> Self {
313        Self::new_impl(max_bit_length, set_bits, false)
314    }
315
316    /// Creates a factory that produces sieves returning numbers `n` of `max_bit_length` bits (with the top bit set)
317    /// such that neither `n` nor `(n - 1) / 2` are divisible by a number of small factors.
318    pub fn new_safe_primes(max_bit_length: u32, set_bits: SetBits) -> Self {
319        Self::new_impl(max_bit_length, set_bits, true)
320    }
321}
322
323impl<T: Integer + RandomBits> SieveFactory for SmallPrimesSieveFactory<T> {
324    type Item = T;
325    type Sieve = SmallPrimesSieve<T>;
326    fn make_sieve(
327        &mut self,
328        rng: &mut (impl CryptoRngCore + ?Sized),
329        _previous_sieve: Option<&Self::Sieve>,
330    ) -> Option<Self::Sieve> {
331        let start =
332            random_odd_integer::<T>(rng, self.max_bit_length, self.set_bits).expect("random_odd_integer() failed");
333        Some(SmallPrimesSieve::new(
334            start.get(),
335            self.max_bit_length,
336            self.safe_primes,
337        ))
338    }
339}
340
341#[cfg(test)]
342mod tests {
343
344    use alloc::format;
345    use alloc::vec::Vec;
346    use core::num::NonZero;
347
348    use crypto_bigint::U64;
349    use num_prime::nt_funcs::factorize64;
350    use rand_chacha::ChaCha8Rng;
351    use rand_core::{OsRng, SeedableRng};
352
353    use super::{random_odd_integer, SetBits, SmallPrimesSieve, SmallPrimesSieveFactory};
354    use crate::hazmat::precomputed::SMALL_PRIMES;
355
356    #[test]
357    fn random() {
358        let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
359
360        let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
361        let start = random_odd_integer::<U64>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
362            .unwrap()
363            .get();
364        for num in SmallPrimesSieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
365            let num_u64 = u64::from(num);
366            assert!(num_u64.leading_zeros() == 32);
367
368            let factors_and_powers = factorize64(num_u64);
369            let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
370
371            assert!(factors[0] > max_prime as u64);
372        }
373    }
374    #[test]
375    fn random_boxed() {
376        let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];
377
378        let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
379        let start = random_odd_integer::<crypto_bigint::BoxedUint>(&mut rng, NonZero::new(32).unwrap(), SetBits::Msb)
380            .unwrap()
381            .get();
382
383        for num in SmallPrimesSieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
384            // For 32-bit targets
385            #[allow(clippy::useless_conversion)]
386            let num_u64: u64 = num.as_words()[0].into();
387            assert!(num_u64.leading_zeros() == 32);
388
389            let factors_and_powers = factorize64(num_u64);
390            let factors = factors_and_powers.into_keys().collect::<Vec<_>>();
391
392            assert!(factors[0] > max_prime as u64);
393        }
394    }
395
396    fn check_sieve(start: u32, bit_length: u32, safe_prime: bool, reference: &[u32]) {
397        let test =
398            SmallPrimesSieve::new(U64::from(start), NonZero::new(bit_length).unwrap(), safe_prime).collect::<Vec<_>>();
399        assert_eq!(test.len(), reference.len());
400        for (x, y) in test.iter().zip(reference.iter()) {
401            assert_eq!(x, &U64::from(*y));
402        }
403    }
404
405    #[test]
406    fn empty_sequence() {
407        check_sieve(1, 1, false, &[]); // no primes of 1 bits
408        check_sieve(1, 2, true, &[]); // no safe primes of 2 bits
409        check_sieve(64, 6, true, &[]); // 64 is 7 bits long
410    }
411
412    #[test]
413    fn small_range() {
414        check_sieve(1, 2, false, &[2, 3]);
415        check_sieve(2, 2, false, &[2, 3]);
416        check_sieve(3, 2, false, &[3]);
417
418        check_sieve(1, 3, false, &[2, 3, 5, 7]);
419        check_sieve(3, 3, false, &[3, 5, 7]);
420        check_sieve(5, 3, false, &[5, 7]);
421        check_sieve(7, 3, false, &[7]);
422
423        check_sieve(1, 4, false, &[2, 3, 5, 7, 9, 11, 13, 15]);
424        check_sieve(3, 4, false, &[3, 5, 7, 9, 11, 13, 15]);
425        check_sieve(5, 4, false, &[5, 7, 11, 13]);
426        check_sieve(7, 4, false, &[7, 11, 13]);
427        check_sieve(9, 4, false, &[11, 13]);
428        check_sieve(13, 4, false, &[13]);
429        check_sieve(15, 4, false, &[]);
430
431        check_sieve(1, 3, true, &[5, 7]);
432        check_sieve(3, 3, true, &[5, 7]);
433        check_sieve(5, 3, true, &[5, 7]);
434        check_sieve(7, 3, true, &[7]);
435
436        // In the following three cases, the "half-start" would be set to 3,
437        // and since every small divisor equal or greater than the start is not tested
438        // (because we can't distinguish between the remainder being 0
439        // and the number being actually equal to the divisor),
440        // no divisors will actually be tested at all, so 15 (a composite)
441        // is included in the output.
442        check_sieve(1, 4, true, &[5, 7, 11, 15]);
443        check_sieve(5, 4, true, &[5, 7, 11, 15]);
444        check_sieve(7, 4, true, &[7, 11, 15]);
445
446        check_sieve(9, 4, true, &[11]);
447        check_sieve(13, 4, true, &[]);
448    }
449
450    #[test]
451    #[should_panic(expected = "The requested bit length (65) is larger than the precision of `start`")]
452    fn sieve_too_many_bits() {
453        let _sieve = SmallPrimesSieve::new(U64::ONE, NonZero::new(65).unwrap(), false);
454    }
455
456    #[test]
457    fn random_below_max_length() {
458        for _ in 0..10 {
459            let r = random_odd_integer::<U64>(&mut OsRng, NonZero::new(50).unwrap(), SetBits::Msb)
460                .unwrap()
461                .get();
462            assert_eq!(r.bits(), 50);
463        }
464    }
465
466    #[test]
467    fn random_odd_uint_too_many_bits() {
468        assert!(random_odd_integer::<U64>(&mut OsRng, NonZero::new(65).unwrap(), SetBits::Msb).is_err());
469    }
470
471    #[test]
472    fn sieve_derived_traits() {
473        let s = SmallPrimesSieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
474        // Debug
475        assert!(format!("{s:?}").starts_with("SmallPrimesSieve"));
476        // Clone
477        assert_eq!(s.clone(), s);
478
479        // PartialEq
480        let s2 = SmallPrimesSieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
481        assert_eq!(s, s2);
482        let s3 = SmallPrimesSieve::new(U64::ONE, NonZero::new(12).unwrap(), false);
483        assert_ne!(s, s3);
484    }
485
486    #[test]
487    fn sieve_with_max_start() {
488        let start = U64::MAX;
489        let mut sieve = SmallPrimesSieve::new(start, NonZero::new(U64::BITS).unwrap(), false);
490        assert!(sieve.next().is_none());
491    }
492
493    #[test]
494    #[should_panic(expected = "`bit_length` must be 2 or greater")]
495    fn too_few_bits_regular_primes() {
496        let _fac = SmallPrimesSieveFactory::<U64>::new(1, SetBits::Msb);
497    }
498
499    #[test]
500    #[should_panic(expected = "`bit_length` must be 3 or greater")]
501    fn too_few_bits_safe_primes() {
502        let _fac = SmallPrimesSieveFactory::<U64>::new_safe_primes(2, SetBits::Msb);
503    }
504
505    #[test]
506    fn set_bits() {
507        for _ in 0..10 {
508            let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::Msb).unwrap();
509            assert!(bool::from(x.bit(63)));
510        }
511
512        for _ in 0..10 {
513            let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::TwoMsb).unwrap();
514            assert!(bool::from(x.bit(63)));
515            assert!(bool::from(x.bit(62)));
516        }
517
518        // 1 in 2^30 chance of spurious failure... good enough?
519        assert!((0..30)
520            .map(|_| { random_odd_integer::<U64>(&mut OsRng, NonZero::new(64).unwrap(), SetBits::None).unwrap() })
521            .any(|x| !bool::from(x.bit(63))));
522    }
523
524    #[test]
525    fn set_two_msb_small_bit_length() {
526        // Check that when technically there isn't a second most significant bit,
527        // `random_odd_integer()` still returns a number.
528        let x = random_odd_integer::<U64>(&mut OsRng, NonZero::new(1).unwrap(), SetBits::TwoMsb)
529            .unwrap()
530            .get();
531        assert_eq!(x, U64::ONE);
532    }
533}