#![doc = include_str!("../README.md")]
#![warn(clippy::pedantic)]
#![allow(clippy::many_single_char_names)]
mod continued_fraction;
use continued_fraction::ContinuedFractionIterator;
use rug::{integer::IntegerExt64, Complete};
type Z = rug::Integer;
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Input value is out of domain")]
OutOfDomain,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum Solution {
Negative(Z, Z),
Positive(Z, Z),
NotExist,
ReachedLimit,
}
macro_rules! check_limit {
($target:expr, $limit:ident) => {
if let Some(limit) = $limit {
if $target.significant_bits_64() > limit {
return Solution::ReachedLimit;
}
}
};
}
macro_rules! check_limit2 {
($target1:expr, $target2:expr, $limit:ident) => {
if let Some(limit) = $limit {
if $target1.significant_bits_64() > limit || $target2.significant_bits_64() > limit {
return Solution::ReachedLimit;
}
}
};
}
macro_rules! gen_x_neg {
($y:ident, $d:ident, $limit:ident) => {{
let x2 = Z::from($d * $y.square_ref().complete() - 1);
debug_assert!(x2.is_perfect_square());
let x = x2.sqrt();
check_limit2!(x, $y, $limit);
Solution::Negative(x, $y)
}};
}
macro_rules! gen_x_pos {
($y:ident, $d:ident, $limit:ident) => {{
let x2 = Z::from($d * $y.square_ref().complete() + 1);
debug_assert!(x2.is_perfect_square());
let x = x2.sqrt();
check_limit2!(x, $y, $limit);
Solution::Positive(x, $y)
}};
}
pub fn continued_fraction_of_sqrt(d: Z) -> Result<Vec<Z>, Error> {
let iter = ContinuedFractionIterator::new(d)?;
Ok(iter.collect())
}
fn next_pq_small(d: i64, sd: i64, p: i64, q: i64) -> (bool, i64, i64, i64) {
let int1 = (sd + p) / q;
let p1 = int1 * q - p;
let p2 = p1 + q;
let norm1 = (d - p1 * p1).abs();
let norm2 = (d - p2 * p2).abs();
let epsiron = norm1 < norm2;
if epsiron {
debug_assert_eq!(norm1 % q, 0);
let q1 = norm1 / q;
(epsiron, int1, p1, q1)
} else {
debug_assert_eq!(norm2 % q, 0);
let q2 = norm2 / q;
(epsiron, int1 + 1, p2, q2)
}
}
fn solve_pell_aux_small(d: i64, limit: Option<u64>) -> Solution {
let sd = d.isqrt();
let mut q_old = None;
let mut p = 0i64;
let mut q = 1i64;
let mut b_old = Z::ONE.clone();
let mut b_cur = Z::ZERO;
let mut e_old = true;
loop {
let (e_cur, int, p_new, q_new) = next_pq_small(d, sd, p, q);
let b_new = if e_old {
(int * &b_cur + &b_old).complete()
} else {
(int * &b_cur - &b_old).complete()
};
check_limit!(b_new, limit);
if p == p_new {
let t = if e_old { b_old } else { -b_old };
let b = b_cur * (b_new + t);
return gen_x_pos!(b, d, limit);
} else if q == q_new {
return if e_cur {
let b = b_new.square_ref() + b_cur.square();
gen_x_neg!(b, d, limit)
} else {
let b = b_new.square_ref() - b_cur.square();
gen_x_pos!(b, d, limit)
};
} else if let Some(q_old) = q_old {
if q + q_old == p && !e_old {
let b = Z::from(2 * b_cur.square() - &b_new * &b_old);
return gen_x_neg!(b, d, limit);
}
}
q_old = if q % 2 == 0 { Some(q / 2) } else { None };
p = p_new;
q = q_new;
b_old = b_cur;
b_cur = b_new;
e_old = e_cur;
}
}
fn next_pq_large(d: &Z, sd: &Z, p: &Z, q: &Z) -> (bool, Z, Z, Z) {
let int1 = (sd + p).complete() / q;
let p1 = (&int1 * q - p).complete();
let p2 = (&p1 + q).complete();
let norm1 = (d - p1.square_ref()).complete().abs();
let norm2 = (d - p2.square_ref()).complete().abs();
let epsiron = norm1 < norm2;
if epsiron {
debug_assert!(norm1.is_divisible(q));
let q1 = norm1.div_exact(q);
(epsiron, int1, p1, q1)
} else {
debug_assert!(norm2.is_divisible(q));
let q2 = norm2.div_exact(q);
(epsiron, int1 + 1, p2, q2)
}
}
fn solve_pell_aux_large(d: &Z, limit: Option<u64>) -> Solution {
let sd = d.sqrt_ref().complete();
let mut q_old = None;
let mut p = Z::ZERO;
let mut q = Z::ONE.clone();
let mut b_old = Z::ONE.clone();
let mut b_cur = Z::ZERO;
let mut e_old = true;
loop {
let (e_cur, int, p_new, q_new) = next_pq_large(d, &sd, &p, &q);
let b_new = if e_old {
(&int * &b_cur + &b_old).complete()
} else {
(&int * &b_cur - &b_old).complete()
};
check_limit!(b_new, limit);
if p == p_new {
let t = if e_old { b_old } else { -b_old };
let b = b_cur * (b_new + t);
return gen_x_pos!(b, d, limit);
} else if q == q_new {
return if e_cur {
let b = b_new.square_ref() + b_cur.square();
gen_x_neg!(b, d, limit)
} else {
let b = b_new.square_ref() - b_cur.square();
gen_x_pos!(b, d, limit)
};
} else if let Some(q_old) = q_old {
if &q + q_old == p && !e_old {
let b = Z::from(2 * b_cur.square() - &b_new * &b_old);
return gen_x_neg!(b, d, limit);
}
}
q_old = if q.is_divisible_u(2) {
Some(q.div_exact_u(2))
} else {
None
};
p = p_new;
q = q_new;
b_old = b_cur;
b_cur = b_new;
e_old = e_cur;
}
}
fn solve_pell_aux(d: &Z, limit: Option<u64>) -> Solution {
if d.is_negative() || d.is_perfect_square() {
return Solution::NotExist;
}
if let Some(d) = d.to_i64() {
solve_pell_aux_small(d, limit)
} else {
solve_pell_aux_large(d, limit)
}
}
#[must_use]
pub fn solve_pell(d: &Z, limit: Option<u64>) -> Solution {
solve_pell_aux(d, limit)
}
#[must_use]
pub fn solve_pell_negative(d: &Z, limit: Option<u64>) -> Solution {
match solve_pell_aux(d, limit) {
Solution::Positive(_, _) => Solution::NotExist,
x => x,
}
}
#[must_use]
pub fn solve_pell_positive(d: &Z, limit: Option<u64>) -> Solution {
match solve_pell(d, limit) {
Solution::Negative(x, y) => {
let y2 = Z::from(2 * (&x * &y).complete());
let x2 = x.square() + y.square() * d;
check_limit2!(x2, y2, limit);
Solution::Positive(x2, y2)
}
x => x,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn to_z(v: &[i32]) -> Vec<Z> {
v.iter().map(|x| Z::from(*x)).collect()
}
#[test]
fn test_continued_fraction_of_sqrt() {
let continued_fractions = vec![
(2, vec![1, 2]),
(3, vec![1, 1, 2]),
(5, vec![2, 4]),
(6, vec![2, 2, 4]),
(7, vec![2, 1, 1, 1, 4]),
(8, vec![2, 1, 4]),
(10, vec![3, 6]),
(11, vec![3, 3, 6]),
(12, vec![3, 2, 6]),
(13, vec![3, 1, 1, 1, 1, 6]),
(14, vec![3, 1, 2, 1, 6]),
(15, vec![3, 1, 6]),
(17, vec![4, 8]),
(18, vec![4, 4, 8]),
(19, vec![4, 2, 1, 3, 1, 2, 8]),
(20, vec![4, 2, 8]),
(21, vec![4, 1, 1, 2, 1, 1, 8]),
(22, vec![4, 1, 2, 4, 2, 1, 8]),
(23, vec![4, 1, 3, 1, 8]),
(24, vec![4, 1, 8]),
(26, vec![5, 10]),
(27, vec![5, 5, 10]),
(28, vec![5, 3, 2, 3, 10]),
(29, vec![5, 2, 1, 1, 2, 10]),
(30, vec![5, 2, 10]),
(31, vec![5, 1, 1, 3, 5, 3, 1, 1, 10]),
(32, vec![5, 1, 1, 1, 10]),
(33, vec![5, 1, 2, 1, 10]),
(34, vec![5, 1, 4, 1, 10]),
(35, vec![5, 1, 10]),
(37, vec![6, 12]),
(38, vec![6, 6, 12]),
(39, vec![6, 4, 12]),
(40, vec![6, 3, 12]),
(41, vec![6, 2, 2, 12]),
(42, vec![6, 2, 12]),
(43, vec![6, 1, 1, 3, 1, 5, 1, 3, 1, 1, 12]),
(44, vec![6, 1, 1, 1, 2, 1, 1, 1, 12]),
(45, vec![6, 1, 2, 2, 2, 1, 12]),
(46, vec![6, 1, 3, 1, 1, 2, 6, 2, 1, 1, 3, 1, 12]),
(47, vec![6, 1, 5, 1, 12]),
(48, vec![6, 1, 12]),
(50, vec![7, 14]),
(51, vec![7, 7, 14]),
(52, vec![7, 4, 1, 2, 1, 4, 14]),
(53, vec![7, 3, 1, 1, 3, 14]),
(54, vec![7, 2, 1, 6, 1, 2, 14]),
(55, vec![7, 2, 2, 2, 14]),
(56, vec![7, 2, 14]),
(57, vec![7, 1, 1, 4, 1, 1, 14]),
(58, vec![7, 1, 1, 1, 1, 1, 1, 14]),
(59, vec![7, 1, 2, 7, 2, 1, 14]),
(60, vec![7, 1, 2, 1, 14]),
(61, vec![7, 1, 4, 3, 1, 2, 2, 1, 3, 4, 1, 14]),
(62, vec![7, 1, 6, 1, 14]),
(63, vec![7, 1, 14]),
(65, vec![8, 16]),
(66, vec![8, 8, 16]),
(67, vec![8, 5, 2, 1, 1, 7, 1, 1, 2, 5, 16]),
(68, vec![8, 4, 16]),
(69, vec![8, 3, 3, 1, 4, 1, 3, 3, 16]),
(70, vec![8, 2, 1, 2, 1, 2, 16]),
(71, vec![8, 2, 2, 1, 7, 1, 2, 2, 16]),
(72, vec![8, 2, 16]),
(73, vec![8, 1, 1, 5, 5, 1, 1, 16]),
(74, vec![8, 1, 1, 1, 1, 16]),
(75, vec![8, 1, 1, 1, 16]),
(76, vec![8, 1, 2, 1, 1, 5, 4, 5, 1, 1, 2, 1, 16]),
(77, vec![8, 1, 3, 2, 3, 1, 16]),
(78, vec![8, 1, 4, 1, 16]),
(79, vec![8, 1, 7, 1, 16]),
(80, vec![8, 1, 16]),
(82, vec![9, 18]),
(83, vec![9, 9, 18]),
(84, vec![9, 6, 18]),
(85, vec![9, 4, 1, 1, 4, 18]),
(86, vec![9, 3, 1, 1, 1, 8, 1, 1, 1, 3, 18]),
(87, vec![9, 3, 18]),
(88, vec![9, 2, 1, 1, 1, 2, 18]),
(89, vec![9, 2, 3, 3, 2, 18]),
(90, vec![9, 2, 18]),
(91, vec![9, 1, 1, 5, 1, 5, 1, 1, 18]),
(92, vec![9, 1, 1, 2, 4, 2, 1, 1, 18]),
(93, vec![9, 1, 1, 1, 4, 6, 4, 1, 1, 1, 18]),
(94, vec![9, 1, 2, 3, 1, 1, 5, 1, 8, 1, 5, 1, 1, 3, 2, 1, 18]),
(95, vec![9, 1, 2, 1, 18]),
(96, vec![9, 1, 3, 1, 18]),
(97, vec![9, 1, 5, 1, 1, 1, 1, 1, 1, 5, 1, 18]),
(98, vec![9, 1, 8, 1, 18]),
(99, vec![9, 1, 18]),
(101, vec![10, 20]),
];
for (d, w) in continued_fractions {
let d = Z::from(d);
let v = continued_fraction_of_sqrt(d).unwrap();
assert_eq!(v, to_z(&w));
}
}
#[test]
fn test_continued_fraction_of_sqrt338() {
let v = continued_fraction_of_sqrt(Z::from(338)).unwrap();
assert_eq!(v, to_z(&[18, 2, 1, 1, 2, 36]));
}
#[test]
fn test_solve_pell() {
let v = solve_pell(&Z::from(653), None);
assert_eq!(
v,
Solution::Negative(Z::from(2_291_286_382u64), Z::from(89_664_965))
);
}
#[test]
fn test_solve_pell2() {
let v = solve_pell(&Z::from(641), None);
assert_eq!(
v,
Solution::Negative(Z::from(36_120_833_468u64), Z::from(1_426_687_145))
);
}
#[test]
fn test_solve_pell3() {
let Solution::Negative(x, y) = solve_pell(&Z::from(1021), None) else {
panic!("not negative")
};
assert_eq!(
x,
Z::from_str_radix("315217280372584882515030", 10).unwrap()
);
assert_eq!(y, Z::from_str_radix("9865001296666956406909", 10).unwrap());
}
fn test_pell_aux_aux(d: i32, expected: &Solution) {
let solution = solve_pell_aux(&Z::from(d), None);
assert_eq!(&solution, expected);
}
#[test]
#[allow(clippy::too_many_lines)]
fn test_pell_aux() {
use Solution::{Negative, NotExist, Positive};
let list = [
(1, NotExist),
(2, Negative(Z::from(1), Z::from(1))),
(3, Positive(Z::from(2), Z::from(1))),
(4, NotExist),
(5, Negative(Z::from(2), Z::from(1))),
(6, Positive(Z::from(5), Z::from(2))),
(7, Positive(Z::from(8), Z::from(3))),
(8, Positive(Z::from(3), Z::from(1))),
(9, NotExist),
(10, Negative(Z::from(3), Z::from(1))),
(11, Positive(Z::from(10), Z::from(3))),
(12, Positive(Z::from(7), Z::from(2))),
(13, Negative(Z::from(18), Z::from(5))),
(14, Positive(Z::from(15), Z::from(4))),
(15, Positive(Z::from(4), Z::from(1))),
(16, NotExist),
(17, Negative(Z::from(4), Z::from(1))),
(18, Positive(Z::from(17), Z::from(4))),
(19, Positive(Z::from(170), Z::from(39))),
(20, Positive(Z::from(9), Z::from(2))),
(21, Positive(Z::from(55), Z::from(12))),
(22, Positive(Z::from(197), Z::from(42))),
(23, Positive(Z::from(24), Z::from(5))),
(24, Positive(Z::from(5), Z::from(1))),
(25, NotExist),
(26, Negative(Z::from(5), Z::from(1))),
(27, Positive(Z::from(26), Z::from(5))),
(28, Positive(Z::from(127), Z::from(24))),
(29, Negative(Z::from(70), Z::from(13))),
(30, Positive(Z::from(11), Z::from(2))),
(31, Positive(Z::from(1520), Z::from(273))),
(32, Positive(Z::from(17), Z::from(3))),
(33, Positive(Z::from(23), Z::from(4))),
(34, Positive(Z::from(35), Z::from(6))),
(35, Positive(Z::from(6), Z::from(1))),
(36, NotExist),
(37, Negative(Z::from(6), Z::from(1))),
(38, Positive(Z::from(37), Z::from(6))),
(39, Positive(Z::from(25), Z::from(4))),
(40, Positive(Z::from(19), Z::from(3))),
(41, Negative(Z::from(32), Z::from(5))),
(42, Positive(Z::from(13), Z::from(2))),
(43, Positive(Z::from(3482), Z::from(531))),
(44, Positive(Z::from(199), Z::from(30))),
(45, Positive(Z::from(161), Z::from(24))),
(46, Positive(Z::from(24335), Z::from(3588))),
(47, Positive(Z::from(48), Z::from(7))),
(48, Positive(Z::from(7), Z::from(1))),
(49, NotExist),
(50, Negative(Z::from(7), Z::from(1))),
(51, Positive(Z::from(50), Z::from(7))),
(52, Positive(Z::from(649), Z::from(90))),
(53, Negative(Z::from(182), Z::from(25))),
(54, Positive(Z::from(485), Z::from(66))),
(55, Positive(Z::from(89), Z::from(12))),
(56, Positive(Z::from(15), Z::from(2))),
(57, Positive(Z::from(151), Z::from(20))),
(58, Negative(Z::from(99), Z::from(13))),
(59, Positive(Z::from(530), Z::from(69))),
(60, Positive(Z::from(31), Z::from(4))),
(61, Negative(Z::from(29718), Z::from(3805))),
(62, Positive(Z::from(63), Z::from(8))),
(63, Positive(Z::from(8), Z::from(1))),
(64, NotExist),
(65, Negative(Z::from(8), Z::from(1))),
(66, Positive(Z::from(65), Z::from(8))),
(67, Positive(Z::from(48842), Z::from(5967))),
(68, Positive(Z::from(33), Z::from(4))),
(69, Positive(Z::from(7775), Z::from(936))),
(70, Positive(Z::from(251), Z::from(30))),
(71, Positive(Z::from(3480), Z::from(413))),
(72, Positive(Z::from(17), Z::from(2))),
(73, Negative(Z::from(1068), Z::from(125))),
(74, Negative(Z::from(43), Z::from(5))),
(75, Positive(Z::from(26), Z::from(3))),
(76, Positive(Z::from(57799), Z::from(6630))),
(77, Positive(Z::from(351), Z::from(40))),
(78, Positive(Z::from(53), Z::from(6))),
(79, Positive(Z::from(80), Z::from(9))),
(80, Positive(Z::from(9), Z::from(1))),
(81, NotExist),
(82, Negative(Z::from(9), Z::from(1))),
(83, Positive(Z::from(82), Z::from(9))),
(84, Positive(Z::from(55), Z::from(6))),
(85, Negative(Z::from(378), Z::from(41))),
(86, Positive(Z::from(10405), Z::from(1122))),
(87, Positive(Z::from(28), Z::from(3))),
(88, Positive(Z::from(197), Z::from(21))),
(89, Negative(Z::from(500), Z::from(53))),
(90, Positive(Z::from(19), Z::from(2))),
(91, Positive(Z::from(1574), Z::from(165))),
(92, Positive(Z::from(1151), Z::from(120))),
(93, Positive(Z::from(12151), Z::from(1260))),
(94, Positive(Z::from(2_143_295), Z::from(221_064))),
(95, Positive(Z::from(39), Z::from(4))),
(96, Positive(Z::from(49), Z::from(5))),
(97, Negative(Z::from(5604), Z::from(569))),
(98, Positive(Z::from(99), Z::from(10))),
(99, Positive(Z::from(10), Z::from(1))),
(100, NotExist),
(101, Negative(Z::from(10), Z::from(1))),
(102, Positive(Z::from(101), Z::from(10))),
(103, Positive(Z::from(227_528), Z::from(22419))),
(104, Positive(Z::from(51), Z::from(5))),
(105, Positive(Z::from(41), Z::from(4))),
(106, Negative(Z::from(4005), Z::from(389))),
(107, Positive(Z::from(962), Z::from(93))),
(108, Positive(Z::from(1351), Z::from(130))),
(109, Negative(Z::from(8_890_182), Z::from(851_525))),
(110, Positive(Z::from(21), Z::from(2))),
(111, Positive(Z::from(295), Z::from(28))),
(112, Positive(Z::from(127), Z::from(12))),
(113, Negative(Z::from(776), Z::from(73))),
(114, Positive(Z::from(1025), Z::from(96))),
(115, Positive(Z::from(1126), Z::from(105))),
];
for (d, expected) in list {
test_pell_aux_aux(d, &expected);
}
}
fn solve_pell_rcf(d: &Z, limit: Option<u64>) -> Solution {
if d.is_negative() || d.is_perfect_square() {
return Solution::NotExist;
}
let (negative, b) = {
let iter = ContinuedFractionIterator::new_cut_tail(d.clone())
.unwrap()
.skip(1);
let mut b_old = Z::ZERO;
let mut b_now = Z::ONE.clone();
let mut negative = true;
for ai in iter {
b_old += &ai * &b_now;
check_limit!(b_old, limit);
std::mem::swap(&mut b_old, &mut b_now);
negative ^= true;
}
(negative, b_now)
};
if negative {
gen_x_neg!(b, d, limit)
} else {
gen_x_pos!(b, d, limit)
}
}
#[test]
fn test_pell_aux_2() {
for d in 2..30000 {
let d = Z::from(d);
let s2 = solve_pell_aux(&d, None);
let s1 = solve_pell_rcf(&d, None);
assert_eq!(s1, s2);
}
}
}