pell_equation/
lib.rs

1#![doc = include_str!("../README.md")]
2// #![warn(clippy::pedantic)]
3
4use num_traits::One;
5use rug::Complete;
6use std::ops::Mul;
7type Z = rug::Integer;
8
9#[non_exhaustive]
10#[derive(Debug, thiserror::Error)]
11/// Error type
12pub enum Error {
13    /// Input value is out of domain
14    #[error("Input value is out of domain")]
15    OutOfDomain,
16}
17
18fn continued_fraction_of_sqrt_small(d: i64) -> Vec<Z> {
19    let sd = d.isqrt();
20    let mut r = vec![Z::from(sd)];
21    if sd * sd == d {
22        return r;
23    }
24    let mut p = -sd;
25    let mut q = 1;
26    let norm = d - p * p;
27    debug_assert_eq!(norm % q, 0);
28    q = norm / q;
29    p = -p;
30    loop {
31        let flag = q == 1;
32        let v = (sd + p) / q;
33        p -= v * q;
34        let norm = d - p * p;
35        debug_assert_eq!(norm % q, 0);
36        q = norm / q;
37        p = -p;
38        r.push(Z::from(v));
39        if flag {
40            return r;
41        }
42    }
43}
44
45fn continued_fraction_of_sqrt_large(d: Z) -> Vec<Z> {
46    let sd = d.sqrt_ref().complete();
47    let mut r = vec![sd.clone()];
48    if sd.square_ref().complete() == d {
49        return r;
50    }
51    let mut p = -sd.clone();
52    let mut q = Z::ONE.clone();
53    let norm = &d - p.square_ref().complete();
54    debug_assert!(norm.is_divisible(&q));
55    q = norm.div_exact(&q);
56    p *= -1;
57    loop {
58        let flag = q == *Z::ONE;
59        let v = (&sd + &p).complete() / &q;
60        p -= &v * &q;
61        let norm = &d - p.square_ref().complete();
62        debug_assert!(norm.is_divisible(&q));
63        q = norm.div_exact(&q);
64        p *= -1;
65        r.push(v);
66        if flag {
67            return r;
68        }
69    }
70}
71
72/// Calculate continued fraction of √d
73///
74/// Calculate [simple continued fraction](https://en.wikipedia.org/wiki/Simple_continued_fraction)
75/// of √d.  
76/// If d is negative returns `Err(Error::OutOfDomain)`.  
77/// ex : √2 = [1; 2, 2, 2, ...]
78/// ```
79/// use rug::Integer;
80/// let v = pell_equation::continued_fraction_of_sqrt(Integer::from(2)).unwrap();
81/// assert_eq!(v, vec![Integer::from(1), Integer::from(2)]);
82/// ```
83pub fn continued_fraction_of_sqrt(d: Z) -> Result<Vec<Z>, Error> {
84    if d.is_negative() {
85        Err(Error::OutOfDomain)
86    } else if let Some(d) = d.to_i64() {
87        Ok(continued_fraction_of_sqrt_small(d))
88    } else {
89        Ok(continued_fraction_of_sqrt_large(d))
90    }
91}
92
93/// Fundamental solution of `x^2 - d*y^2 = ±1`
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum Solution {
96    /// Fundamental solution of `x^2 - d*y^2 = -1`
97    Negative(Z, Z),
98    /// Fundamental solution of `x^2 - d*y^2 = 1`
99    Positive(Z, Z),
100    /// Not exsist nontirivial solution
101    NotExist,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
105struct Matrix2x2 {
106    a: Z,
107    b: Z,
108    c: Z,
109    d: Z,
110}
111impl Matrix2x2 {
112    fn new(a: Z) -> Self {
113        Self {
114            a,
115            b: Z::ONE.clone(),
116            c: Z::ONE.clone(),
117            d: Z::ZERO,
118        }
119    }
120}
121impl num_traits::One for Matrix2x2 {
122    fn one() -> Self {
123        Self {
124            a: Z::ONE.clone(),
125            b: Z::ZERO,
126            c: Z::ZERO,
127            d: Z::ONE.clone(),
128        }
129    }
130}
131#[auto_impl_ops::auto_ops]
132impl std::ops::Mul<&Matrix2x2> for &Matrix2x2 {
133    type Output = Matrix2x2;
134    fn mul(self, rhs: &Matrix2x2) -> Self::Output {
135        let a = self.a.clone() * &rhs.a + &self.b * &rhs.c;
136        let b = self.a.clone() * &rhs.b + &self.b * &rhs.d;
137        let c = self.c.clone() * &rhs.a + &self.d * &rhs.c;
138        let d = self.c.clone() * &rhs.b + &self.d * &rhs.d;
139        Matrix2x2 { a, b, c, d }
140    }
141}
142fn tree_product(a: &[Z]) -> Matrix2x2 {
143    let n = (a.len().ilog2() + 1) as usize;
144    let mut v = vec![Matrix2x2::one(); n];
145    for (i, a) in a.iter().rev().enumerate() {
146        let a = Matrix2x2::new(a.clone());
147        v[0] *= a;
148        let mut i = i + 1;
149        let mut j = 0;
150        while i % 2 == 0 {
151            let mut t = Matrix2x2::one();
152            std::mem::swap(&mut t, &mut v[j]);
153            v[j + 1] *= t;
154            i >>= 1;
155            j += 1;
156        }
157    }
158    let mut t = Matrix2x2::one();
159    for i in (0..n).rev() {
160        t *= &v[i];
161    }
162    t
163}
164
165fn solve_pell_aux(mut a: Vec<Z>, d: Z) -> Solution {
166    let n = a.len() - 1;
167    if n == 0 {
168        return Solution::NotExist;
169    }
170    let (p_now, q_now) = if n > 8192 {
171        let m = tree_product(&a[1..n]);
172        let init = Matrix2x2 {
173            a: a[0].clone(),
174            b: Z::ONE.clone(),
175            c: Z::ONE.clone(),
176            d: Z::ZERO,
177        };
178        let Matrix2x2 { a, b, c: _, d: _ } = m * init;
179        (a, b)
180    } else {
181        let _ = a.pop();
182        let mut p_old = Z::ONE.clone();
183        let mut q_old = Z::ZERO;
184        let mut p_now = a[0].clone();
185        let mut q_now = Z::ONE.clone();
186        // println!("{p_old} {q_old}");
187        // println!("{p_now} {q_now}");
188        for ai in a.into_iter().skip(1) {
189            p_old += &ai * &p_now;
190            q_old += &ai * &q_now;
191            std::mem::swap(&mut p_old, &mut p_now);
192            std::mem::swap(&mut q_old, &mut q_now);
193            // println!("{p_now} {q_now}");
194        }
195        (p_now, q_now)
196    };
197    if n % 2 == 0 {
198        debug_assert_eq!(
199            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
200            *Z::ONE
201        );
202        Solution::Positive(p_now, q_now)
203    } else {
204        debug_assert_eq!(
205            p_now.square_ref().complete() - q_now.square_ref().complete() * &d,
206            -Z::ONE.clone()
207        );
208        Solution::Negative(p_now, q_now)
209    }
210}
211
212/// Calculate fundamental solution of `x^2 - d*y^2 = ±1`
213///
214/// If d is negative or perfect square returns `NotExist`.  
215/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
216/// Otherwise returns fundamental solution of `x^2 - d*y^2 = 1`.
217/// ```
218/// use rug::Integer;
219/// let v = pell_equation::solve_pell(Integer::from(2));
220/// assert_eq!(v, pell_equation::Solution::Negative(Integer::from(1), Integer::from(1)));
221/// let w = pell_equation::solve_pell(Integer::from(3));
222/// assert_eq!(w, pell_equation::Solution::Positive(Integer::from(2), Integer::from(1)));
223/// ```
224pub fn solve_pell(d: Z) -> Solution {
225    let Ok(a) = continued_fraction_of_sqrt(d.clone()) else {
226        return Solution::NotExist;
227    };
228    solve_pell_aux(a, d)
229}
230
231/// Calculate fundamental solution of `x^2 - d*y^2 = -1`
232///
233/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
234/// Otherwise returns `None`.
235/// ```
236/// use rug::Integer;
237/// let v = pell_equation::solve_pell_negative(Integer::from(2));
238/// assert_eq!(v, Some((Integer::from(1), Integer::from(1))));
239/// let w = pell_equation::solve_pell_negative(Integer::from(3));
240/// assert_eq!(w, None);
241/// ```
242pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
243    let a = continued_fraction_of_sqrt(d.clone()).ok()?;
244    if (a.len() - 1) % 2 == 0 {
245        return None;
246    }
247    let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
248        unreachable!()
249    };
250    Some((x, y))
251}
252
253/// Calculate fundamental solution of `x^2 - d*y^2 = 1`
254///
255/// If `x^2 - d*y^2 = 1` has nontrivial solution, returns its fundamental solution.  
256/// Otherwise returns `None`.
257/// ```
258/// use rug::Integer;
259/// let v = pell_equation::solve_pell_positive(Integer::from(2));
260/// assert_eq!(v, Some((Integer::from(3), Integer::from(2))));
261/// let w = pell_equation::solve_pell_positive(Integer::from(3));
262/// assert_eq!(w, Some((Integer::from(2), Integer::from(1))));
263/// ```
264pub fn solve_pell_positive(d: Z) -> Option<(Z, Z)> {
265    match solve_pell(d.clone()) {
266        Solution::NotExist => None,
267        Solution::Positive(x, y) => Some((x, y)),
268        Solution::Negative(x, y) => {
269            let y2 = 2 * (&x * &y).complete();
270            let x2 = x.square() + y.square() * d;
271            Some((x2, y2))
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    fn to_z(v: &[i32]) -> Vec<Z> {
280        v.iter().map(|x| Z::from(*x)).collect()
281    }
282    // https://planetmath.org/tableofcontinuedfractionsofsqrtnfor1n102
283    #[test]
284    fn test_continued_fraction_of_sqrt2() {
285        let v = continued_fraction_of_sqrt(Z::from(2)).unwrap();
286        assert_eq!(v, to_z(&[1, 2]));
287    }
288    #[test]
289    fn test_continued_fraction_of_sqrt3() {
290        let v = continued_fraction_of_sqrt(Z::from(3)).unwrap();
291        assert_eq!(v, to_z(&[1, 1, 2]));
292    }
293    #[test]
294    fn test_continued_fraction_of_sqrt5() {
295        let v = continued_fraction_of_sqrt(Z::from(5)).unwrap();
296        assert_eq!(v, to_z(&[2, 4]));
297    }
298    #[test]
299    fn test_continued_fraction_of_sqrt6() {
300        let v = continued_fraction_of_sqrt(Z::from(6)).unwrap();
301        assert_eq!(v, to_z(&[2, 2, 4]));
302    }
303    #[test]
304    fn test_continued_fraction_of_sqrt7() {
305        let v = continued_fraction_of_sqrt(Z::from(7)).unwrap();
306        assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
307    }
308    #[test]
309    fn test_continued_fraction_of_sqrt8() {
310        let v = continued_fraction_of_sqrt(Z::from(8)).unwrap();
311        assert_eq!(v, to_z(&[2, 1, 4]));
312    }
313    #[test]
314    fn test_continued_fraction_of_sqrt10() {
315        let v = continued_fraction_of_sqrt(Z::from(10)).unwrap();
316        assert_eq!(v, to_z(&[3, 6]));
317    }
318    #[test]
319    fn test_continued_fraction_of_sqrt11() {
320        let v = continued_fraction_of_sqrt(Z::from(11)).unwrap();
321        assert_eq!(v, to_z(&[3, 3, 6]));
322    }
323    #[test]
324    fn test_continued_fraction_of_sqrt12() {
325        let v = continued_fraction_of_sqrt(Z::from(12)).unwrap();
326        assert_eq!(v, to_z(&[3, 2, 6]));
327    }
328    #[test]
329    fn test_continued_fraction_of_sqrt13() {
330        let v = continued_fraction_of_sqrt(Z::from(13)).unwrap();
331        assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
332    }
333    #[test]
334    fn test_continued_fraction_of_sqrt31() {
335        let v = continued_fraction_of_sqrt(Z::from(31)).unwrap();
336        assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
337    }
338    #[test]
339    fn test_continued_fraction_of_sqrt94() {
340        let v = continued_fraction_of_sqrt(Z::from(94)).unwrap();
341        assert_eq!(
342            v,
343            to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
344        );
345    }
346    #[test]
347    fn test_continued_fraction_of_sqrt338() {
348        let v = continued_fraction_of_sqrt(Z::from(338)).unwrap();
349        assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
350    }
351    #[test]
352    fn test_solve_pell() {
353        let v = solve_pell(Z::from(653));
354        assert_eq!(
355            v,
356            Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
357        );
358    }
359    #[test]
360    fn test_solve_pell2() {
361        let v = solve_pell(Z::from(115));
362        assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
363    }
364    #[test]
365    fn test_solve_pell3() {
366        let v = solve_pell(Z::from(114));
367        assert_eq!(v, Solution::Positive(Z::from(1025), Z::from(96)));
368    }
369    #[test]
370    fn test_solve_pell4() {
371        let v = solve_pell(Z::from(641));
372        assert_eq!(
373            v,
374            Solution::Negative(Z::from(36120833468u64), Z::from(1426687145))
375        );
376    }
377    #[test]
378    fn test_solve_pell5() {
379        let Solution::Negative(x, y) = solve_pell(Z::from(1021)) else {
380            panic!("not negative")
381        };
382        assert_eq!(
383            x,
384            Z::from_str_radix("315217280372584882515030", 10).unwrap()
385        );
386        assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
387    }
388}