1#![doc = include_str!("../README.md")]
2use num_traits::One;
5use rug::Complete;
6use std::ops::Mul;
7type Z = rug::Integer;
8
9#[non_exhaustive]
10#[derive(Debug, thiserror::Error)]
11pub enum Error {
13 #[error("Input value is out of domain")]
15 OutOfDomain,
16}
17
18fn continued_fraction_of_sqrt_small(d: i64) -> Vec<Z> {
19 let sd = d.isqrt();
20 let mut r = vec![Z::from(sd)];
21 if sd * sd == d {
22 return r;
23 }
24 let mut p = -sd;
25 let mut q = 1;
26 let norm = d - p * p;
27 debug_assert_eq!(norm % q, 0);
28 q = norm / q;
29 p = -p;
30 loop {
31 let flag = q == 1;
32 let v = (sd + p) / q;
33 p -= v * q;
34 let norm = d - p * p;
35 debug_assert_eq!(norm % q, 0);
36 q = norm / q;
37 p = -p;
38 r.push(Z::from(v));
39 if flag {
40 return r;
41 }
42 }
43}
44
45fn continued_fraction_of_sqrt_large(d: Z) -> Vec<Z> {
46 let sd = d.sqrt_ref().complete();
47 let mut r = vec![sd.clone()];
48 if sd.square_ref().complete() == d {
49 return r;
50 }
51 let mut p = -sd.clone();
52 let mut q = Z::ONE.clone();
53 let norm = &d - p.square_ref().complete();
54 debug_assert!(norm.is_divisible(&q));
55 q = norm.div_exact(&q);
56 p *= -1;
57 loop {
58 let flag = q == *Z::ONE;
59 let v = (&sd + &p).complete() / &q;
60 p -= &v * &q;
61 let norm = &d - p.square_ref().complete();
62 debug_assert!(norm.is_divisible(&q));
63 q = norm.div_exact(&q);
64 p *= -1;
65 r.push(v);
66 if flag {
67 return r;
68 }
69 }
70}
71
72pub fn continued_fraction_of_sqrt(d: Z) -> Result<Vec<Z>, Error> {
84 if d.is_negative() {
85 Err(Error::OutOfDomain)
86 } else if let Some(d) = d.to_i64() {
87 Ok(continued_fraction_of_sqrt_small(d))
88 } else {
89 Ok(continued_fraction_of_sqrt_large(d))
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum Solution {
96 Negative(Z, Z),
98 Positive(Z, Z),
100 NotExist,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
105struct Matrix2x2 {
106 a: Z,
107 b: Z,
108 c: Z,
109 d: Z,
110}
111impl Matrix2x2 {
112 fn new(a: Z) -> Self {
113 Self {
114 a,
115 b: Z::ONE.clone(),
116 c: Z::ONE.clone(),
117 d: Z::ZERO,
118 }
119 }
120}
121impl num_traits::One for Matrix2x2 {
122 fn one() -> Self {
123 Self {
124 a: Z::ONE.clone(),
125 b: Z::ZERO,
126 c: Z::ZERO,
127 d: Z::ONE.clone(),
128 }
129 }
130}
131#[auto_impl_ops::auto_ops]
132impl std::ops::Mul<&Matrix2x2> for &Matrix2x2 {
133 type Output = Matrix2x2;
134 fn mul(self, rhs: &Matrix2x2) -> Self::Output {
135 let a = self.a.clone() * &rhs.a + &self.b * &rhs.c;
136 let b = self.a.clone() * &rhs.b + &self.b * &rhs.d;
137 let c = self.c.clone() * &rhs.a + &self.d * &rhs.c;
138 let d = self.c.clone() * &rhs.b + &self.d * &rhs.d;
139 Matrix2x2 { a, b, c, d }
140 }
141}
142fn tree_product(a: &[Z]) -> Matrix2x2 {
143 let n = (a.len().ilog2() + 1) as usize;
144 let mut v = vec![Matrix2x2::one(); n];
145 for (i, a) in a.iter().rev().enumerate() {
146 let a = Matrix2x2::new(a.clone());
147 v[0] *= a;
148 let mut i = i + 1;
149 let mut j = 0;
150 while i % 2 == 0 {
151 let mut t = Matrix2x2::one();
152 std::mem::swap(&mut t, &mut v[j]);
153 v[j + 1] *= t;
154 i >>= 1;
155 j += 1;
156 }
157 }
158 let mut t = Matrix2x2::one();
159 for i in (0..n).rev() {
160 t *= &v[i];
161 }
162 t
163}
164
165fn solve_pell_aux(mut a: Vec<Z>, d: Z) -> Solution {
166 let n = a.len() - 1;
167 if n == 0 {
168 return Solution::NotExist;
169 }
170 let (p_now, q_now) = if n > 8192 {
171 let m = tree_product(&a[1..n]);
172 let init = Matrix2x2 {
173 a: a[0].clone(),
174 b: Z::ONE.clone(),
175 c: Z::ONE.clone(),
176 d: Z::ZERO,
177 };
178 let Matrix2x2 { a, b, c: _, d: _ } = m * init;
179 (a, b)
180 } else {
181 let _ = a.pop();
182 let mut p_old = Z::ONE.clone();
183 let mut q_old = Z::ZERO;
184 let mut p_now = a[0].clone();
185 let mut q_now = Z::ONE.clone();
186 for ai in a.into_iter().skip(1) {
189 p_old += &ai * &p_now;
190 q_old += &ai * &q_now;
191 std::mem::swap(&mut p_old, &mut p_now);
192 std::mem::swap(&mut q_old, &mut q_now);
193 }
195 (p_now, q_now)
196 };
197 if n % 2 == 0 {
198 debug_assert_eq!(
199 p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
200 *Z::ONE
201 );
202 Solution::Positive(p_now, q_now)
203 } else {
204 debug_assert_eq!(
205 p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
206 -Z::ONE.clone()
207 );
208 Solution::Negative(p_now, q_now)
209 }
210}
211
212pub fn solve_pell(d: Z) -> Solution {
225 let Ok(a) = continued_fraction_of_sqrt(d.clone()) else {
226 return Solution::NotExist;
227 };
228 solve_pell_aux(a, d)
229}
230
231pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
243 let a = continued_fraction_of_sqrt(d.clone()).ok()?;
244 if (a.len() - 1) % 2 == 0 {
245 return None;
246 }
247 let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
248 unreachable!()
249 };
250 Some((x, y))
251}
252
253pub fn solve_pell_positive(d: Z) -> Option<(Z, Z)> {
265 match solve_pell(d.clone()) {
266 Solution::NotExist => None,
267 Solution::Positive(x, y) => Some((x, y)),
268 Solution::Negative(x, y) => {
269 let y2 = 2 * (&x * &y).complete();
270 let x2 = x.square() + y.square() * d;
271 Some((x2, y2))
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 fn to_z(v: &[i32]) -> Vec<Z> {
280 v.iter().map(|x| Z::from(*x)).collect()
281 }
282 #[test]
284 fn test_continued_fraction_of_sqrt2() {
285 let v = continued_fraction_of_sqrt(Z::from(2)).unwrap();
286 assert_eq!(v, to_z(&[1, 2]));
287 }
288 #[test]
289 fn test_continued_fraction_of_sqrt3() {
290 let v = continued_fraction_of_sqrt(Z::from(3)).unwrap();
291 assert_eq!(v, to_z(&[1, 1, 2]));
292 }
293 #[test]
294 fn test_continued_fraction_of_sqrt5() {
295 let v = continued_fraction_of_sqrt(Z::from(5)).unwrap();
296 assert_eq!(v, to_z(&[2, 4]));
297 }
298 #[test]
299 fn test_continued_fraction_of_sqrt6() {
300 let v = continued_fraction_of_sqrt(Z::from(6)).unwrap();
301 assert_eq!(v, to_z(&[2, 2, 4]));
302 }
303 #[test]
304 fn test_continued_fraction_of_sqrt7() {
305 let v = continued_fraction_of_sqrt(Z::from(7)).unwrap();
306 assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
307 }
308 #[test]
309 fn test_continued_fraction_of_sqrt8() {
310 let v = continued_fraction_of_sqrt(Z::from(8)).unwrap();
311 assert_eq!(v, to_z(&[2, 1, 4]));
312 }
313 #[test]
314 fn test_continued_fraction_of_sqrt10() {
315 let v = continued_fraction_of_sqrt(Z::from(10)).unwrap();
316 assert_eq!(v, to_z(&[3, 6]));
317 }
318 #[test]
319 fn test_continued_fraction_of_sqrt11() {
320 let v = continued_fraction_of_sqrt(Z::from(11)).unwrap();
321 assert_eq!(v, to_z(&[3, 3, 6]));
322 }
323 #[test]
324 fn test_continued_fraction_of_sqrt12() {
325 let v = continued_fraction_of_sqrt(Z::from(12)).unwrap();
326 assert_eq!(v, to_z(&[3, 2, 6]));
327 }
328 #[test]
329 fn test_continued_fraction_of_sqrt13() {
330 let v = continued_fraction_of_sqrt(Z::from(13)).unwrap();
331 assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
332 }
333 #[test]
334 fn test_continued_fraction_of_sqrt31() {
335 let v = continued_fraction_of_sqrt(Z::from(31)).unwrap();
336 assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
337 }
338 #[test]
339 fn test_continued_fraction_of_sqrt94() {
340 let v = continued_fraction_of_sqrt(Z::from(94)).unwrap();
341 assert_eq!(
342 v,
343 to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
344 );
345 }
346 #[test]
347 fn test_continued_fraction_of_sqrt338() {
348 let v = continued_fraction_of_sqrt(Z::from(338)).unwrap();
349 assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
350 }
351 #[test]
352 fn test_solve_pell() {
353 let v = solve_pell(Z::from(653));
354 assert_eq!(
355 v,
356 Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
357 );
358 }
359 #[test]
360 fn test_solve_pell2() {
361 let v = solve_pell(Z::from(115));
362 assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
363 }
364 #[test]
365 fn test_solve_pell3() {
366 let v = solve_pell(Z::from(114));
367 assert_eq!(v, Solution::Positive(Z::from(1025), Z::from(96)));
368 }
369 #[test]
370 fn test_solve_pell4() {
371 let v = solve_pell(Z::from(641));
372 assert_eq!(
373 v,
374 Solution::Negative(Z::from(36120833468u64), Z::from(1426687145))
375 );
376 }
377 #[test]
378 fn test_solve_pell5() {
379 let Solution::Negative(x, y) = solve_pell(Z::from(1021)) else {
380 panic!("not negative")
381 };
382 assert_eq!(
383 x,
384 Z::from_str_radix("315217280372584882515030", 10).unwrap()
385 );
386 assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
387 }
388}