Skip to main content

primefactor/
candidates.rs

1//! There are many functions that can produce prime number candidates, but only
2//! a few that are guaranteed to produce all primes.
3//!
4//! Implementations of Prime wheels for number factorization
5//! https://en.wikipedia.org/wiki/Wheel_factorization
6//!
7//! We can omit overflow bounds checks for the wheel iterators, since the
8//! callers stop consuming them well before values approach the u128 limit.
9//! In `factorize`, the iterator is only consumed up to sqrt(n), which for
10//! the maximum u128 value is approximately 1.84e19.
11//!
12/// Wheel factorization algorithm with base {2, 3, 5} (30 spokes).
13///
14/// This is an infinite iterator; callers must provide a termination condition.
15/// It is designed for use in trial division up to √n.
16#[derive(Clone, Debug, Default, Eq, PartialEq)]
17pub struct PrimeWheel30 {
18    base: u128,
19    index: usize,
20}
21
22impl PrimeWheel30 {
23    const GAPS: [u128; 12] = [
24        2, // +2 = 2
25        1, // +1 = 3
26        2, // +2 = 5
27        2, // +2 = 7 (index 3, end of initial phase)
28        4, // +4 = 11 + n * 30 (index 4, start of cycle)
29        2, // +2 = 13 + n * 30
30        4, // +4 = 17 + n * 30
31        2, // +2 = 19 + n * 30
32        4, // +4 = 23 + n * 30
33        6, // +6 = 29 + n * 30
34        2, // +2 = 31 + n * 30
35        6  // +6 = 37 + n * 30 (index 11, end of cycle)
36    ];
37    pub fn new() -> Self {
38        Self::default()
39    }
40}
41
42impl Iterator for PrimeWheel30 {
43    type Item = u128;
44
45    #[inline(always)]
46    fn next(&mut self) -> Option<Self::Item> {
47        let gap = Self::GAPS.get(self.index)?;
48        self.base += gap; 
49        self.index += 1;
50        if self.index == 12 {
51            self.index = 4;
52        }
53        Some(self.base)
54    }
55}
56
57/// Wheel factorization algorithm with base {2, 3, 5, 7} (210 spokes).
58///
59/// This is an infinite iterator; callers must provide a termination condition.
60/// It is designed for use in trial division up to √n.
61#[derive(Clone, Debug, Default, Eq, PartialEq)]
62pub struct PrimeWheel210 {
63    base: u128,
64    index: usize,
65}
66
67impl PrimeWheel210 {
68    const GAPS: [u128; 53] = [
69        2, 1, 2, 2, 4, // initial phase: 2, 3, 5, 7, 11 (index 0-4)
70        2, 4, 2, 4, 6, 2, 6, 4, 2, 4, 6, 6, 2, 6, 4, // 13..71 (index 5, start of cycle)
71        2, 6, 4, 6, 8, 4, 2, 4, 2, 4, 8, 6, 4, 6, 2, 4, // 73..143
72        6, 2, 6, 6, 4, 2, 4, 6, 2, 6, 4, 2, 4, 2, 10, 2, // 149..211
73        10 // 221 + n * 210 (index 52, end of cycle, wraps to index 5)
74    ];
75    pub fn new() -> Self {
76        Self::default()
77    }
78    /// Create a wheel that will yield candidates >= `start`.
79    /// The first call to `next()` returns the first candidate at or above `start`.
80    pub fn from(start: u128) -> Self {
81        if start <= 2 {
82            return Self::default();
83        }
84        // For small starts, walk through the initial phase
85        if start <= 11 {
86            let mut base: u128 = 0;
87            for (i, &gap) in Self::GAPS.iter().enumerate().take(5) {
88                if base + gap >= start {
89                    return Self { base, index: i };
90                }
91                base += gap;
92            }
93            return Self { base: 7, index: 4 };
94        }
95        // Jump into the correct 210-cycle.
96        // Each cycle starts at base = 11 + k*210 and produces candidates
97        // from base+2 up to base+210 (48 candidates spanning 210 values).
98        let k = start.saturating_sub(13) / 210;
99        let mut pos = 11 + k * 210;
100        for (i, &gap) in Self::GAPS[5..].iter().enumerate() {
101            if pos + gap >= start {
102                return Self { base: pos, index: i + 5 };
103            }
104            pos += gap;
105        }
106        // All candidates in this cycle are below start; use next cycle
107        Self { base: pos, index: 5 }
108    }
109}
110
111impl Iterator for PrimeWheel210 {
112    type Item = u128;
113
114    #[inline(always)]
115    fn next(&mut self) -> Option<Self::Item> {
116        let gap = Self::GAPS.get(self.index)?;
117        self.base += gap;
118        self.index += 1;
119        if self.index == 53 {
120            self.index = 5;
121        }
122        Some(self.base)
123    }
124}
125
126impl PrimeWheel210 {
127    /// Navigate the wheel backward to yield the previous prime candidate.
128    /// This works directly on the internal state without modifying the forward Iterator.
129    /// It returns `None` when attempting to go before 2.
130    pub fn prev(&mut self) -> Option<u128> {
131        if self.base < 2 {
132            return None;
133        }
134        let current = self.base;
135        // Reverse step: If we are at the beginning of the recurring cycle
136        // (index 5), wrap backwards to the end of the previous cycle.
137        if self.index == 5 && self.base > 11 {
138            self.index = 53;
139        }
140        if self.index > 0 {
141            self.index -= 1;
142            let gap = Self::GAPS[self.index];
143            self.base = self.base.saturating_sub(gap);
144        } else {
145            self.base = 0;
146        }
147        Some(current)
148    }
149}
150
151/// Fast prime candidate filter using the 210-spoke wheel bitmap.
152/// Returns false for any number divisible by 2, 3, 5, or 7,
153/// eliminating ~77% of all composites with a single modulo + bit-test.
154#[inline(always)]
155pub(crate) fn is_prime_candidate(n: u128) -> bool {
156    if n < 11 {
157        return matches!(n, 2 | 3 | 5 | 7);
158    }
159    const BITMAP: [u32; 7] = [
160        0xa08a_2802, 0x2820_8a20, 0x0208_8288, 0x8202_28a2,
161        0x20a0_8a08, 0x8828_2288, 0x0002_00a2,
162    ];
163    let index = (n % 210) as usize;
164    BITMAP[index / 32] & (1 << (index & 0x1F)) != 0
165}
166
167/// Modular exponentiation: (base^exp) mod modulus.
168#[inline]
169fn mod_pow(mut base: u128, mut exp: u128, modulus: u128) -> u128 {
170    if modulus == 1 { return 0; }
171    let mut result: u128 = 1;
172    base %= modulus;
173    while exp > 0 {
174        if exp & 1 == 1 {
175            result = mod_mul(result, base, modulus);
176        }
177        exp >>= 1;
178        if exp > 0 {
179            base = mod_mul(base, base, modulus);
180        }
181    }
182    result
183}
184
185/// Modular addition: (a + b) mod m, without overflow.
186/// Requires a < m and b < m.
187#[inline]
188fn add_mod(a: u128, b: u128, m: u128) -> u128 {
189    debug_assert!(a < m);
190    debug_assert!(b < m);
191    if a >= m - b {
192        a - (m - b)
193    } else {
194        a + b
195    }
196}
197
198/// Modular multiplication: (a * b) mod m, without overflow.
199/// Uses direct multiplication when the product fits in u128.
200/// For larger products, it uses Russian peasant multiplication.
201#[inline]
202fn mod_mul(a: u128, b: u128, m: u128) -> u128 {
203    debug_assert!(m > 0);
204    // For small moduli where a*b won't overflow u128, use direct multiplication
205    if a.leading_zeros() + b.leading_zeros() >= 128 {
206        return (a * b) % m;
207    }
208    let mut result: u128 = 0;
209    let mut a = a % m;
210    let mut b = b % m;
211    // Full-range safe fallback for very large moduli.
212    while b > 0 {
213        if b & 1 == 1 {
214            result = add_mod(result, a, m);
215        }
216        b >>= 1;
217        if b > 0 {
218            a = add_mod(a, a, m);
219        }
220    }
221    result
222}
223
224/// Test a single Miller-Rabin witness against n.
225/// Returns true if n passes the test for this witness (probably prime).
226fn miller_rabin_witness(n: u128, a: u128, d: u128, r: u32) -> bool {
227    debug_assert!(n >= 2);
228    let mut x = mod_pow(a, d, n);
229    if x == 1 || x == n - 1 {
230        return true;
231    }
232    for _ in 1..r {
233        x = mod_mul(x, x, n);
234        if x == n - 1 {
235            return true;
236        }
237    }
238    false
239}
240
241/// Deterministic Miller-Rabin primality test (for n >= 2).
242///
243/// Uses witnesses {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37} which are
244/// proven sufficient for all numbers below 3,317,044,064,679,887,385,961,981.
245///
246/// Reference: <https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test>
247pub(crate) fn miller_rabin(n: u128) -> bool {
248    const WITNESSES: [u128; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
249    debug_assert!(n >= 2);
250    let n_minus_1 = n - 1;
251    let r = n_minus_1.trailing_zeros();
252    let d = n_minus_1 >> r;
253    WITNESSES.iter().all(|&a| a >= n || miller_rabin_witness(n, a, d, r))
254}
255
256#[cfg(test)]
257mod tests {
258    use reikna::prime::{is_prime, next_prime};
259    use super::{PrimeWheel30, PrimeWheel210, add_mod, mod_mul};
260
261    fn mod_mul_reference(a: u128, b: u128, m: u128) -> u128 {
262        let mut result = 0;
263        let mut a = a % m;
264        let mut b = b % m;
265        while b > 0 {
266            if b & 1 == 1 {
267                result = add_mod(result, a, m);
268            }
269            b >>= 1;
270            if b > 0 {
271                a = add_mod(a, a, m);
272            }
273        }
274
275        result
276    }
277
278    #[test]
279    fn test_prime_wheel_30_first_1000() {
280        let mut wheel = PrimeWheel30::new();
281        let mut misses = 0;
282        let mut p = 0;
283        for _ in 0..1000 {
284            p = next_prime(p);
285            for n in wheel.by_ref() {
286                if n == p as u128 {
287                    break;
288                }
289                assert!(!is_prime(n as u64));
290                misses += 1;
291            }
292        }
293        // Assert the exact number of expected misses for the first 1000 primes
294        assert_eq!(misses, 1114);
295    }
296
297    #[test]
298    fn test_prime_wheel_210_first_1000() {
299        let mut wheel = PrimeWheel210::new();
300        let mut misses = 0;
301        let mut p = 0;
302        for _ in 0..1000 {
303            p = next_prime(p);
304            for n in wheel.by_ref() {
305                if n == p as u128 {
306                    break;
307                }
308                assert!(!is_prime(n as u64));
309                misses += 1;
310            }
311        }
312        // Assert the exact number of expected misses for the first 1000 primes
313        assert_eq!(misses, 813);
314    }
315
316    #[test]
317    fn test_prime_wheel_30_quality() {
318        const TOTAL: u128 = 1000000;
319        let mut primes: u128 = 0;
320        let pw_iter = PrimeWheel30::new();
321        for p in pw_iter.take(TOTAL as usize) {
322            primes += is_prime(p as u64) as u128;
323        }
324        let percent = primes as f64 / TOTAL as f64 * 100.0;
325        println!("Prime wheel generated {}/{} ({:.3}%) primes",
326                primes, TOTAL, percent);
327        assert!(percent > 25.0);
328    }
329
330    #[test]
331    fn test_prime_wheel_210_quality() {
332        const TOTAL: u128 = 1000000;
333        let mut primes: u128 = 0;
334        let pw_iter = PrimeWheel210::new();
335        for p in pw_iter.take(TOTAL as usize) {
336            primes += is_prime(p as u64) as u128;
337        }
338        let percent = primes as f64 / TOTAL as f64 * 100.0;
339        println!("Prime wheel generated {}/{} ({:.3}%) primes",
340                primes, TOTAL, percent);
341        assert!(percent > 30.0);
342    }
343
344    #[test]
345    fn test_add_mod_large_values() {
346        let m = u128::MAX - 158;
347        let a = m - 1;
348        let b = m - 1;
349        assert_eq!(add_mod(a, b, m), m - 2);
350    }
351
352    #[test]
353    fn test_mod_mul_matches_direct_below_2pow127_boundary() {
354        let m = (1_u128 << 127) - 1;
355        let values = [
356            0,
357            1,
358            2,
359            3,
360            17,
361            (1_u128 << 64) + 13,
362            m - 2,
363            m - 1,
364        ];
365
366        for &a in &values {
367            for &b in &values {
368                assert_eq!(mod_mul(a, b, m), mod_mul_reference(a, b, m));
369            }
370        }
371    }
372
373    #[test]
374    fn test_mod_mul_large_modulus_regression_cases() {
375        let m = u128::MAX - 158;
376        let cases = [
377            (m - 1, m - 1, 1),
378            (m - 1, 2, m - 2),
379            (m - 2, 2, m - 4),
380            (m - 1, m - 2, 2),
381        ];
382
383        for (a, b, expected) in cases {
384            assert_eq!(mod_mul(a, b, m), expected);
385            assert_eq!(mod_mul(b, a, m), expected);
386        }
387    }
388
389    #[test]
390    fn test_prime_wheel_210_from_matches_new() {
391        // from(0), from(1), from(2) should all behave like new()
392        for start in [0, 1, 2] {
393            let from_iter: Vec<u128> = PrimeWheel210::from(start).take(200).collect();
394            let new_iter: Vec<u128> = PrimeWheel210::new().take(200).collect();
395            assert_eq!(from_iter, new_iter, "from({start}) differs from new()");
396        }
397    }
398
399    #[test]
400    fn test_prime_wheel_210_from_initial_primes() {
401        // Starting at each small prime should yield that prime first
402        for &p in &[2u128, 3, 5, 7, 11] {
403            let first = PrimeWheel210::from(p).next().unwrap();
404            assert_eq!(first, p, "from({p}) should yield {p} first");
405        }
406    }
407
408    #[test]
409    fn test_prime_wheel_210_from_never_skips_candidates() {
410        // For every starting point 0..=500, verify that from(start)
411        // yields a subset of new()'s output, starting at the right place
412        let all: Vec<u128> = PrimeWheel210::new().take(2000).collect();
413        for start in 0..=500 {
414            let first = match PrimeWheel210::from(start).next() {
415                Some(v) => v,
416                None => continue,
417            };
418            assert!(first >= start,
419                "from({start}) yielded {first} which is below start");
420            // The first value must appear in the full sequence
421            let pos = all.iter().position(|&v| v == first)
422                .unwrap_or_else(|| panic!(
423                    "from({start}) yielded {first} not in wheel sequence"));
424            // All subsequent values must match the full sequence
425            let from_vals: Vec<u128> = PrimeWheel210::from(start).take(50).collect();
426            assert_eq!(from_vals, all[pos..pos + 50],
427                "from({start}) sequence diverges from new() at offset {pos}");
428        }
429    }
430
431    #[test]
432    fn test_prime_wheel_210_from_at_cycle_boundaries() {
433        // Test at exact 210-block boundaries
434        let all: Vec<u128> = PrimeWheel210::new().take(5000).collect();
435        for block in [1u128, 2, 5, 10, 100, 1000] {
436            let boundary = block * 210;
437            for offset in [0, 1, 2, 209, 210, 211] {
438                let start = boundary + offset;
439                let first = PrimeWheel210::from(start).next().unwrap();
440                assert!(first >= start,
441                    "from({start}) yielded {first} below start");
442                if first < *all.last().unwrap() {
443                    assert!(all.contains(&first),
444                        "from({start}): {first} not a wheel candidate");
445                }
446            }
447        }
448    }
449
450    #[test]
451    fn test_prime_wheel_210_from_finds_all_primes() {
452        // Verify that starting from various points, we don't miss any primes
453        for start in [0u128, 1, 13, 100, 210, 211, 420, 1000, 10000] {
454            let mut wheel = PrimeWheel210::from(start);
455            let mut p = if start <= 2 { 0 } else { (start - 1) as u64 };
456            // Check the next 100 primes from this starting point
457            for _ in 0..100 {
458                p = next_prime(p);
459                if (p as u128) < start { continue; }
460                for n in wheel.by_ref() {
461                    if n == p as u128 {
462                        break;
463                    }
464                    assert!(!is_prime(n as u64),
465                        "from({start}): wheel candidate {n} is prime but was skipped");
466                }
467            }
468        }
469    }
470}