Skip to main content

lib_modulo/
factorize.rs

1use std::num::NonZero;
2
3use crate::{prime::primality_test, Context64};
4
5// 12 bytes * 6541 ~ 75 KiB
6static SMALL_ODD_PRIME_CONTEXT_16: [(u16, u64, u16); 6541] =
7    include!("./small_prime_context_u16_raw.rs");
8
9/// Factorize integer and writes prime factors to `factor` in any order.
10///
11/// This function is probabilistic and may fail.
12///
13/// # Time complexity
14///
15/// O(`x`^0.25) expected
16///
17/// # Example
18///
19/// ```
20/// use lib_modulo::factorize::*;
21///
22/// let mut factor = Vec::new();
23/// // panics if factorization fails
24/// assert!(factorize(998_244_353 * 1_000_000_007, &mut factor).is_ok());
25///
26/// factor.sort_unstable();
27/// assert_eq!(factor, vec![998_244_353, 1_000_000_007])
28/// ```
29pub fn factorize(mut x: u64, factor: &mut Vec<u64>) -> Result<(), ()> {
30    if x < 2 {
31        return Ok(());
32    }
33    factor.reserve(64);
34
35    // trial division by small primes less than 2^16
36    {
37        factor.extend(std::iter::repeat_n(2, x.trailing_zeros() as usize));
38        x >>= x.trailing_zeros();
39    }
40    for &(n, inv_n, r2_mod_n) in SMALL_ODD_PRIME_CONTEXT_16.iter() {
41        let ctx = Context64 {
42            n: n as u64,
43            inv_n,
44            r2_mod_n: r2_mod_n as u64,
45        };
46
47        while ctx.can_divide(x) {
48            x /= ctx.n;
49            factor.push(ctx.n);
50        }
51
52        if x == 1 {
53            return Ok(());
54        }
55    }
56
57    // find large prime factors (up to 3) by Pollard's rho
58    while x > 1 {
59        if primality_test(x) {
60            factor.push(x);
61            return Ok(());
62        }
63
64        if let Some(d) = pollard_rho(x) {
65            let d = d.get();
66            while x % d == 0 {
67                x /= d;
68                factor.push(d);
69            }
70        } else {
71            return Err(());
72        }
73    }
74
75    Ok(())
76}
77
78/// Find prime factor of `x`.
79///
80/// This function is probabilistic and may fail.
81///
82/// # Time complexity
83///
84/// *O*(p^0.25) where p is a prime factor of `x`
85fn pollard_rho(x: u64) -> Option<NonZero<u64>> {
86    let ctx = Context64::new(x);
87    let one = ctx.modulo(1);
88
89    for c in 1..100 {
90        // a = b (mod x) => f(a) = f(b) (mod x)
91        let f = |x: u64| ctx.mul_add(x, x, c);
92
93        let mut y0 = ctx.modulo(1);
94        let mut y1 = y0;
95
96        let mut prod = one;
97        let mut step = 0;
98        let mut memo = [[0, 0, one.value]; 1 << 5];
99
100        'cycle_detection: while !prod.is_zero() {
101            y0.value = f(y0.value);
102            y1.value = f(f(y1.value));
103            prod *= y1 - y0;
104            step += 1;
105
106            if step % (1 << 5) == 0 {
107                memo[(step >> 5) % (1 << 5)] = [y0.value, y1.value, prod.value];
108            }
109            if step % (1 << 10) == 0 || prod.is_zero() {
110                let g = binary_gcd(prod.value, x);
111
112                if g == 1 {
113                    continue 'cycle_detection;
114                } else if primality_test(g) {
115                    return NonZero::new(g);
116                }
117
118                for i in 0..memo.len() {
119                    let g = binary_gcd(memo[i][2], x);
120
121                    if g != 1 {
122                        if primality_test(g) {
123                            return NonZero::new(g);
124                        }
125
126                        y0.value = memo[i][0];
127                        y1.value = memo[i][1];
128                        for _ in 0..1 << 5 {
129                            let g = binary_gcd((y0 - y1).value, x);
130
131                            if g != 1 {
132                                if primality_test(g) {
133                                    return NonZero::new(g);
134                                } else if g != x {
135                                    // FIXME: `x` is composed of at most 3 primes, so return `x/g`
136                                    return pollard_rho(g);
137                                } else {
138                                    break 'cycle_detection;
139                                }
140                            }
141
142                            y0.value = f(y0.value);
143                            y1.value = f(f(y1.value));
144                        }
145                    }
146                }
147            }
148        }
149    }
150
151    None
152}
153
154#[inline(always)]
155fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
156    if b == 0 {
157        return a;
158    }
159
160    let shift = (a | b).trailing_zeros();
161    b >>= b.trailing_zeros();
162
163    while a != 0 {
164        a >>= a.trailing_zeros();
165
166        if a < b {
167            (a, b) = (b, a)
168        }
169        a -= b
170    }
171
172    b << shift
173}
174
175#[cfg(test)]
176mod tests {
177    use rand::{rng, seq::SliceRandom, Rng};
178
179    use super::*;
180
181    #[test]
182    fn random() {
183        let mut rng = rng();
184        for n in std::iter::repeat_with(|| rng.random_range(1 << 55..=u64::MAX)).take(10_000) {
185            let mut factor = Vec::new();
186
187            assert!(factorize(n, &mut factor).is_ok());
188            assert_eq!(n, factor.iter().product())
189        }
190    }
191
192    #[test]
193    fn random_square() {
194        let mut rng = rng();
195        for n in std::iter::repeat_with(|| rng.random_range(1 << 20..1 << 32)).take(5000) {
196            let mut factor = Vec::new();
197
198            assert!(factorize(n * n, &mut factor).is_ok());
199            assert_eq!(n * n, factor.iter().product())
200        }
201    }
202
203    #[test]
204    fn prime_square() {
205        for n in (0..1 << 32)
206            .rev()
207            .step_by(2)
208            .filter(|n| primality_test(*n))
209            .take(500)
210        {
211            let mut factor = Vec::new();
212
213            assert!(factorize(n * n, &mut factor).is_ok());
214            assert_eq!(n * n, factor.iter().product())
215        }
216    }
217
218    // fast since p is relatively small
219    #[test]
220    fn prime_cube() {
221        let p = Vec::from_iter(
222            (0..1 << 21)
223                .rev()
224                .step_by(2)
225                .filter(|n| primality_test(*n))
226                .take(500),
227        );
228
229        for p in p {
230            let n = p.pow(3);
231            let mut factor = Vec::new();
232
233            assert!(factorize(n, &mut factor).is_ok());
234            assert_eq!(n, factor.iter().product())
235        }
236    }
237
238    #[test]
239    fn prime_double() {
240        let mut p: Vec<u64> = (0..1 << 32)
241            .rev()
242            .step_by(2)
243            .filter(|n| primality_test(*n))
244            .take(500)
245            .collect();
246        p.shuffle(&mut rng());
247
248        for p in p.windows(2) {
249            let n = p[0] * p[1];
250            let mut factor = Vec::new();
251
252            assert!(factorize(n, &mut factor).is_ok());
253            assert_eq!(n, factor.iter().product())
254        }
255    }
256
257    #[test]
258    fn prime_triple() {
259        let mut p: Vec<u64> = (0..1 << 21)
260            .rev()
261            .step_by(2)
262            .filter(|n| primality_test(*n))
263            .take(500)
264            .collect();
265        p.shuffle(&mut rng());
266
267        for p in p.windows(3) {
268            let n = p[0] * p[1] * p[2];
269            let mut factor = Vec::new();
270
271            assert!(factorize(n, &mut factor).is_ok());
272            assert_eq!(n, factor.iter().product())
273        }
274    }
275}