Skip to main content

lib_modulo/
factorize.rs

1use std::num::NonZero;
2
3use crate::{prime::primality_test, Context64};
4
5// 12 bytes * 6541 ~ 75 KiB
6static SMALL_ODD_PRIME_CONTEXT_16: [(u16, u64, u16); 6541] =
7    include!("./small_prime_context_u16_raw.rs");
8
9/// Factorize integer and writes prime factors to `factor` in any order.
10///
11/// This function is probabilistic and may fail.
12///
13/// # Time complexity
14///
15/// O(`x`^0.25) expected
16///
17/// # Example
18///
19/// ```
20/// use lib_modulo::factorize::*;
21///
22/// let mut factor = Vec::new();
23/// // panics if factorization fails
24/// assert!(factorize(998_244_353 * 1_000_000_007, &mut factor).is_ok());
25///
26/// factor.sort_unstable();
27/// assert_eq!(factor, vec![998_244_353, 1_000_000_007])
28/// ```
29pub fn factorize(mut x: u64, factor: &mut Vec<u64>) -> Result<(), ()> {
30    if x < 2 {
31        return Ok(());
32    }
33    factor.reserve(64);
34
35    // trial division by small primes less than 2^16
36    {
37        factor.extend(std::iter::repeat_n(2, x.trailing_zeros() as usize));
38        x >>= x.trailing_zeros();
39    }
40    for &(n, inv_n, r2_mod_n) in SMALL_ODD_PRIME_CONTEXT_16.iter() {
41        let ctx = Context64 {
42            n: n as u64,
43            inv_n,
44            r2_mod_n: r2_mod_n as u64,
45        };
46
47        while ctx.can_divide(x) {
48            x /= ctx.n;
49            factor.push(ctx.n);
50        }
51
52        if x == 1 {
53            return Ok(());
54        }
55    }
56
57    // find large prime factors (up to 3) by Pollard's rho
58    while x > 1 {
59        if primality_test(x) {
60            factor.push(x);
61            return Ok(());
62        }
63
64        if let Some(d) = pollard_rho(x) {
65            let d = d.get();
66            while x % d == 0 {
67                x /= d;
68                factor.push(d);
69            }
70        } else {
71            return Err(());
72        }
73    }
74
75    Ok(())
76}
77
78/// Find prime factor of `x`.
79///
80/// This function is probabilistic and may fail.
81///
82/// # Time complexity
83///
84/// *O*(p^0.25) where p is a prime factor of `x`
85fn pollard_rho(x: u64) -> Option<NonZero<u64>> {
86    let ctx = Context64::new(x);
87    let one = ctx.modulo(1);
88
89    for c in 1..100 {
90        // a = b (mod x) => f(a) = f(b) (mod x)
91        let f = |x: u64| ctx.mul_add(x, x, c);
92
93        let mut y0 = ctx.modulo(1);
94        let mut y1 = y0;
95
96        let mut prod = one;
97        let mut step = 0;
98        let mut memo = [[0, 0, one.value]; 1 << 5];
99
100        'a: while !prod.is_zero() {
101            y0.value = f(y0.value);
102            y1.value = f(f(y1.value));
103            prod *= y1 - y0;
104            step += 1;
105
106            if step % (1 << 5) == 0 {
107                memo[(step >> 5) % (1 << 5)] = [y0.value, y1.value, prod.value];
108            }
109            if step % (1 << 10) == 0 {
110                let g = binary_gcd(prod.value, x);
111
112                if g == 1 {
113                    continue 'a;
114                } else if primality_test(g) {
115                    return NonZero::new(g);
116                }
117
118                for i in 0..memo.len() {
119                    let g = binary_gcd(memo[i][2], x);
120
121                    if g != 1 {
122                        if primality_test(g) {
123                            return NonZero::new(g);
124                        }
125
126                        y0.value = memo[i][0];
127                        y1.value = memo[i][1];
128                        for _ in 0..1 << 5 {
129                            let g = binary_gcd((y0 - y1).value, x);
130
131                            if g != 1 {
132                                if primality_test(g) {
133                                    return NonZero::new(g);
134                                } else if g != x {
135                                    // FIXME: `x` is composed of at most 3 primes, so return `x/g`
136                                    return pollard_rho(g);
137                                } else {
138                                    break 'a;
139                                }
140                            }
141
142                            y0.value = f(y0.value);
143                            y1.value = f(f(y1.value));
144                        }
145                    }
146                }
147            }
148        }
149
150        'a: for i in 0..(step % (1 << 10)) >> 5 {
151            let g = binary_gcd(memo[i][2], x);
152
153            if g != 1 {
154                if primality_test(g) {
155                    return NonZero::new(g);
156                }
157
158                y0.value = memo[i][0];
159                y1.value = memo[i][1];
160                for _ in 0..1 << 5 {
161                    let g = binary_gcd((y0 - y1).value, x);
162
163                    if g != 1 {
164                        if primality_test(g) {
165                            return NonZero::new(g);
166                        } else if g != x {
167                            // FIXME: `x` is composed of at most 3 primes, so return `x/g`
168                            return pollard_rho(g);
169                        } else {
170                            break 'a;
171                        }
172                    }
173
174                    y0.value = f(y0.value);
175                    y1.value = f(f(y1.value));
176                }
177            }
178        }
179    }
180
181    None
182}
183
184#[inline(always)]
185fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
186    if b == 0 {
187        return a;
188    }
189
190    let shift = (a | b).trailing_zeros();
191    b >>= b.trailing_zeros();
192
193    while a != 0 {
194        a >>= a.trailing_zeros();
195
196        if a < b {
197            (a, b) = (b, a)
198        }
199        a -= b
200    }
201
202    b << shift
203}
204
205#[cfg(test)]
206mod tests {
207    use rand::{rng, seq::SliceRandom, Rng};
208
209    use super::*;
210
211    #[test]
212    fn random_square() {
213        let mut rng = rng();
214        for n in std::iter::repeat_with(|| rng.random_range(1 << 20..1 << 32)).take(5000) {
215            let mut factor = Vec::new();
216
217            assert!(factorize(n * n, &mut factor).is_ok());
218            assert_eq!(n * n, factor.iter().product())
219        }
220    }
221
222    #[test]
223    fn random_cube() {
224        let mut rng = rng();
225        for n in std::iter::repeat_with(|| rng.random_range(1 << 16..1 << 21)).take(5000) {
226            let mut factor = Vec::new();
227
228            assert!(factorize(n * n, &mut factor).is_ok());
229            assert_eq!(n * n, factor.iter().product())
230        }
231    }
232
233    #[test]
234    fn prime_square() {
235        for n in (0..1 << 32)
236            .rev()
237            .step_by(2)
238            .filter(|n| primality_test(*n))
239            .take(500)
240        {
241            let mut factor = Vec::new();
242
243            assert!(factorize(n * n, &mut factor).is_ok());
244            assert_eq!(n * n, factor.iter().product())
245        }
246    }
247
248    // fast since p is relatively small
249    #[test]
250    fn prime_cube() {
251        let p = Vec::from_iter(
252            (0..1 << 21)
253                .rev()
254                .step_by(2)
255                .filter(|n| primality_test(*n))
256                .take(500),
257        );
258
259        for p in p {
260            let n = p.pow(3);
261            let mut factor = Vec::new();
262
263            assert!(factorize(n, &mut factor).is_ok());
264            assert_eq!(n, factor.iter().product())
265        }
266    }
267
268    #[test]
269    fn prime_double() {
270        let mut p: Vec<u64> = (0..1 << 32)
271            .rev()
272            .step_by(2)
273            .filter(|n| primality_test(*n))
274            .take(500)
275            .collect();
276        p.shuffle(&mut rng());
277
278        for p in p.windows(2) {
279            let n = p[0] * p[1];
280            let mut factor = Vec::new();
281
282            assert!(factorize(n, &mut factor).is_ok());
283            assert_eq!(n, factor.iter().product())
284        }
285    }
286
287    #[test]
288    fn prime_triple() {
289        let mut p: Vec<u64> = (0..1 << 21)
290            .rev()
291            .step_by(2)
292            .filter(|n| primality_test(*n))
293            .take(500)
294            .collect();
295        p.shuffle(&mut rng());
296
297        for p in p.windows(3) {
298            let n = p[0] * p[1] * p[2];
299            let mut factor = Vec::new();
300
301            assert!(factorize(n, &mut factor).is_ok());
302            assert_eq!(n, factor.iter().product())
303        }
304    }
305}