Skip to main content

lib_modulo/
prime.rs

1use crate::Modulus64;
2
3/// Performs deterministic Miller-Rabin primality test.
4///
5/// # Time complexity
6///
7/// *O*(log *x*)
8pub const fn primality_test(x: u64) -> bool {
9    if x < 64 {
10        return (super::PRIME_LT_64 >> x) & 1 == 1;
11    } else if (super::COPRIME_2_3_5 >> (x % 30)) & 1 == 0 || x % 7 == 0 {
12        return false;
13    }
14
15    let (s, d) = {
16        let x = x - 1;
17        let s = x.trailing_zeros();
18        (s - 1, x >> s)
19    };
20
21    let modulus = Modulus64::new(x);
22    let one = modulus.residue(1).x;
23    // (a - a) r = 0 (mod x), r % x != 0
24    let neg_one = x - one;
25
26    // from <https://miller-rabin.appspot.com/>
27    let witness = if x < 350_269_456_337 {
28        static SET3: [u64; 3] = [0x3AB4F88FF0CC7C80, 0xCBEE4CDF120C10AA, 0xE6F1343B0EDCA8E7];
29        SET3.as_slice()
30    } else if x < 7_999_252_175_582_851 {
31        static SET5: [u64; 5] = [
32            2,
33            0x3C1C7396F6D,
34            0x2142E2E3F22DE5C,
35            0x297105B6B7B29DD,
36            0x370EB221A5F176DD,
37        ];
38        SET5.as_slice()
39    } else {
40        static SET7: [u64; 7] = [2, 325, 9375, 28178, 450775, 9780504, 1795265022];
41        SET7.as_slice()
42    };
43
44    let mut i = 0;
45    'test: while i < witness.len() {
46        let mut mint = modulus.residue(witness[i]);
47        i += 1;
48
49        if mint.is_zero() {
50            continue;
51        }
52
53        mint = mint.pow(d);
54        if mint.x == one || mint.x == neg_one {
55            continue;
56        }
57
58        let mut s = s;
59        while s > 0 {
60            s -= 1;
61
62            mint.x = mint.modulus.mul(mint.x, mint.x);
63            if mint.x == neg_one {
64                continue 'test;
65            }
66        }
67
68        return false;
69    }
70
71    true
72}
73
74// #[test]
75// fn f() {
76//     use std::io::Write;
77
78//     let mut f = std::fs::File::create("./src/small_prime_context_u16_raw.rs").unwrap();
79
80//     let _ = f.write(b"[\n");
81//     for n in (3..1 << 16).step_by(2) {
82//         if primality_test(n) {
83//             let modulus = Context64::new(n);
84
85//             let _ = f.write(format!("({}, {}, {}),\n", modulus.n, modulus.inv_n, modulus.r2_mod_n).as_bytes());
86//         }
87//     }
88//     let _ = f.write(b"]");
89// }
90
91#[cfg(test)]
92mod tests {
93    use rand::{rng, Rng};
94
95    use super::*;
96
97    const fn primality_test_naive(x: u64) -> bool {
98        if x < 2 {
99            return false;
100        }
101
102        let mut d = 1;
103        while d < x.isqrt() {
104            d += 1;
105
106            if x % d == 0 {
107                return false;
108            }
109        }
110
111        true
112    }
113
114    #[test]
115    fn small() {
116        for x in 0..500_000 {
117            assert_eq!(primality_test(x), primality_test_naive(x), "{x}")
118        }
119    }
120
121    #[test]
122    fn intermediate() {
123        let mut rng = rng();
124
125        for x in std::iter::repeat_with(|| rng.random_range(1 << 30..1 << 40)).take(100) {
126            assert_eq!(primality_test(x), primality_test_naive(x), "{x}")
127        }
128    }
129}