pell-equation 0.1.2

solve Pell's equation
Documentation
#![doc = include_str!("../README.md")]

use std::{
    fmt,
    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
    sync::Arc,
};
type Q = rug::Rational;
type Z = rug::Integer;

#[derive(Debug, Clone, PartialEq, Eq)]
struct QuadraticField {
    a: Q,
    b: Q,
    d: Arc<Z>,
}
impl fmt::Display for QuadraticField {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        write!(f, "{}+{}{}", self.a, self.b, self.d)
    }
}
#[auto_impl_ops::auto_ops]
impl AddAssign<&QuadraticField> for QuadraticField {
    fn add_assign(&mut self, other: &Self) {
        assert_eq!(self.d, other.d);
        self.a += &other.a;
        self.b += &other.b;
    }
}
#[auto_impl_ops::auto_ops]
impl SubAssign<&QuadraticField> for QuadraticField {
    fn sub_assign(&mut self, other: &Self) {
        assert_eq!(self.d, other.d);
        self.a -= &other.a;
        self.b -= &other.b;
    }
}
#[auto_impl_ops::auto_ops]
impl SubAssign<&Z> for QuadraticField {
    fn sub_assign(&mut self, other: &Z) {
        self.a -= other;
    }
}
#[auto_impl_ops::auto_ops]
impl MulAssign<&QuadraticField> for QuadraticField {
    fn mul_assign(&mut self, other: &Self) {
        assert_eq!(self.d, other.d);
        let new_a = Q::from(&self.a * &other.a) + Q::from(&self.b * &other.b) * &*self.d;
        self.b = Q::from(&self.a * &other.b) + Q::from(&self.b * &other.a);
        self.a = new_a;
    }
}
#[auto_impl_ops::auto_ops]
impl DivAssign<&QuadraticField> for QuadraticField {
    fn div_assign(&mut self, other: &Self) {
        assert_eq!(self.d, other.d);
        let denom = Q::from(other.a.square_ref()) - Q::from(other.b.square_ref()) * &*self.d;
        let new_a = (Q::from(&self.a * &other.a) - Q::from(&self.b * &other.b) * &*self.d) / &denom;
        self.b = (Q::from(&self.b * &other.a) - Q::from(&self.a * &other.b)) / denom;
        self.a = new_a;
    }
}
impl QuadraticField {
    fn from_integer(v: Z, d: Arc<Z>) -> Self {
        Self {
            a: Q::from(v),
            b: Q::ZERO.clone(),
            d,
        }
    }
    fn one(d: Arc<Z>) -> Self {
        Self::from_integer(Z::ONE.clone(), d)
    }
    fn recip(&self) -> Self {
        QuadraticField::one(Arc::clone(&self.d)) / self
    }
    fn floor(&self) -> Z {
        let d = Z::from(self.d.sqrt_ref());
        let mut left = &self.a + Q::from(&self.b * &d);
        let mut right = &self.a + Q::from(&self.b * (d + 1));
        if right < left {
            std::mem::swap(&mut left, &mut right);
        }
        let offset = Q::from(self.b.square_ref()) * &*self.d;
        while Q::from(left.floor_ref()) != Q::from(right.floor_ref()) {
            let mid = Q::from(&left + &right) / 2;
            let y = Q::from(&mid - &self.a).square() - &offset;
            if y <= *Q::ZERO {
                left = mid;
            } else {
                right = mid;
            }
        }
        left.floor().numer().clone()
    }
}

/// Calculate continued fraction of √d
///
/// Calculate [simple continued fraction](https://en.wikipedia.org/wiki/Simple_continued_fraction)
/// of √d.  
/// ex : √2 = [1; 2, 2, 2, ...]
/// ```
/// use rug::Integer;
/// let v = pell_equation::continued_fraction_of_sqrt(Integer::from(2));
/// assert_eq!(v, vec![Integer::from(1), Integer::from(2)]);
/// ```
pub fn continued_fraction_of_sqrt(d: Z) -> Vec<Z> {
    let sd = Z::from(d.sqrt_ref());
    if Z::from(sd.square_ref()) == d {
        return vec![sd];
    }
    let d = Arc::new(d);
    let mut v = QuadraticField {
        a: Q::ZERO.clone(),
        b: Q::ONE.clone(),
        d: Arc::clone(&d),
    };
    let int = v.floor();
    v = (v - &int).recip();
    let v0 = v.clone();
    let mut a = vec![int];
    loop {
        let int = v.floor();
        v = (v - &int).recip();
        a.push(int);
        if v == v0 {
            return a;
        }
    }
}

/// Fundamental solution of `x^2 - d*y^2 = ±1`
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Solution {
    /// Fundamental solution of `x^2 - d*y^2 = -1`
    Negative(Z, Z),
    /// Fundamental solution of `x^2 - d*y^2 = 1`
    Positive(Z, Z),
}

fn solve_pell_aux(a: Vec<Z>, d: Z) -> Solution {
    let mut p_old = Z::ONE.clone();
    let mut q_old = Z::ZERO;
    let mut p_now = a[0].clone();
    let mut q_now = Z::ONE.clone();
    // println!("{p_old} {q_old}");
    // println!("{p_now} {q_now}");
    let z = Q::from(p_now.square_ref()) - Q::from(q_now.square_ref()) * &d;
    if z == -Z::ONE.clone() {
        return Solution::Negative(p_now, q_now);
    } else if z == *Z::ONE {
        return Solution::Positive(p_now, q_now);
    }
    for a in a.iter().skip(1) {
        let p_new = a * &p_now + p_old;
        let q_new = a * &q_now + q_old;
        p_old = p_now;
        q_old = q_now;
        p_now = p_new;
        q_now = q_new;
        // println!("{p_now} {q_now}");
        let z = Q::from(p_now.square_ref()) - Q::from(q_now.square_ref()) * &d;
        if z == -Z::ONE.clone() {
            return Solution::Negative(p_now, q_now);
        } else if z == *Z::ONE {
            return Solution::Positive(p_now, q_now);
        }
    }
    unreachable!()
}

/// Calculate fundamental solution of `x^2 - d*y^2 = ±1`
///
/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
/// Otherwise returns fundamental solution of `x^2 - d*y^2 = 1`.
/// ```
/// use rug::Integer;
/// let v = pell_equation::solve_pell(Integer::from(2));
/// assert_eq!(v, pell_equation::Solution::Negative(Integer::from(1), Integer::from(1)));
/// let w = pell_equation::solve_pell(Integer::from(3));
/// assert_eq!(w, pell_equation::Solution::Positive(Integer::from(2), Integer::from(1)));
/// ```
pub fn solve_pell(d: Z) -> Solution {
    let a = continued_fraction_of_sqrt(d.clone());
    solve_pell_aux(a, d)
}

/// Calculate fundamental solution of `x^2 - d*y^2 = -1`
///
/// If `x^2 - d*y^2 = -1` has nontrivial solution, returns its fundamental solution.  
/// Otherwise returns None.
/// ```
/// use rug::Integer;
/// let v = pell_equation::solve_pell_negative(Integer::from(2));
/// assert_eq!(v, Some((Integer::from(1), Integer::from(1))));
/// let w = pell_equation::solve_pell_negative(Integer::from(3));
/// assert_eq!(w, None);
/// ```
pub fn solve_pell_negative(d: Z) -> Option<(Z, Z)> {
    let a = continued_fraction_of_sqrt(d.clone());
    if (a.len() - 1) % 2 == 0 {
        return None;
    }
    let Solution::Negative(x, y) = solve_pell_aux(a, d) else {
        unreachable!()
    };
    Some((x, y))
}

/// Calculate fundamental solution of `x^2 - d*y^2 = 1`
///
/// This function returns `x^2 - d*y^2 = 1` fundamental solution.  
/// ```
/// use rug::Integer;
/// let v = pell_equation::solve_pell_positive(Integer::from(2));
/// assert_eq!(v, (Integer::from(3), Integer::from(2)));
/// let w = pell_equation::solve_pell_positive(Integer::from(3));
/// assert_eq!(w, (Integer::from(2), Integer::from(1)));
/// ```
pub fn solve_pell_positive(d: Z) -> (Z, Z) {
    match solve_pell(d.clone()) {
        Solution::Positive(x, y) => (x, y),
        Solution::Negative(x, y) => {
            let q = QuadraticField {
                a: Q::from(x),
                b: Q::from(y),
                d: d.into(),
            };
            let q2 = &q * &q;
            (q2.a.numer().clone(), q2.b.numer().clone())
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test() {
        let d = Arc::new(Z::from(3));
        let q = QuadraticField {
            a: Q::ZERO.clone(),
            b: Q::ONE.clone(),
            d: Arc::clone(&d),
        };
        let a = q.floor();
        assert_eq!(&a, Z::ONE);
        let q = (q - a).recip();
        let r = QuadraticField {
            a: Q::from_f64(0.5).unwrap(),
            b: Q::from_f64(0.5).unwrap(),
            d: Arc::clone(&d),
        };
        assert_eq!(q, r);
    }
    fn to_z(v: &[i32]) -> Vec<Z> {
        v.iter().map(|x| Z::from(*x)).collect()
    }
    // https://planetmath.org/tableofcontinuedfractionsofsqrtnfor1n102
    #[test]
    fn test_continued_fraction_of_sqrt2() {
        let v = continued_fraction_of_sqrt(Z::from(2));
        assert_eq!(v, to_z(&[1, 2]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt3() {
        let v = continued_fraction_of_sqrt(Z::from(3));
        assert_eq!(v, to_z(&[1, 1, 2]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt5() {
        let v = continued_fraction_of_sqrt(Z::from(5));
        assert_eq!(v, to_z(&[2, 4]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt6() {
        let v = continued_fraction_of_sqrt(Z::from(6));
        assert_eq!(v, to_z(&[2, 2, 4]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt7() {
        let v = continued_fraction_of_sqrt(Z::from(7));
        assert_eq!(v, to_z(&[2, 1, 1, 1, 4]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt8() {
        let v = continued_fraction_of_sqrt(Z::from(8));
        assert_eq!(v, to_z(&[2, 1, 4]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt10() {
        let v = continued_fraction_of_sqrt(Z::from(10));
        assert_eq!(v, to_z(&[3, 6]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt11() {
        let v = continued_fraction_of_sqrt(Z::from(11));
        assert_eq!(v, to_z(&[3, 3, 6]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt12() {
        let v = continued_fraction_of_sqrt(Z::from(12));
        assert_eq!(v, to_z(&[3, 2, 6]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt13() {
        let v = continued_fraction_of_sqrt(Z::from(13));
        assert_eq!(v, to_z(&[3, 1, 1, 1, 1, 6]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt31() {
        let v = continued_fraction_of_sqrt(Z::from(31));
        assert_eq!(v, to_z(&[5, 1, 1, 3, 5, 3, 1, 1, 10]));
    }
    #[test]
    fn test_continued_fraction_of_sqrt94() {
        let v = continued_fraction_of_sqrt(Z::from(94));
        assert_eq!(
            v,
            to_z(&[9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18])
        );
    }
    #[test]
    fn test_continued_fraction_of_sqrt338() {
        let v = continued_fraction_of_sqrt(Z::from(338));
        assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
    }
    #[test]
    fn test_solve_pell() {
        let v = solve_pell(Z::from(653));
        assert_eq!(
            v,
            Solution::Negative(Z::from(2291286382u64), Z::from(89664965))
        );
    }
    #[test]
    fn test_solve_pell2() {
        let v = solve_pell(Z::from(115));
        assert_eq!(v, Solution::Positive(Z::from(1126), Z::from(105)));
    }
}