pell_equation/
lib.rs

1#![doc = include_str!("../README.md")]
2// #![warn(clippy::pedantic)]
3
4use rug::Complete;
5use std::ops::Mul;
6type Z = rug::Integer;
7
8#[non_exhaustive]
9#[derive(Debug, thiserror::Error)]
10/// Error type
11pub enum Error {
12    /// Input value is out of domain
13    #[error("Input value is out of domain")]
14    OutOfDomain,
15}
16
17fn continued_fraction_of_sqrt_small(d: i64) -> Vec<Z> {
18    let sd = d.isqrt();
19    let mut r = vec![Z::from(sd)];
20    if sd * sd == d {
21        return r;
22    }
23    let mut p = -sd;
24    let mut q = 1;
25    let norm = d - p * p;
26    debug_assert_eq!(norm % q, 0);
27    q = norm / q;
28    p = -p;
29    loop {
30        let flag = q == 1;
31        let v = (sd + p) / q;
32        p -= v * q;
33        let norm = d - p * p;
34        debug_assert_eq!(norm % q, 0);
35        q = norm / q;
36        p = -p;
37        r.push(Z::from(v));
38        if flag {
39            return r;
40        }
41    }
42}
43
44fn continued_fraction_of_sqrt_large(d: Z) -> Vec<Z> {
45    let sd = d.sqrt_ref().complete();
46    let mut r = vec![sd.clone()];
47    if sd.square_ref().complete() == d {
48        return r;
49    }
50    let mut p = -sd.clone();
51    let mut q = Z::ONE.clone();
52    let norm = &d - p.square_ref().complete();
53    debug_assert!(norm.is_divisible(&q));
54    q = norm.div_exact(&q);
55    p *= -1;
56    loop {
57        let flag = q == *Z::ONE;
58        let v = (&sd + &p).complete() / &q;
59        p -= &v * &q;
60        let norm = &d - p.square_ref().complete();
61        debug_assert!(norm.is_divisible(&q));
62        q = norm.div_exact(&q);
63        p *= -1;
64        r.push(v);
65        if flag {
66            return r;
67        }
68    }
69}
70
71/// Calculate continued fraction of √d
72///
73/// Calculate [simple continued fraction](https://en.wikipedia.org/wiki/Simple_continued_fraction)
74/// of √d.  
75/// If d is negative returns `Err(Error::OutOfDomain)`.  
76/// ex : √2 = [1; 2, 2, 2, ...]
77/// ```
78/// use rug::Integer;
79/// let v = pell_equation::continued_fraction_of_sqrt(Integer::from(2)).unwrap();
80/// assert_eq!(v, vec![Integer::from(1), Integer::from(2)]);
81/// ```
82pub fn continued_fraction_of_sqrt(d: Z) -> Result<Vec<Z>, Error> {
83    if d.is_negative() {
84        Err(Error::OutOfDomain)
85    } else if let Some(d) = d.to_i64() {
86        Ok(continued_fraction_of_sqrt_small(d))
87    } else {
88        Ok(continued_fraction_of_sqrt_large(d))
89    }
90}
91
92/// Fundamental solution of `x^2 - d*y^2 = ±1`
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum Solution {
95    /// Fundamental solution of `x^2 - d*y^2 = -1`
96    Negative(Z, Z),
97    /// Fundamental solution of `x^2 - d*y^2 = 1`
98    Positive(Z, Z),
99    /// Not exsist nontirivial solution
100    NotExist,
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104struct Matrix2x2 {
105    a: Z,
106    b: Z,
107    c: Z,
108    d: Z,
109}
110impl Matrix2x2 {
111    fn new(a: Z) -> Self {
112        Self {
113            a,
114            b: Z::ONE.clone(),
115            c: Z::ONE.clone(),
116            d: Z::ZERO,
117        }
118    }
119}
120impl num_traits::One for Matrix2x2 {
121    fn one() -> Self {
122        Self {
123            a: Z::ONE.clone(),
124            b: Z::ZERO,
125            c: Z::ZERO,
126            d: Z::ONE.clone(),
127        }
128    }
129}
130#[auto_impl_ops::auto_ops]
131impl std::ops::Mul<&Matrix2x2> for &Matrix2x2 {
132    type Output = Matrix2x2;
133    fn mul(self, rhs: &Matrix2x2) -> Self::Output {
134        let a = self.a.clone() * &rhs.a + &self.b * &rhs.c;
135        let b = self.a.clone() * &rhs.b + &self.b * &rhs.d;
136        let c = self.c.clone() * &rhs.a + &self.d * &rhs.c;
137        let d = self.c.clone() * &rhs.b + &self.d * &rhs.d;
138        Matrix2x2 { a, b, c, d }
139    }
140}
141fn tree_product(a: &[Z]) -> Matrix2x2 {
142    let v = a
143        .iter()
144        .map(|a| Matrix2x2::new(a.clone()))
145        .collect::<Vec<_>>();
146    let w = ring_algorithm::build_subproduct_tree::<Matrix2x2>(v);
147    w.into_iter().next().unwrap()
148}
149
150fn solve_pell_aux(mut a: Vec<Z>, d: Z) -> Solution {
151    let n = a.len() - 1;
152    if n == 0 {
153        return Solution::NotExist;
154    }
155    let (p_now, q_now) = if n > 8192 {
156        let m = tree_product(&a[1..n]);
157        let init = Matrix2x2 {
158            a: a[0].clone(),
159            b: Z::ONE.clone(),
160            c: Z::ONE.clone(),
161            d: Z::ZERO,
162        };
163        let Matrix2x2 { a, b, c: _, d: _ } = m * init;
164        (a, b)
165    } else {
166        let _ = a.pop();
167        let mut p_old = Z::ONE.clone();
168        let mut q_old = Z::ZERO;
169        let mut p_now = a[0].clone();
170        let mut q_now = Z::ONE.clone();
171        // println!("{p_old} {q_old}");
172        // println!("{p_now} {q_now}");
173        for ai in a.into_iter().skip(1) {
174            p_old += &ai * &p_now;
175            q_old += &ai * &q_now;
176            std::mem::swap(&mut p_old, &mut p_now);
177            std::mem::swap(&mut q_old, &mut q_now);
178            // println!("{p_now} {q_now}");
179        }
180        (p_now, q_now)
181    };
182    if n % 2 == 0 {
183        debug_assert_eq!(
184            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
185            *Z::ONE
186        );
187        Solution::Positive(p_now, q_now)
188    } else {
189        debug_assert_eq!(
190            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
191            -Z::ONE.clone()
192        );
193        Solution::Negative(p_now, q_now)
194    }
195}
196
197/// Calculate fundamental solution of `x^2 - d*y^2 = ±1`
198///
199/// If d is negative or perfect square returns `NotExist`.  
200/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
201/// Otherwise returns fundamental solution of `x^2 - d*y^2 = 1`.
202/// ```
203/// use rug::Integer;
204/// let v = pell_equation::solve_pell(Integer::from(2));
205/// assert_eq!(v, pell_equation::Solution::Negative(Integer::from(1), Integer::from(1)));
206/// let w = pell_equation::solve_pell(Integer::from(3));
207/// assert_eq!(w, pell_equation::Solution::Positive(Integer::from(2), Integer::from(1)));
208/// ```
209pub fn solve_pell(d: Z) -> Solution {
210    let Ok(a) = continued_fraction_of_sqrt(d.clone()) else {
211        return Solution::NotExist;
212    };
213    solve_pell_aux(a, d)
214}
215
216/// Calculate fundamental solution of `x^2 - d*y^2 = -1`
217///
218/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
219/// Otherwise returns `None`.
220/// ```
221/// use rug::Integer;
222/// let v = pell_equation::solve_pell_negative(Integer::from(2));
223/// assert_eq!(v, Some((Integer::from(1), Integer::from(1))));
224/// let w = pell_equation::solve_pell_negative(Integer::from(3));
225/// assert_eq!(w, None);
226/// ```
227pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
228    let a = continued_fraction_of_sqrt(d.clone()).ok()?;
229    if (a.len() - 1) % 2 == 0 {
230        return None;
231    }
232    let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
233        unreachable!()
234    };
235    Some((x, y))
236}
237
238/// Calculate fundamental solution of `x^2 - d*y^2 = 1`
239///
240/// If `x^2 - d*y^2 = 1` has nontrivial solution, returns its fundamental solution.  
241/// Otherwise returns `None`.
242/// ```
243/// use rug::Integer;
244/// let v = pell_equation::solve_pell_positive(Integer::from(2));
245/// assert_eq!(v, Some((Integer::from(3), Integer::from(2))));
246/// let w = pell_equation::solve_pell_positive(Integer::from(3));
247/// assert_eq!(w, Some((Integer::from(2), Integer::from(1))));
248/// ```
249pub fn solve_pell_positive(d: Z) -> Option<(Z, Z)> {
250    match solve_pell(d.clone()) {
251        Solution::NotExist => None,
252        Solution::Positive(x, y) => Some((x, y)),
253        Solution::Negative(x, y) => {
254            let y2 = 2 * (&x * &y).complete();
255            let x2 = x.square() + y.square() * d;
256            Some((x2, y2))
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    fn to_z(v: &[i32]) -> Vec<Z> {
265        v.iter().map(|x| Z::from(*x)).collect()
266    }
267    // https://planetmath.org/tableofcontinuedfractionsofsqrtnfor1n102
268    #[test]
269    fn test_continued_fraction_of_sqrt2() {
270        let v = continued_fraction_of_sqrt(Z::from(2)).unwrap();
271        assert_eq!(v, to_z(&[1, 2]));
272    }
273    #[test]
274    fn test_continued_fraction_of_sqrt3() {
275        let v = continued_fraction_of_sqrt(Z::from(3)).unwrap();
276        assert_eq!(v, to_z(&[1, 1, 2]));
277    }
278    #[test]
279    fn test_continued_fraction_of_sqrt5() {
280        let v = continued_fraction_of_sqrt(Z::from(5)).unwrap();
281        assert_eq!(v, to_z(&[2, 4]));
282    }
283    #[test]
284    fn test_continued_fraction_of_sqrt6() {
285        let v = continued_fraction_of_sqrt(Z::from(6)).unwrap();
286        assert_eq!(v, to_z(&[2, 2, 4]));
287    }
288    #[test]
289    fn test_continued_fraction_of_sqrt7() {
290        let v = continued_fraction_of_sqrt(Z::from(7)).unwrap();
291        assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
292    }
293    #[test]
294    fn test_continued_fraction_of_sqrt8() {
295        let v = continued_fraction_of_sqrt(Z::from(8)).unwrap();
296        assert_eq!(v, to_z(&[2, 1, 4]));
297    }
298    #[test]
299    fn test_continued_fraction_of_sqrt10() {
300        let v = continued_fraction_of_sqrt(Z::from(10)).unwrap();
301        assert_eq!(v, to_z(&[3, 6]));
302    }
303    #[test]
304    fn test_continued_fraction_of_sqrt11() {
305        let v = continued_fraction_of_sqrt(Z::from(11)).unwrap();
306        assert_eq!(v, to_z(&[3, 3, 6]));
307    }
308    #[test]
309    fn test_continued_fraction_of_sqrt12() {
310        let v = continued_fraction_of_sqrt(Z::from(12)).unwrap();
311        assert_eq!(v, to_z(&[3, 2, 6]));
312    }
313    #[test]
314    fn test_continued_fraction_of_sqrt13() {
315        let v = continued_fraction_of_sqrt(Z::from(13)).unwrap();
316        assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
317    }
318    #[test]
319    fn test_continued_fraction_of_sqrt31() {
320        let v = continued_fraction_of_sqrt(Z::from(31)).unwrap();
321        assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
322    }
323    #[test]
324    fn test_continued_fraction_of_sqrt94() {
325        let v = continued_fraction_of_sqrt(Z::from(94)).unwrap();
326        assert_eq!(
327            v,
328            to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
329        );
330    }
331    #[test]
332    fn test_continued_fraction_of_sqrt338() {
333        let v = continued_fraction_of_sqrt(Z::from(338)).unwrap();
334        assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
335    }
336    #[test]
337    fn test_solve_pell() {
338        let v = solve_pell(Z::from(653));
339        assert_eq!(
340            v,
341            Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
342        );
343    }
344    #[test]
345    fn test_solve_pell2() {
346        let v = solve_pell(Z::from(115));
347        assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
348    }
349    #[test]
350    fn test_solve_pell3() {
351        let v = solve_pell(Z::from(114));
352        assert_eq!(v, Solution::Positive(Z::from(1025), Z::from(96)));
353    }
354    #[test]
355    fn test_solve_pell4() {
356        let v = solve_pell(Z::from(641));
357        assert_eq!(
358            v,
359            Solution::Negative(Z::from(36120833468u64), Z::from(1426687145))
360        );
361    }
362    #[test]
363    fn test_solve_pell5() {
364        let Solution::Negative(x, y) = solve_pell(Z::from(1021)) else {
365            panic!("not negative")
366        };
367        assert_eq!(
368            x,
369            Z::from_str_radix("315217280372584882515030", 10).unwrap()
370        );
371        assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
372    }
373}