1#![doc = include_str!("../README.md")]
2use rug::Complete;
5use std::ops::Mul;
6type Z = rug::Integer;
7
8#[non_exhaustive]
9#[derive(Debug, thiserror::Error)]
10pub enum Error {
12 #[error("Input value is out of domain")]
14 OutOfDomain,
15}
16
17fn continued_fraction_of_sqrt_small(d: i64) -> Vec<Z> {
18 let sd = d.isqrt();
19 let mut r = vec![Z::from(sd)];
20 if sd * sd == d {
21 return r;
22 }
23 let mut p = -sd;
24 let mut q = 1;
25 let norm = d - p * p;
26 debug_assert_eq!(norm % q, 0);
27 q = norm / q;
28 p = -p;
29 loop {
30 let flag = q == 1;
31 let v = (sd + p) / q;
32 p -= v * q;
33 let norm = d - p * p;
34 debug_assert_eq!(norm % q, 0);
35 q = norm / q;
36 p = -p;
37 r.push(Z::from(v));
38 if flag {
39 return r;
40 }
41 }
42}
43
44fn continued_fraction_of_sqrt_large(d: Z) -> Vec<Z> {
45 let sd = d.sqrt_ref().complete();
46 let mut r = vec![sd.clone()];
47 if sd.square_ref().complete() == d {
48 return r;
49 }
50 let mut p = -sd.clone();
51 let mut q = Z::ONE.clone();
52 let norm = &d - p.square_ref().complete();
53 debug_assert!(norm.is_divisible(&q));
54 q = norm.div_exact(&q);
55 p *= -1;
56 loop {
57 let flag = q == *Z::ONE;
58 let v = (&sd + &p).complete() / &q;
59 p -= &v * &q;
60 let norm = &d - p.square_ref().complete();
61 debug_assert!(norm.is_divisible(&q));
62 q = norm.div_exact(&q);
63 p *= -1;
64 r.push(v);
65 if flag {
66 return r;
67 }
68 }
69}
70
71pub fn continued_fraction_of_sqrt(d: Z) -> Result<Vec<Z>, Error> {
83 if d.is_negative() {
84 Err(Error::OutOfDomain)
85 } else if let Some(d) = d.to_i64() {
86 Ok(continued_fraction_of_sqrt_small(d))
87 } else {
88 Ok(continued_fraction_of_sqrt_large(d))
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum Solution {
95 Negative(Z, Z),
97 Positive(Z, Z),
99 NotExist,
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104struct Matrix2x2 {
105 a: Z,
106 b: Z,
107 c: Z,
108 d: Z,
109}
110impl Matrix2x2 {
111 fn new(a: Z) -> Self {
112 Self {
113 a,
114 b: Z::ONE.clone(),
115 c: Z::ONE.clone(),
116 d: Z::ZERO,
117 }
118 }
119}
120impl num_traits::One for Matrix2x2 {
121 fn one() -> Self {
122 Self {
123 a: Z::ONE.clone(),
124 b: Z::ZERO,
125 c: Z::ZERO,
126 d: Z::ONE.clone(),
127 }
128 }
129}
130#[auto_impl_ops::auto_ops]
131impl std::ops::Mul<&Matrix2x2> for &Matrix2x2 {
132 type Output = Matrix2x2;
133 fn mul(self, rhs: &Matrix2x2) -> Self::Output {
134 let a = self.a.clone() * &rhs.a + &self.b * &rhs.c;
135 let b = self.a.clone() * &rhs.b + &self.b * &rhs.d;
136 let c = self.c.clone() * &rhs.a + &self.d * &rhs.c;
137 let d = self.c.clone() * &rhs.b + &self.d * &rhs.d;
138 Matrix2x2 { a, b, c, d }
139 }
140}
141fn tree_product(a: &[Z]) -> Matrix2x2 {
142 let v = a
143 .iter()
144 .map(|a| Matrix2x2::new(a.clone()))
145 .collect::<Vec<_>>();
146 let w = ring_algorithm::build_subproduct_tree::<Matrix2x2>(v);
147 w.into_iter().next().unwrap()
148}
149
150fn solve_pell_aux(mut a: Vec<Z>, d: Z) -> Solution {
151 let n = a.len() - 1;
152 if n == 0 {
153 return Solution::NotExist;
154 }
155 let (p_now, q_now) = if n > 8192 {
156 let m = tree_product(&a[1..n]);
157 let init = Matrix2x2 {
158 a: a[0].clone(),
159 b: Z::ONE.clone(),
160 c: Z::ONE.clone(),
161 d: Z::ZERO,
162 };
163 let Matrix2x2 { a, b, c: _, d: _ } = m * init;
164 (a, b)
165 } else {
166 let _ = a.pop();
167 let mut p_old = Z::ONE.clone();
168 let mut q_old = Z::ZERO;
169 let mut p_now = a[0].clone();
170 let mut q_now = Z::ONE.clone();
171 for ai in a.into_iter().skip(1) {
174 p_old += &ai * &p_now;
175 q_old += &ai * &q_now;
176 std::mem::swap(&mut p_old, &mut p_now);
177 std::mem::swap(&mut q_old, &mut q_now);
178 }
180 (p_now, q_now)
181 };
182 if n % 2 == 0 {
183 debug_assert_eq!(
184 p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
185 *Z::ONE
186 );
187 Solution::Positive(p_now, q_now)
188 } else {
189 debug_assert_eq!(
190 p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
191 -Z::ONE.clone()
192 );
193 Solution::Negative(p_now, q_now)
194 }
195}
196
197pub fn solve_pell(d: Z) -> Solution {
210 let Ok(a) = continued_fraction_of_sqrt(d.clone()) else {
211 return Solution::NotExist;
212 };
213 solve_pell_aux(a, d)
214}
215
216pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
228 let a = continued_fraction_of_sqrt(d.clone()).ok()?;
229 if (a.len() - 1) % 2 == 0 {
230 return None;
231 }
232 let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
233 unreachable!()
234 };
235 Some((x, y))
236}
237
238pub fn solve_pell_positive(d: Z) -> Option<(Z, Z)> {
250 match solve_pell(d.clone()) {
251 Solution::NotExist => None,
252 Solution::Positive(x, y) => Some((x, y)),
253 Solution::Negative(x, y) => {
254 let y2 = 2 * (&x * &y).complete();
255 let x2 = x.square() + y.square() * d;
256 Some((x2, y2))
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 fn to_z(v: &[i32]) -> Vec<Z> {
265 v.iter().map(|x| Z::from(*x)).collect()
266 }
267 #[test]
269 fn test_continued_fraction_of_sqrt2() {
270 let v = continued_fraction_of_sqrt(Z::from(2)).unwrap();
271 assert_eq!(v, to_z(&[1, 2]));
272 }
273 #[test]
274 fn test_continued_fraction_of_sqrt3() {
275 let v = continued_fraction_of_sqrt(Z::from(3)).unwrap();
276 assert_eq!(v, to_z(&[1, 1, 2]));
277 }
278 #[test]
279 fn test_continued_fraction_of_sqrt5() {
280 let v = continued_fraction_of_sqrt(Z::from(5)).unwrap();
281 assert_eq!(v, to_z(&[2, 4]));
282 }
283 #[test]
284 fn test_continued_fraction_of_sqrt6() {
285 let v = continued_fraction_of_sqrt(Z::from(6)).unwrap();
286 assert_eq!(v, to_z(&[2, 2, 4]));
287 }
288 #[test]
289 fn test_continued_fraction_of_sqrt7() {
290 let v = continued_fraction_of_sqrt(Z::from(7)).unwrap();
291 assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
292 }
293 #[test]
294 fn test_continued_fraction_of_sqrt8() {
295 let v = continued_fraction_of_sqrt(Z::from(8)).unwrap();
296 assert_eq!(v, to_z(&[2, 1, 4]));
297 }
298 #[test]
299 fn test_continued_fraction_of_sqrt10() {
300 let v = continued_fraction_of_sqrt(Z::from(10)).unwrap();
301 assert_eq!(v, to_z(&[3, 6]));
302 }
303 #[test]
304 fn test_continued_fraction_of_sqrt11() {
305 let v = continued_fraction_of_sqrt(Z::from(11)).unwrap();
306 assert_eq!(v, to_z(&[3, 3, 6]));
307 }
308 #[test]
309 fn test_continued_fraction_of_sqrt12() {
310 let v = continued_fraction_of_sqrt(Z::from(12)).unwrap();
311 assert_eq!(v, to_z(&[3, 2, 6]));
312 }
313 #[test]
314 fn test_continued_fraction_of_sqrt13() {
315 let v = continued_fraction_of_sqrt(Z::from(13)).unwrap();
316 assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
317 }
318 #[test]
319 fn test_continued_fraction_of_sqrt31() {
320 let v = continued_fraction_of_sqrt(Z::from(31)).unwrap();
321 assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
322 }
323 #[test]
324 fn test_continued_fraction_of_sqrt94() {
325 let v = continued_fraction_of_sqrt(Z::from(94)).unwrap();
326 assert_eq!(
327 v,
328 to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
329 );
330 }
331 #[test]
332 fn test_continued_fraction_of_sqrt338() {
333 let v = continued_fraction_of_sqrt(Z::from(338)).unwrap();
334 assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
335 }
336 #[test]
337 fn test_solve_pell() {
338 let v = solve_pell(Z::from(653));
339 assert_eq!(
340 v,
341 Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
342 );
343 }
344 #[test]
345 fn test_solve_pell2() {
346 let v = solve_pell(Z::from(115));
347 assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
348 }
349 #[test]
350 fn test_solve_pell3() {
351 let v = solve_pell(Z::from(114));
352 assert_eq!(v, Solution::Positive(Z::from(1025), Z::from(96)));
353 }
354 #[test]
355 fn test_solve_pell4() {
356 let v = solve_pell(Z::from(641));
357 assert_eq!(
358 v,
359 Solution::Negative(Z::from(36120833468u64), Z::from(1426687145))
360 );
361 }
362 #[test]
363 fn test_solve_pell5() {
364 let Solution::Negative(x, y) = solve_pell(Z::from(1021)) else {
365 panic!("not negative")
366 };
367 assert_eq!(
368 x,
369 Z::from_str_radix("315217280372584882515030", 10).unwrap()
370 );
371 assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
372 }
373}