Skip to main content

lib_modulo/
factorize.rs

1use std::num::NonZero;
2
3use crate::{Modulus64, prime::primality_test};
4
5/// Factorize integer and writes prime factors to `factor` in any order.
6///
7/// This function is probabilistic and may fail.
8///
9/// # Time complexity
10///
11/// O(`x`^0.25) expected
12///
13/// # Example
14///
15/// ```
16/// use lib_modulo::factorize::*;
17///
18/// let mut factor = Vec::new();
19/// // panics if factorization fails
20/// assert!(factorize(998_244_353 * 1_000_000_007, &mut factor).is_ok());
21///
22/// factor.sort_unstable();
23/// assert_eq!(factor, vec![998_244_353, 1_000_000_007])
24/// ```
25pub fn factorize(mut x: u64, factor: &mut Vec<u64>) -> Result<(), ()> {
26    if x < 2 {
27        return Ok(());
28    }
29    factor.reserve(64);
30
31    // trial division by small primes less than 2^10
32    {
33        factor.extend(std::iter::repeat_n(2, x.trailing_zeros() as usize));
34        x >>= x.trailing_zeros();
35    }
36    for &(n, inv_n, r2_mod_n) in SMALL_ODD_PRIME_CONTEXT_16.iter() {
37        let ctx = Modulus64 {
38            n: n as u64,
39            inv_n,
40            r2_mod_n: r2_mod_n as u64,
41        };
42
43        while ctx.can_divide(x) {
44            x /= ctx.n;
45            factor.push(ctx.n);
46        }
47
48        if x == 1 {
49            return Ok(());
50        }
51    }
52
53    // find large prime factors (up to 3) by Pollard's rho
54    while x > 1 {
55        if primality_test(x) {
56            factor.push(x);
57            return Ok(());
58        }
59
60        if let Some(d) = pollard_rho(x) {
61            let d = d.get();
62            while x % d == 0 {
63                x /= d;
64                factor.push(d);
65            }
66        } else {
67            return Err(());
68        }
69    }
70
71    Ok(())
72}
73
74/// Find prime factor of `x`.
75///
76/// This function is probabilistic and may fail.
77///
78/// # Time complexity
79///
80/// *O*(p^0.25) where p is a prime factor of `x`
81fn pollard_rho(x: u64) -> Option<NonZero<u64>> {
82    let ctx = Modulus64::new(x);
83    let one = ctx.residue(1);
84
85    for c in 1..100 {
86        // a = b (mod x) => f(a) = f(b) (mod x)
87        let f = |x: u64| ctx.mul_add(x, x, c);
88
89        let mut y0 = ctx.residue(1);
90        let mut y1 = y0;
91
92        let mut prod = one;
93        let mut step = 0;
94        let mut memo = [[0, 0, one.x]; 1 << 5];
95
96        'cycle_detection: while !prod.is_zero() {
97            y0.x = f(y0.x);
98            y1.x = f(f(y1.x));
99            prod *= y1 - y0;
100            step += 1;
101
102            if step % (1 << 5) == 0 {
103                memo[(step >> 5) % (1 << 5)] = [y0.x, y1.x, prod.x];
104            }
105            if step % (1 << 10) == 0 || prod.is_zero() {
106                let g = binary_gcd(prod.x, x);
107
108                if g == 1 {
109                    continue 'cycle_detection;
110                } else if primality_test(g) {
111                    return NonZero::new(g);
112                }
113
114                for i in 0..memo.len() {
115                    let g = binary_gcd(memo[i][2], x);
116
117                    if g != 1 {
118                        if primality_test(g) {
119                            return NonZero::new(g);
120                        }
121
122                        y0.x = memo[i][0];
123                        y1.x = memo[i][1];
124                        for _ in 0..1 << 5 {
125                            let g = binary_gcd((y0 - y1).x, x);
126
127                            if g != 1 {
128                                if primality_test(g) {
129                                    return NonZero::new(g);
130                                } else if g != x {
131                                    // FIXME: `x` is composed of at most 3 primes, so return `x/g`
132                                    return pollard_rho(g);
133                                } else {
134                                    break 'cycle_detection;
135                                }
136                            }
137
138                            y0.x = f(y0.x);
139                            y1.x = f(f(y1.x));
140                        }
141                    }
142                }
143            }
144        }
145    }
146
147    None
148}
149
150#[inline(always)]
151fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
152    if b == 0 {
153        return a;
154    }
155
156    let shift = (a | b).trailing_zeros();
157    b >>= b.trailing_zeros();
158
159    while a != 0 {
160        a >>= a.trailing_zeros();
161
162        if a < b {
163            (a, b) = (b, a)
164        }
165        a -= b
166    }
167
168    b << shift
169}
170
171#[cfg(test)]
172mod tests {
173    use rand::{rng, seq::SliceRandom, Rng};
174
175    use super::*;
176
177    #[test]
178    fn random() {
179        let mut rng = rng();
180        for n in std::iter::repeat_with(|| rng.random_range(1 << 55..=u64::MAX)).take(10_000) {
181            let mut factor = Vec::new();
182
183            assert!(factorize(n, &mut factor).is_ok());
184            assert_eq!(n, factor.iter().product())
185        }
186    }
187
188    #[test]
189    fn random_square() {
190        let mut rng = rng();
191        for n in std::iter::repeat_with(|| rng.random_range(1 << 20..1 << 32)).take(5000) {
192            let mut factor = Vec::new();
193
194            assert!(factorize(n * n, &mut factor).is_ok());
195            assert_eq!(n * n, factor.iter().product())
196        }
197    }
198
199    #[test]
200    fn prime_square() {
201        for n in (0..1 << 32)
202            .rev()
203            .step_by(2)
204            .filter(|n| primality_test(*n))
205            .take(500)
206        {
207            let mut factor = Vec::new();
208
209            assert!(factorize(n * n, &mut factor).is_ok());
210            assert_eq!(n * n, factor.iter().product())
211        }
212    }
213
214    // fast since p is relatively small
215    #[test]
216    fn prime_cube() {
217        let p = Vec::from_iter(
218            (0..1 << 21)
219                .rev()
220                .step_by(2)
221                .filter(|n| primality_test(*n))
222                .take(500),
223        );
224
225        for p in p {
226            let n = p.pow(3);
227            let mut factor = Vec::new();
228
229            assert!(factorize(n, &mut factor).is_ok());
230            assert_eq!(n, factor.iter().product())
231        }
232    }
233
234    #[test]
235    fn prime_double() {
236        let mut p: Vec<u64> = (0..1 << 32)
237            .rev()
238            .step_by(2)
239            .filter(|n| primality_test(*n))
240            .take(500)
241            .collect();
242        p.shuffle(&mut rng());
243
244        for p in p.windows(2) {
245            let n = p[0] * p[1];
246            let mut factor = Vec::new();
247
248            assert!(factorize(n, &mut factor).is_ok());
249            assert_eq!(n, factor.iter().product())
250        }
251    }
252
253    #[test]
254    fn prime_triple() {
255        let mut p: Vec<u64> = (0..1 << 21)
256            .rev()
257            .step_by(2)
258            .filter(|n| primality_test(*n))
259            .take(500)
260            .collect();
261        p.shuffle(&mut rng());
262
263        for p in p.windows(3) {
264            let n = p[0] * p[1] * p[2];
265            let mut factor = Vec::new();
266
267            assert!(factorize(n, &mut factor).is_ok());
268            assert_eq!(n, factor.iter().product())
269        }
270    }
271}
272
273// 12 bytes * 171 ~ 2 KiB
274static SMALL_ODD_PRIME_CONTEXT_16: [(u16, u64, u16); 171] = [
275    (3, 12297829382473034411, 1),
276    (5, 14757395258967641293, 1),
277    (7, 7905747460161236407, 4),
278    (11, 3353953467947191203, 3),
279    (13, 5675921253449092805, 9),
280    (17, 17361641481138401521, 1),
281    (19, 9708812670373448219, 4),
282    (23, 15238614669586151335, 13),
283    (29, 3816567739388183093, 25),
284    (31, 17256631552825064415, 8),
285    (37, 1495681951922396077, 33),
286    (41, 10348173504763894809, 10),
287    (43, 9437869060967677571, 4),
288    (47, 5887258746928580303, 14),
289    (53, 2436362424829563421, 13),
290    (59, 14694863923124558067, 25),
291    (61, 5745707170499696405, 12),
292    (67, 17345445920055250027, 21),
293    (71, 1818693077689674103, 29),
294    (73, 9097024474706080249, 4),
295    (79, 11208148297950107311, 73),
296    (83, 11779246215742243803, 51),
297    (89, 17617676924329347049, 39),
298    (97, 11790702397628785569, 35),
299    (101, 4200743699953660269, 80),
300    (103, 15760325033848937303, 38),
301    (107, 8619973866219416643, 11),
302    (109, 12015769075535579493, 105),
303    (113, 10447713457676206225, 109),
304    (127, 9150747060186627967, 4),
305    (131, 281629680514649643, 33),
306    (137, 16292379802327414201, 38),
307    (139, 4246732448623781667, 30),
308    (149, 16094474695182830269, 123),
309    (151, 8062815290495565607, 105),
310    (157, 6579730370240349621, 39),
311    (163, 2263404180823257867, 152),
312    (167, 10162278172342986519, 63),
313    (173, 9809829218388894501, 133),
314    (179, 17107036403551874683, 161),
315    (181, 3770881385233444253, 126),
316    (191, 2124755861893246783, 103),
317    (193, 8124213711219232577, 108),
318    (197, 14513935692512591373, 175),
319    (199, 2780916192016515319, 155),
320    (211, 13900627050804827995, 119),
321    (223, 7527595115280579359, 171),
322    (227, 1950316554048586955, 147),
323    (229, 2094390156840385773, 104),
324    (233, 7204522363551799129, 135),
325    (239, 7255204782128442895, 34),
326    (241, 17298606475760824337, 15),
327    (251, 2939720171109091891, 243),
328    (257, 18374966859414961921, 1),
329    (263, 15430736487513693367, 33),
330    (269, 10354863773718001093, 21),
331    (271, 15383631589145234927, 36),
332    (277, 17181443938689762877, 155),
333    (281, 14245350405676059433, 85),
334    (283, 5149444458738708755, 151),
335    (293, 2707201348701401773, 161),
336    (307, 17305088903023944187, 199),
337    (311, 9134400602415662215, 35),
338    (313, 6365010734698503433, 132),
339    (317, 17050145153302519317, 235),
340    (331, 3455281367280943203, 256),
341    (337, 9196002980365592497, 4),
342    (347, 9941040754419844819, 129),
343    (349, 15751088062938241781, 148),
344    (353, 8779186981255537313, 22),
345    (359, 5600822016808749655, 264),
346    (367, 9751139919072624015, 129),
347    (373, 3511310534137743069, 68),
348    (379, 17181268226964305331, 171),
349    (383, 14834457375202459263, 150),
350    (389, 12661389891209383757, 164),
351    (397, 185861401246443845, 273),
352    (401, 3220129888178724721, 360),
353    (409, 2074694932495450793, 265),
354    (419, 1849076971589024267, 100),
355    (421, 14897608040525528621, 255),
356    (431, 8046375605237577039, 216),
357    (433, 7540585914657253201, 150),
358    (439, 15379290047785184263, 36),
359    (443, 15615189678648040307, 153),
360    (449, 205420312624827969, 18),
361    (457, 686202733595322489, 68),
362    (461, 3041111821262312197, 444),
363    (463, 8127723090792113455, 60),
364    (467, 15247201739725667931, 264),
365    (479, 8010277176057592351, 28),
366    (487, 2386334448960373207, 467),
367    (491, 1051952818867347139, 429),
368    (499, 12494988971771199291, 462),
369    (503, 17969989256695189447, 378),
370    (509, 5436172123882971989, 93),
371    (521, 1805727346946616377, 130),
372    (523, 7195288319381928355, 7),
373    (541, 13911777416032342069, 254),
374    (547, 13219604528142859659, 318),
375    (557, 5133295029488295333, 414),
376    (563, 18151858289227516155, 424),
377    (569, 6386658317259721737, 107),
378    (571, 1873749835858413299, 396),
379    (577, 8184343991108570561, 546),
380    (587, 8107768264083584867, 456),
381    (593, 13407330009728190129, 277),
382    (599, 16999336775772408167, 205),
383    (601, 10926856722530117097, 8),
384    (607, 7810235958720518559, 197),
385    (613, 10111102787547160429, 112),
386    (617, 5112468778937331161, 334),
387    (619, 10400506755613301315, 170),
388    (631, 14909412801254946631, 281),
389    (641, 18417966001831689601, 1),
390    (643, 12450835035754191915, 314),
391    (647, 16650538700226241335, 119),
392    (653, 18023005695293558853, 386),
393    (659, 13744083976011213723, 624),
394    (661, 3544230707051608253, 222),
395    (673, 7016889276180750689, 417),
396    (677, 7711120491668837677, 667),
397    (683, 18338710433453565955, 555),
398    (691, 16604739238563445883, 100),
399    (701, 4578792394900801685, 253),
400    (709, 17796294705807240205, 211),
401    (719, 7799457855921701935, 631),
402    (727, 3501582781529460967, 128),
403    (733, 8027982755134170485, 669),
404    (739, 12755461734324196043, 707),
405    (743, 8416482154761154775, 338),
406    (751, 73688724661955599, 167),
407    (757, 11233750354002778461, 112),
408    (761, 8193166224591101769, 591),
409    (769, 8635666926574042369, 360),
410    (773, 11955781087876436429, 597),
411    (787, 1757948926973591323, 441),
412    (797, 3332912354597459765, 57),
413    (809, 12062209660064712985, 676),
414    (811, 2229076349227541379, 19),
415    (821, 4044353146489304861, 487),
416    (823, 5536264624794968711, 100),
417    (827, 9502192231439261171, 262),
418    (829, 9813044796750195733, 345),
419    (839, 17941052639030982263, 295),
420    (853, 13018686907823153661, 712),
421    (857, 12247604875076703465, 315),
422    (859, 858986918449804499, 677),
423    (863, 13637338028999645343, 730),
424    (877, 4964004106494246501, 270),
425    (881, 14447506709261737361, 487),
426    (883, 17047047751015848379, 179),
427    (887, 15763959422628906567, 364),
428    (907, 11979197639928253475, 558),
429    (911, 16462352285319281519, 142),
430    (919, 16098246732442938407, 67),
431    (929, 13264181960428396641, 437),
432    (937, 11792529028977610905, 174),
433    (941, 1725094026021722149, 501),
434    (947, 18271431828024877947, 152),
435    (953, 18137040080866579081, 443),
436    (967, 9614435380713147895, 279),
437    (971, 10828675717831559651, 343),
438    (977, 12064963626510136625, 496),
439    (983, 9326583728009813991, 260),
440    (991, 13160290676198438943, 139),
441    (997, 11785933776281829869, 299),
442    (1009, 6965519813759503633, 142),
443    (1013, 7775675932353384541, 92),
444    (1019, 10372899268140896051, 821),
445    (1021, 7497942008412795221, 646),
446];