pell_equation/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use rug::Complete;
4type Q = rug::Rational;
5type Z = rug::Integer;
6
7/// Calculate continued fraction of √d
8///
9/// Calculate [simple continued fraction](https://en.wikipedia.org/wiki/Simple_continued_fraction)
10/// of √d.  
11/// ex : √2 = [1; 2, 2, 2, ...]
12/// ```
13/// use rug::Integer;
14/// let v = pell_equation::continued_fraction_of_sqrt(Integer::from(2));
15/// assert_eq!(v, vec![Integer::from(1), Integer::from(2)]);
16/// ```
17pub fn continued_fraction_of_sqrt(d: Z) -> Vec<Z> {
18    let sd = Z::from(d.sqrt_ref());
19    let mut r = vec![sd.clone()];
20    if Z::from(sd.square_ref()) == d {
21        return r;
22    }
23    let mut p = -sd.clone();
24    let mut q = Z::ONE.clone();
25    let norm = &d - Z::from(p.square_ref());
26    debug_assert!(norm.is_divisible(&q));
27    q = norm.div_exact(&q);
28    p *= -1;
29    loop {
30        let flag = q == *Z::ONE;
31        let t = Q::from(&sd + &p) / &q;
32        let v = t.floor().numer().clone();
33        p -= &v * &q;
34        let norm = &d - Z::from(p.square_ref());
35        debug_assert!(norm.is_divisible(&q));
36        q = norm.div_exact(&q);
37        p *= -1;
38        r.push(v);
39        if flag {
40            return r;
41        }
42    }
43}
44
45/// Fundamental solution of `x^2 - d*y^2 = ±1`
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum Solution {
48    /// Fundamental solution of `x^2 - d*y^2 = -1`
49    Negative(Z, Z),
50    /// Fundamental solution of `x^2 - d*y^2 = 1`
51    Positive(Z, Z),
52}
53
54fn solve_pell_aux(mut a: Vec<Z>, d: Z) -> Solution {
55    let n = a.len() - 1;
56    let _ = a.pop();
57    let mut p_old = Z::ONE.clone();
58    let mut q_old = Z::ZERO;
59    let mut p_now = a[0].clone();
60    let mut q_now = Z::ONE.clone();
61    // println!("{p_old} {q_old}");
62    // println!("{p_now} {q_now}");
63    for ai in a.into_iter().skip(1) {
64        p_old += &ai * &p_now;
65        q_old += &ai * &q_now;
66        std::mem::swap(&mut p_old, &mut p_now);
67        std::mem::swap(&mut q_old, &mut q_now);
68        // println!("{p_now} {q_now}");
69    }
70    if n % 2 == 0 {
71        debug_assert_eq!(
72            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
73            *Z::ONE
74        );
75        Solution::Positive(p_now, q_now)
76    } else {
77        debug_assert_eq!(
78            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
79            -Z::ONE.clone()
80        );
81        Solution::Negative(p_now, q_now)
82    }
83}
84
85/// Calculate fundamental solution of `x^2 - d*y^2 = ±1`
86///
87/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
88/// Otherwise returns fundamental solution of `x^2 - d*y^2 = 1`.
89/// ```
90/// use rug::Integer;
91/// let v = pell_equation::solve_pell(Integer::from(2));
92/// assert_eq!(v, pell_equation::Solution::Negative(Integer::from(1), Integer::from(1)));
93/// let w = pell_equation::solve_pell(Integer::from(3));
94/// assert_eq!(w, pell_equation::Solution::Positive(Integer::from(2), Integer::from(1)));
95/// ```
96pub fn solve_pell(d: Z) -> Solution {
97    let a = continued_fraction_of_sqrt(d.clone());
98    solve_pell_aux(a, d)
99}
100
101/// Calculate fundamental solution of `x^2 - d*y^2 = -1`
102///
103/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
104/// Otherwise returns None.
105/// ```
106/// use rug::Integer;
107/// let v = pell_equation::solve_pell_negative(Integer::from(2));
108/// assert_eq!(v, Some((Integer::from(1), Integer::from(1))));
109/// let w = pell_equation::solve_pell_negative(Integer::from(3));
110/// assert_eq!(w, None);
111/// ```
112pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
113    let a = continued_fraction_of_sqrt(d.clone());
114    if (a.len() - 1) % 2 == 0 {
115        return None;
116    }
117    let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
118        unreachable!()
119    };
120    Some((x, y))
121}
122
123/// Calculate fundamental solution of `x^2 - d*y^2 = 1`
124///
125/// This function returns `x^2 - d*y^2 = 1` fundamental solution.  
126/// ```
127/// use rug::Integer;
128/// let v = pell_equation::solve_pell_positive(Integer::from(2));
129/// assert_eq!(v, (Integer::from(3), Integer::from(2)));
130/// let w = pell_equation::solve_pell_positive(Integer::from(3));
131/// assert_eq!(w, (Integer::from(2), Integer::from(1)));
132/// ```
133pub fn solve_pell_positive(d: Z) -> (Z, Z) {
134    match solve_pell(d.clone()) {
135        Solution::Positive(x, y) => (x, y),
136        Solution::Negative(x, y) => {
137            let y2 = 2 * (&x * &y).complete();
138            let x2 = x.square() + y.square() * d;
139            (x2, y2)
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    fn to_z(v: &[i32]) -> Vec<Z> {
148        v.iter().map(|x| Z::from(*x)).collect()
149    }
150    // https://planetmath.org/tableofcontinuedfractionsofsqrtnfor1n102
151    #[test]
152    fn test_continued_fraction_of_sqrt2() {
153        let v = continued_fraction_of_sqrt(Z::from(2));
154        assert_eq!(v, to_z(&[1, 2]));
155    }
156    #[test]
157    fn test_continued_fraction_of_sqrt3() {
158        let v = continued_fraction_of_sqrt(Z::from(3));
159        assert_eq!(v, to_z(&[1, 1, 2]));
160    }
161    #[test]
162    fn test_continued_fraction_of_sqrt5() {
163        let v = continued_fraction_of_sqrt(Z::from(5));
164        assert_eq!(v, to_z(&[2, 4]));
165    }
166    #[test]
167    fn test_continued_fraction_of_sqrt6() {
168        let v = continued_fraction_of_sqrt(Z::from(6));
169        assert_eq!(v, to_z(&[2, 2, 4]));
170    }
171    #[test]
172    fn test_continued_fraction_of_sqrt7() {
173        let v = continued_fraction_of_sqrt(Z::from(7));
174        assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
175    }
176    #[test]
177    fn test_continued_fraction_of_sqrt8() {
178        let v = continued_fraction_of_sqrt(Z::from(8));
179        assert_eq!(v, to_z(&[2, 1, 4]));
180    }
181    #[test]
182    fn test_continued_fraction_of_sqrt10() {
183        let v = continued_fraction_of_sqrt(Z::from(10));
184        assert_eq!(v, to_z(&[3, 6]));
185    }
186    #[test]
187    fn test_continued_fraction_of_sqrt11() {
188        let v = continued_fraction_of_sqrt(Z::from(11));
189        assert_eq!(v, to_z(&[3, 3, 6]));
190    }
191    #[test]
192    fn test_continued_fraction_of_sqrt12() {
193        let v = continued_fraction_of_sqrt(Z::from(12));
194        assert_eq!(v, to_z(&[3, 2, 6]));
195    }
196    #[test]
197    fn test_continued_fraction_of_sqrt13() {
198        let v = continued_fraction_of_sqrt(Z::from(13));
199        assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
200    }
201    #[test]
202    fn test_continued_fraction_of_sqrt31() {
203        let v = continued_fraction_of_sqrt(Z::from(31));
204        assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
205    }
206    #[test]
207    fn test_continued_fraction_of_sqrt94() {
208        let v = continued_fraction_of_sqrt(Z::from(94));
209        assert_eq!(
210            v,
211            to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
212        );
213    }
214    #[test]
215    fn test_continued_fraction_of_sqrt338() {
216        let v = continued_fraction_of_sqrt(Z::from(338));
217        assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
218    }
219    #[test]
220    fn test_solve_pell() {
221        let v = solve_pell(Z::from(653));
222        assert_eq!(
223            v,
224            Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
225        );
226    }
227    #[test]
228    fn test_solve_pell2() {
229        let v = solve_pell(Z::from(115));
230        assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
231    }
232    #[test]
233    fn test_solve_pell3() {
234        let v = solve_pell(Z::from(114));
235        assert_eq!(v, Solution::Positive(Z::from(1025), Z::from(96)));
236    }
237    #[test]
238    fn test_solve_pell4() {
239        let v = solve_pell(Z::from(641));
240        assert_eq!(
241            v,
242            Solution::Negative(Z::from(36120833468u64), Z::from(1426687145))
243        );
244    }
245    #[test]
246    fn test_solve_pell5() {
247        let Solution::Negative(x, y) = solve_pell(Z::from(1021)) else {
248            panic!("unexpected positive")
249        };
250        assert_eq!(
251            x,
252            Z::from_str_radix("315217280372584882515030", 10).unwrap()
253        );
254        assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
255    }
256}