dashu_base/ring/
gcd.rs

1use super::{ExtendedGcd, Gcd};
2use core::mem::replace;
3
4trait UncheckedGcd<Rhs = Self> {
5    type Output;
6
7    /// GCD with assumptions that (1) at least one of the input is not zero, (2) the
8    /// two operands are relatively close, (3) the factor 2 is removed from the operands.
9    /// For internal use only.
10    fn unchecked_gcd(self, rhs: Rhs) -> Self::Output;
11}
12
13trait UncheckedExtendedGcd<Rhs = Self> {
14    type OutputGcd;
15    type OutputCoeff;
16
17    /// Extended GCD with assumptions that (1) at least one of the input is not zero,
18    /// (2) the first oprand is larger than the second. For internal use only.
19    fn unchecked_gcd_ext(self, rhs: Rhs)
20        -> (Self::OutputGcd, Self::OutputCoeff, Self::OutputCoeff);
21}
22
23macro_rules! impl_unchecked_gcd_ops_prim {
24    ($($U:ty | $I:ty;)*) => {$(
25        impl UncheckedGcd for $U {
26            type Output = $U;
27
28            #[inline]
29            fn unchecked_gcd(self, rhs: Self) -> Self::Output {
30                debug_assert!(self | rhs > 0);
31                debug_assert!(self & rhs & 1 > 0);
32
33                let (mut a, mut b) = (self, rhs);
34
35                // the binary GCD algorithm
36                while a != b {
37                    if a > b {
38                        a -= b;
39                        a >>= a.trailing_zeros();
40                    } else {
41                        b -= a;
42                        b >>= b.trailing_zeros();
43                    }
44                }
45                a
46            }
47        }
48        impl UncheckedExtendedGcd for $U {
49            type OutputGcd = $U;
50            type OutputCoeff = $I;
51
52            #[inline]
53            fn unchecked_gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
54                debug_assert!(self | rhs > 0);
55                debug_assert!(self >= rhs);
56
57                // keep r = self * s + rhs * t
58                let (mut last_r, mut r) = (self, rhs);
59                let (mut last_s, mut s) = (1, 0);
60                let (mut last_t, mut t) = (0, 1);
61
62                loop {
63                    let quo = last_r / r;
64                    let new_r = last_r - quo * r;
65                    if new_r == 0 {
66                        return (r, s, t)
67                    }
68                    last_r = replace(&mut r, new_r);
69                    let new_s = last_s - quo as $I * s;
70                    last_s = replace(&mut s, new_s);
71                    let new_t = last_t - quo as $I * t;
72                    last_t = replace(&mut t, new_t);
73                }
74
75            }
76        }
77    )*};
78    ($($U:ty | $I:ty => $HU:ty | $HI:ty;)*) => {$( // treat the integers as two parts
79        impl UncheckedGcd for $U {
80            type Output = $U;
81
82            fn unchecked_gcd(self, rhs: Self) -> Self::Output {
83                debug_assert!(self | rhs > 0);
84                debug_assert!(self & rhs & 1 > 0);
85                let (mut a, mut b) = (self, rhs);
86
87                // the binary GCD algorithm
88                while a != b {
89                    if (a | b) >> <$HU>::BITS == 0 {
90                        // forward to single width int
91                        return (a as $HU).unchecked_gcd(b as $HU) as $U;
92                    }
93                    if a > b {
94                        a -= b;
95                        a >>= a.trailing_zeros();
96                    } else {
97                        b -= a;
98                        b >>= b.trailing_zeros();
99                    }
100                }
101                a
102            }
103        }
104        impl UncheckedExtendedGcd for $U {
105            type OutputGcd = $U;
106            type OutputCoeff = $I;
107
108            fn unchecked_gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
109                debug_assert!(self | rhs > 0);
110                debug_assert!(self >= rhs);
111
112                // keep r = self * s + rhs * t
113                let (mut last_r, mut r) = (self, rhs);
114                let (mut last_s, mut s) = (1, 0);
115                let (mut last_t, mut t) = (0, 1);
116
117                // normal euclidean algorithm on double width integers
118                while r >> <$HU>::BITS > 0 {
119                    let quo = last_r / r;
120                    let new_r = last_r - quo * r;
121                    if new_r == 0 {
122                        return (r, s, t);
123                    }
124                    last_r = replace(&mut r, new_r);
125                    let new_s = last_s - quo as $I * s;
126                    last_s = replace(&mut s, new_s);
127                    let new_t = last_t - quo as $I * t;
128                    last_t = replace(&mut t, new_t);
129                }
130
131                // reduce double by single
132                let r = r as $HU;
133                let quo = last_r / r as $U;
134                let new_r = (last_r - quo * r as $U) as $HU;
135                if new_r == 0 {
136                    return (r as $U, s, t);
137                }
138                let new_s = last_s - quo as $I * s;
139                let new_t = last_t - quo as $I * t;
140
141                // forward to single width int
142                let (g, cx, cy) = r.unchecked_gcd_ext(new_r);
143                let (cx, cy) = (cx as $I, cy as $I);
144                (g as $U, &cx * s + &cy * new_s, cx * t + cy * new_t)
145            }
146        }
147    )*}
148}
149impl_unchecked_gcd_ops_prim!(u8 | i8; u16 | i16; usize | isize;);
150#[cfg(target_pointer_width = "16")]
151impl_unchecked_gcd_ops_prim!(u32 | i32 => u16 | i16; u64 | i64 => u32 | i32; u128 | i128 => u64 | i64;);
152#[cfg(target_pointer_width = "32")]
153impl_unchecked_gcd_ops_prim!(u32 | i32;);
154#[cfg(target_pointer_width = "32")]
155impl_unchecked_gcd_ops_prim!(u64 | i64 => u32 | i32; u128 | i128 => u64 | u64;);
156#[cfg(target_pointer_width = "64")]
157impl_unchecked_gcd_ops_prim!(u32 | i32; u64 | i64;);
158#[cfg(target_pointer_width = "64")]
159impl_unchecked_gcd_ops_prim!(u128 | i128 => u64 | i64;);
160
161macro_rules! impl_gcd_ops_prim {
162    ($($U:ty | $I:ty;)*) => {$(
163        impl Gcd for $U {
164            type Output = $U;
165
166            #[inline]
167            fn gcd(self, rhs: Self) -> Self::Output {
168                let (mut a, mut b) = (self, rhs);
169                if a == 0 || b == 0 {
170                    if a == 0 && b == 0 {
171                        panic_gcd_0_0();
172                    }
173                    return a | b;
174                }
175
176                // find common factors of 2
177                let shift = (a | b).trailing_zeros();
178                a >>= a.trailing_zeros();
179                b >>= b.trailing_zeros();
180
181                // reduce by division if the difference between operands is large
182                let (za, zb) = (a.leading_zeros(), b.leading_zeros());
183                const GCD_BIT_DIFF_THRESHOLD: u32 = 3;
184                if za > zb.wrapping_add(GCD_BIT_DIFF_THRESHOLD) {
185                    let r = b % a;
186                    if r == 0 {
187                        return a << shift;
188                    } else {
189                        b = r >> r.trailing_zeros();
190                    }
191                } else if zb > za.wrapping_add(4) {
192                    let r = a % b;
193                    if r == 0 {
194                        return b << shift;
195                    } else {
196                        a = r >> r.trailing_zeros();
197                    }
198                }
199
200                // forward to the gcd algorithm
201                a.unchecked_gcd(b) << shift
202            }
203        }
204
205        impl ExtendedGcd for $U {
206            type OutputGcd = $U;
207            type OutputCoeff = $I;
208
209            #[inline]
210            fn gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
211                let (mut a, mut b) = (self, rhs);
212
213                // check if zero inputs
214                match (a == 0, b == 0) {
215                    (true, true) => panic_gcd_0_0(),
216                    (true, false) => return (b, 0, 1),
217                    (false, true) => return (a, 1, 0),
218                    _ => {}
219                }
220
221                // find common factors of 2
222                let shift = (a | b).trailing_zeros();
223                a >>= shift;
224                b >>= shift;
225
226                // make sure a is larger than b
227                if a >= b {
228                    if b == 1 {
229                        // this shortcut eliminates the overflow when a = <$T>::MAX and b = 1
230                        (1 << shift, 0, 1)
231                    } else {
232                        // forward to the gcd algorithm
233                        let (g, ca, cb) = a.unchecked_gcd_ext(b);
234                        (g << shift, ca, cb)
235                    }
236                } else {
237                    if a == 1 {
238                        (1 << shift, 1, 0)
239                    } else {
240                        let (g, cb, ca) = b.unchecked_gcd_ext(a);
241                        (g << shift, ca, cb)
242                    }
243                }
244            }
245        }
246    )*}
247}
248impl_gcd_ops_prim!(u8 | i8; u16 | i16; u32 | i32; u64 | i64; u128 | i128; usize | isize;);
249
250fn panic_gcd_0_0() -> ! {
251    panic!("the greatest common divisor is not defined between zeros!")
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_simple() {
260        assert_eq!(12u8.gcd(18), 6);
261        assert_eq!(16u16.gcd(2032), 16);
262        assert_eq!(0x40000000u32.gcd(0xcfd41b91), 1);
263        assert_eq!(
264            0x80000000000000000000000000000000u128.gcd(0x6f32f1ef8b18a2bc3cea59789c79d441),
265            1
266        );
267        assert_eq!(
268            79901280795560547607793891992771245827u128.gcd(27442821378946980402542540754159585749),
269            1
270        );
271
272        let result = 12u8.gcd_ext(18);
273        assert_eq!(result, (6, -1, 1));
274        let result = 16u16.gcd_ext(2032);
275        assert_eq!(result, (16, 1, 0));
276        let result = 0x40000000u32.gcd_ext(0xcfd41b91);
277        assert_eq!(result, (1, -569926925, 175506801));
278        let result =
279            0x80000000000000000000000000000000u128.gcd_ext(0x6f32f1ef8b18a2bc3cea59789c79d441);
280        assert_eq!(
281            result,
282            (
283                1,
284                59127885930508821681098646892310825630,
285                -68061485417298041807799738471800882239
286            )
287        );
288    }
289}