1use super::{ExtendedGcd, Gcd};
2use core::mem::replace;
3
4trait UncheckedGcd<Rhs = Self> {
5 type Output;
6
7 fn unchecked_gcd(self, rhs: Rhs) -> Self::Output;
11}
12
13trait UncheckedExtendedGcd<Rhs = Self> {
14 type OutputGcd;
15 type OutputCoeff;
16
17 fn unchecked_gcd_ext(self, rhs: Rhs)
20 -> (Self::OutputGcd, Self::OutputCoeff, Self::OutputCoeff);
21}
22
23macro_rules! impl_unchecked_gcd_ops_prim {
24 ($($U:ty | $I:ty;)*) => {$(
25 impl UncheckedGcd for $U {
26 type Output = $U;
27
28 #[inline]
29 fn unchecked_gcd(self, rhs: Self) -> Self::Output {
30 debug_assert!(self | rhs > 0);
31 debug_assert!(self & rhs & 1 > 0);
32
33 let (mut a, mut b) = (self, rhs);
34
35 while a != b {
37 if a > b {
38 a -= b;
39 a >>= a.trailing_zeros();
40 } else {
41 b -= a;
42 b >>= b.trailing_zeros();
43 }
44 }
45 a
46 }
47 }
48 impl UncheckedExtendedGcd for $U {
49 type OutputGcd = $U;
50 type OutputCoeff = $I;
51
52 #[inline]
53 fn unchecked_gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
54 debug_assert!(self | rhs > 0);
55 debug_assert!(self >= rhs);
56
57 let (mut last_r, mut r) = (self, rhs);
59 let (mut last_s, mut s) = (1, 0);
60 let (mut last_t, mut t) = (0, 1);
61
62 loop {
63 let quo = last_r / r;
64 let new_r = last_r - quo * r;
65 if new_r == 0 {
66 return (r, s, t)
67 }
68 last_r = replace(&mut r, new_r);
69 let new_s = last_s - quo as $I * s;
70 last_s = replace(&mut s, new_s);
71 let new_t = last_t - quo as $I * t;
72 last_t = replace(&mut t, new_t);
73 }
74
75 }
76 }
77 )*};
78 ($($U:ty | $I:ty => $HU:ty | $HI:ty;)*) => {$( impl UncheckedGcd for $U {
80 type Output = $U;
81
82 fn unchecked_gcd(self, rhs: Self) -> Self::Output {
83 debug_assert!(self | rhs > 0);
84 debug_assert!(self & rhs & 1 > 0);
85 let (mut a, mut b) = (self, rhs);
86
87 while a != b {
89 if (a | b) >> <$HU>::BITS == 0 {
90 return (a as $HU).unchecked_gcd(b as $HU) as $U;
92 }
93 if a > b {
94 a -= b;
95 a >>= a.trailing_zeros();
96 } else {
97 b -= a;
98 b >>= b.trailing_zeros();
99 }
100 }
101 a
102 }
103 }
104 impl UncheckedExtendedGcd for $U {
105 type OutputGcd = $U;
106 type OutputCoeff = $I;
107
108 fn unchecked_gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
109 debug_assert!(self | rhs > 0);
110 debug_assert!(self >= rhs);
111
112 let (mut last_r, mut r) = (self, rhs);
114 let (mut last_s, mut s) = (1, 0);
115 let (mut last_t, mut t) = (0, 1);
116
117 while r >> <$HU>::BITS > 0 {
119 let quo = last_r / r;
120 let new_r = last_r - quo * r;
121 if new_r == 0 {
122 return (r, s, t);
123 }
124 last_r = replace(&mut r, new_r);
125 let new_s = last_s - quo as $I * s;
126 last_s = replace(&mut s, new_s);
127 let new_t = last_t - quo as $I * t;
128 last_t = replace(&mut t, new_t);
129 }
130
131 let r = r as $HU;
133 let quo = last_r / r as $U;
134 let new_r = (last_r - quo * r as $U) as $HU;
135 if new_r == 0 {
136 return (r as $U, s, t);
137 }
138 let new_s = last_s - quo as $I * s;
139 let new_t = last_t - quo as $I * t;
140
141 let (g, cx, cy) = r.unchecked_gcd_ext(new_r);
143 let (cx, cy) = (cx as $I, cy as $I);
144 (g as $U, &cx * s + &cy * new_s, cx * t + cy * new_t)
145 }
146 }
147 )*}
148}
149impl_unchecked_gcd_ops_prim!(u8 | i8; u16 | i16; usize | isize;);
150#[cfg(target_pointer_width = "16")]
151impl_unchecked_gcd_ops_prim!(u32 | i32 => u16 | i16; u64 | i64 => u32 | i32; u128 | i128 => u64 | i64;);
152#[cfg(target_pointer_width = "32")]
153impl_unchecked_gcd_ops_prim!(u32 | i32;);
154#[cfg(target_pointer_width = "32")]
155impl_unchecked_gcd_ops_prim!(u64 | i64 => u32 | i32; u128 | i128 => u64 | u64;);
156#[cfg(target_pointer_width = "64")]
157impl_unchecked_gcd_ops_prim!(u32 | i32; u64 | i64;);
158#[cfg(target_pointer_width = "64")]
159impl_unchecked_gcd_ops_prim!(u128 | i128 => u64 | i64;);
160
161macro_rules! impl_gcd_ops_prim {
162 ($($U:ty | $I:ty;)*) => {$(
163 impl Gcd for $U {
164 type Output = $U;
165
166 #[inline]
167 fn gcd(self, rhs: Self) -> Self::Output {
168 let (mut a, mut b) = (self, rhs);
169 if a == 0 || b == 0 {
170 if a == 0 && b == 0 {
171 panic_gcd_0_0();
172 }
173 return a | b;
174 }
175
176 let shift = (a | b).trailing_zeros();
178 a >>= a.trailing_zeros();
179 b >>= b.trailing_zeros();
180
181 let (za, zb) = (a.leading_zeros(), b.leading_zeros());
183 const GCD_BIT_DIFF_THRESHOLD: u32 = 3;
184 if za > zb.wrapping_add(GCD_BIT_DIFF_THRESHOLD) {
185 let r = b % a;
186 if r == 0 {
187 return a << shift;
188 } else {
189 b = r >> r.trailing_zeros();
190 }
191 } else if zb > za.wrapping_add(4) {
192 let r = a % b;
193 if r == 0 {
194 return b << shift;
195 } else {
196 a = r >> r.trailing_zeros();
197 }
198 }
199
200 a.unchecked_gcd(b) << shift
202 }
203 }
204
205 impl ExtendedGcd for $U {
206 type OutputGcd = $U;
207 type OutputCoeff = $I;
208
209 #[inline]
210 fn gcd_ext(self, rhs: $U) -> ($U, $I, $I) {
211 let (mut a, mut b) = (self, rhs);
212
213 match (a == 0, b == 0) {
215 (true, true) => panic_gcd_0_0(),
216 (true, false) => return (b, 0, 1),
217 (false, true) => return (a, 1, 0),
218 _ => {}
219 }
220
221 let shift = (a | b).trailing_zeros();
223 a >>= shift;
224 b >>= shift;
225
226 if a >= b {
228 if b == 1 {
229 (1 << shift, 0, 1)
231 } else {
232 let (g, ca, cb) = a.unchecked_gcd_ext(b);
234 (g << shift, ca, cb)
235 }
236 } else {
237 if a == 1 {
238 (1 << shift, 1, 0)
239 } else {
240 let (g, cb, ca) = b.unchecked_gcd_ext(a);
241 (g << shift, ca, cb)
242 }
243 }
244 }
245 }
246 )*}
247}
248impl_gcd_ops_prim!(u8 | i8; u16 | i16; u32 | i32; u64 | i64; u128 | i128; usize | isize;);
249
250fn panic_gcd_0_0() -> ! {
251 panic!("the greatest common divisor is not defined between zeros!")
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_simple() {
260 assert_eq!(12u8.gcd(18), 6);
261 assert_eq!(16u16.gcd(2032), 16);
262 assert_eq!(0x40000000u32.gcd(0xcfd41b91), 1);
263 assert_eq!(
264 0x80000000000000000000000000000000u128.gcd(0x6f32f1ef8b18a2bc3cea59789c79d441),
265 1
266 );
267 assert_eq!(
268 79901280795560547607793891992771245827u128.gcd(27442821378946980402542540754159585749),
269 1
270 );
271
272 let result = 12u8.gcd_ext(18);
273 assert_eq!(result, (6, -1, 1));
274 let result = 16u16.gcd_ext(2032);
275 assert_eq!(result, (16, 1, 0));
276 let result = 0x40000000u32.gcd_ext(0xcfd41b91);
277 assert_eq!(result, (1, -569926925, 175506801));
278 let result =
279 0x80000000000000000000000000000000u128.gcd_ext(0x6f32f1ef8b18a2bc3cea59789c79d441);
280 assert_eq!(
281 result,
282 (
283 1,
284 59127885930508821681098646892310825630,
285 -68061485417298041807799738471800882239
286 )
287 );
288 }
289}