1use crate::errors::AlkahestError;
16use crate::kernel::{Domain, ExprId, ExprPool};
17use crate::poly::groebner::ideal::GbPoly;
18use rug::ops::Pow;
19use rug::Integer;
20use std::collections::BTreeMap;
21use std::fmt;
22
23use super::{expr_to_gbpoly, SolverError};
24
25#[derive(Debug, Clone)]
27pub enum DiophantineError {
28 NotPolynomial(String),
30 NonIntegerCoefficients,
32 Unsupported(String),
34 NoSolution,
36}
37
38impl fmt::Display for DiophantineError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 DiophantineError::NotPolynomial(s) => write!(f, "diophantine: {s}"),
42 DiophantineError::NonIntegerCoefficients => {
43 write!(f, "diophantine: coefficients must be rational integers")
44 }
45 DiophantineError::Unsupported(s) => write!(f, "diophantine: unsupported: {s}"),
46 DiophantineError::NoSolution => write!(f, "diophantine: no integer solution"),
47 }
48 }
49}
50
51impl std::error::Error for DiophantineError {}
52
53impl AlkahestError for DiophantineError {
54 fn code(&self) -> &'static str {
55 match self {
56 DiophantineError::NotPolynomial(_) => "E-DIOPH-001",
57 DiophantineError::NonIntegerCoefficients => "E-DIOPH-002",
58 DiophantineError::Unsupported(_) => "E-DIOPH-003",
59 DiophantineError::NoSolution => "E-DIOPH-004",
60 }
61 }
62
63 fn remediation(&self) -> Option<&'static str> {
64 match self {
65 DiophantineError::NotPolynomial(_) => Some(
66 "pass a single polynomial equation in the listed symbols with integer/rational coefficients",
67 ),
68 DiophantineError::NonIntegerCoefficients => Some(
69 "rewrite so all coefficients are integers (no fractional parameters)",
70 ),
71 DiophantineError::Unsupported(_) => Some(
72 "supported: linear two-variable, x²+y²=n, x²−D·y²=N (no xy term); huge integers may need a smaller instance",
73 ),
74 DiophantineError::NoSolution => Some(
75 "check divisibility for linear equations; for quadratics verify solvability over ℤ",
76 ),
77 }
78 }
79}
80
81impl From<SolverError> for DiophantineError {
82 fn from(e: SolverError) -> Self {
83 DiophantineError::NotPolynomial(e.to_string())
84 }
85}
86
87#[derive(Debug, Clone)]
89pub enum DiophantineSolution {
90 ParametricLinear {
93 parameter: ExprId,
94 values: Vec<ExprId>,
95 },
96 Finite(Vec<Vec<ExprId>>),
98 PellFundamental { d: ExprId, x0: ExprId, y0: ExprId },
101 PellGeneralized {
105 d: ExprId,
106 n: ExprId,
107 x0: ExprId,
108 y0: ExprId,
109 unit_x: ExprId,
110 unit_y: ExprId,
111 },
112 NoSolution,
114}
115
116fn lcm_rational_denominators(poly: &GbPoly) -> Integer {
117 let mut l = Integer::from(1);
118 for c in poly.terms.values() {
119 let den: Integer = c.denom().into();
120 l = l.lcm(&den);
121 }
122 l
123}
124
125fn gbpoly_integer_coeffs(poly: &GbPoly) -> Result<BTreeMap<Vec<u32>, Integer>, DiophantineError> {
126 let scale = lcm_rational_denominators(poly);
127 let mut out = BTreeMap::new();
128 for (e, c) in &poly.terms {
129 let num: Integer = c.numer().into();
130 let den: Integer = c.denom().into();
131 let prod = num * &scale;
132 let scaled = div_exact(&prod, &den).ok_or(DiophantineError::NonIntegerCoefficients)?;
133 if scaled != 0 {
134 out.insert(e.clone(), scaled);
135 }
136 }
137 Ok(out)
138}
139
140fn term_gcd(iv: &[Integer]) -> Integer {
141 let mut g = iv.first().cloned().unwrap_or_else(|| Integer::from(0));
142 for x in iv.iter().skip(1) {
143 g = g.gcd(x);
144 }
145 g
146}
147
148fn div_exact(a: &Integer, g: &Integer) -> Option<Integer> {
149 let (q, r) = a.clone().div_rem_euc_ref(g).into();
150 if r == 0 {
151 Some(q)
152 } else {
153 None
154 }
155}
156
157fn extended_gcd(a: &Integer, b: &Integer) -> (Integer, Integer, Integer) {
159 let mut old_r = a.clone();
160 let mut r = b.clone();
161 let mut old_s = Integer::from(1);
162 let mut s = Integer::from(0);
163 let mut old_t = Integer::from(0);
164 let mut t = Integer::from(1);
165 while r != 0 {
166 let q = old_r.clone() / &r;
167 let mut tmp = old_r - &q * &r;
168 old_r = r;
169 r = tmp;
170 tmp = old_s - &q * &s;
171 old_s = s;
172 s = tmp;
173 tmp = old_t - &q * &t;
174 old_t = t;
175 t = tmp;
176 }
177 (old_r, old_s, old_t)
178}
179
180fn compose_sum_sq(x: &Integer, y: &Integer, c: &Integer, d: &Integer) -> (Integer, Integer) {
182 let nx: Integer = x.clone() * c - y.clone() * d;
183 let ny: Integer = x.clone() * d + y.clone() * c;
184 (nx, ny)
185}
186
187fn is_perfect_square(n: &Integer) -> bool {
188 if n.cmp0().is_lt() {
189 return false;
190 }
191 let (_, r) = n.clone().sqrt_rem(Integer::new());
192 r == 0
193}
194
195fn legendre(a: &Integer, p: &Integer) -> i32 {
197 let exp = (p.clone() - 1) / 2;
198 let ls = a
199 .clone()
200 .pow_mod(&exp, p)
201 .unwrap_or_else(|_| Integer::from(0));
202 if ls == 1 {
203 1
204 } else if ls == p.clone() - 1 {
205 -1
206 } else {
207 0
208 }
209}
210
211fn tonelli_shanks(n: &Integer, p: &Integer) -> Option<Integer> {
213 let (_, rrem) = n.clone().div_rem_euc_ref(p).into();
214 if rrem == 0 {
215 return Some(Integer::from(0));
216 }
217 if legendre(n, p) != 1 {
218 return None;
219 }
220 if p.clone() % 4u32 == 3 {
221 let exp = (p.clone() + 1) / 4;
222 return n.clone().pow_mod(&exp, p).ok();
223 }
224
225 let mut q: Integer = p.clone() - Integer::from(1);
226 let mut s = 0u32;
227 while q.clone() % 2u32 == 0 {
228 q /= 2u32;
229 s += 1;
230 }
231
232 let mut z = Integer::from(2);
233 while legendre(&z, p) != -1 {
234 z += 1;
235 if z >= *p {
236 return None;
237 }
238 }
239
240 let mut m = s;
241 let mut c = z.clone().pow_mod(&q, p).ok()?;
242 let mut t = n.clone().pow_mod(&q, p).ok()?;
243 let mut r = n.clone().pow_mod(&((q.clone() + 1) / 2), p).ok()?;
244
245 while t != 1 {
246 let mut i = 0u32;
247 let mut tt = t.clone();
248 while tt != 1 {
249 tt = (tt.clone() * &tt) % p;
250 i += 1;
251 if i > m {
252 return None;
253 }
254 }
255 let exp = m - i - 1;
256 let two_exp = Integer::from(1) << exp;
257 let b = c.clone().pow_mod(&two_exp, p).ok()?;
258 r = (r.clone() * &b) % p;
259 t = (t * &b * &b) % p;
260 c = (b.clone() * &b) % p;
261 m = i;
262 }
263 Some(r)
264}
265
266fn cornacchia_prime(d: &Integer, p: &Integer) -> Option<(Integer, Integer)> {
269 if *p == 2 {
270 if *d == 1 {
271 return Some((Integer::from(1), Integer::from(1)));
272 }
273 return None;
274 }
275 if p.clone() % 2 == 0 {
276 return None;
277 }
278
279 let negd = (p.clone() - (d.clone() % p)) % p;
281 if legendre(&negd, p) != 1 {
282 return None;
283 }
284
285 let mut r0 = tonelli_shanks(&negd, p)?;
286 if r0.clone() > p.clone() / 2 {
287 r0 = p.clone() - &r0;
288 }
289
290 let mut r = p.clone();
291 let mut s = r0;
292 while s.clone() * &s > *p {
293 let rem = r.clone() % &s;
294 r = s;
295 s = rem;
296 }
297
298 let diff = p.clone() - &s * &s;
299 if diff.cmp0().is_lt() {
300 return None;
301 }
302 let q = div_exact(&diff, d)?;
303 let (_, rr) = q.clone().sqrt_rem(Integer::new());
304 if rr != 0 {
305 return None;
306 }
307 let y = q.sqrt();
308 Some((s, y))
309}
310
311fn prime_as_sum_two_squares(p: &Integer) -> Option<(Integer, Integer)> {
313 cornacchia_prime(&Integer::from(1), p)
314}
315
316fn pollard_step(g: &Integer, c: &Integer, x: &Integer) -> Integer {
317 (x.clone() * x + c) % g
318}
319
320fn pollard_rho_factor(n: &Integer) -> Option<Integer> {
322 if n <= &Integer::from(3) || is_probable_prime(n) {
323 return None;
324 }
325 let mut x = Integer::from(2);
326 let mut y = Integer::from(2);
327 let mut d = Integer::from(1);
328 let c = Integer::from(1);
329 while d == 1 {
330 x = pollard_step(n, &c, &x);
331 y = pollard_step(n, &c, &pollard_step(n, &c, &y));
332 let diff = if x.clone() >= y {
333 x.clone() - &y
334 } else {
335 y.clone() - &x
336 };
337 d = diff.gcd(n);
338 if d == *n {
339 return None;
340 }
341 }
342 if d > 1 && d < *n {
343 Some(d)
344 } else {
345 None
346 }
347}
348
349fn is_probable_prime(n: &Integer) -> bool {
351 if n <= &Integer::from(1) {
352 return false;
353 }
354 if n <= &Integer::from(3) {
355 return true;
356 }
357 if n.clone() % 2u32 == 0 {
358 return false;
359 }
360 n.is_probably_prime(40) != rug::integer::IsPrime::No
361}
362
363fn factor_positive(mut n: Integer) -> Vec<(Integer, u32)> {
365 let mut fac: Vec<(Integer, u32)> = Vec::new();
366
367 let push_pow = |fac: &mut Vec<(Integer, u32)>, p: Integer, e: u32| {
368 if e > 0 {
369 fac.push((p, e));
370 }
371 };
372
373 let small: [u32; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
374 for &pr in &small {
375 let p = Integer::from(pr);
376 if n <= 1 {
377 break;
378 }
379 let mut e = 0u32;
380 while n.clone() % &p == 0 {
381 n /= &p;
382 e += 1;
383 }
384 push_pow(&mut fac, p, e);
385 }
386
387 let mut stack: Vec<Integer> = Vec::new();
388 if n > 1 {
389 stack.push(n);
390 }
391 let mut prime_parts: Vec<Integer> = Vec::new();
392 while let Some(m) = stack.pop() {
393 if m <= 1 {
394 continue;
395 }
396 if is_probable_prime(&m) {
397 prime_parts.push(m);
398 continue;
399 }
400 let mut split = None;
401 for _ in 0..16 {
402 if let Some(d) = pollard_rho_factor(&m) {
403 let other = m.clone() / &d;
404 split = Some((d, other));
405 break;
406 }
407 }
408 if let Some((d, other)) = split {
409 stack.push(d);
410 stack.push(other);
411 } else {
412 prime_parts.push(m);
413 }
414 }
415
416 prime_parts.sort();
417 let mut i = 0usize;
418 while i < prime_parts.len() {
419 let p = prime_parts[i].clone();
420 let mut e = 0u32;
421 while i < prime_parts.len() && prime_parts[i] == p {
422 e += 1;
423 i += 1;
424 }
425 push_pow(&mut fac, p, e);
426 }
427
428 fac
429}
430
431fn scan_sum_two_squares_pairs(n: &Integer) -> Vec<(Integer, Integer)> {
432 let mut pts: Vec<(Integer, Integer)> = Vec::new();
433 let mut x = Integer::from(0);
434 let max_x = n.clone().sqrt();
435 while x <= max_x {
436 let r = n.clone() - &x * &x;
437 if is_perfect_square(&r) {
438 let y = r.sqrt();
439 if x <= y {
440 pts.push((x.clone(), y.clone()));
441 if x < y {
442 pts.push((y.clone(), x.clone()));
443 }
444 }
445 }
446 x += 1;
447 }
448 pts
449}
450
451fn merge_distinct_pairs(acc: &mut Vec<(Integer, Integer)>, more: Vec<(Integer, Integer)>) {
452 use std::collections::BTreeSet;
453 let mut seen: BTreeSet<String> = acc.iter().map(|(a, b)| format!("{a},{b}")).collect();
454 for (x, y) in more {
455 let k = format!("{x},{y}");
456 if seen.insert(k) {
457 acc.push((x, y));
458 }
459 }
460}
461
462fn sum_two_squares_representatives(n: &Integer) -> Vec<(Integer, Integer)> {
465 if n.cmp0().is_lt() {
466 return vec![];
467 }
468 if *n == 0 {
469 return vec![(Integer::from(0), Integer::from(0))];
470 }
471
472 if n.significant_bits() > 4000 {
473 return vec![];
474 }
475
476 let mut rest = n.clone();
477 let mut e2 = 0u32;
478 while rest.clone() % 2u32 == 0 {
479 rest /= 2u32;
480 e2 += 1;
481 }
482
483 if rest == 1 {
484 let mut x = Integer::from(1);
486 let mut y = Integer::from(0);
487 for _ in 0..e2 {
488 let c = compose_sum_sq(&x, &y, &Integer::from(1), &Integer::from(1));
489 x = c.0;
490 y = c.1;
491 }
492 return canonical_pairs(x, y);
493 }
494
495 let facs = factor_positive(rest);
496 for (p, e) in &facs {
497 let m4 = p.clone() % 4;
498 if m4 == 3 && e % 2 == 1 {
499 return vec![];
500 }
501 }
502
503 let mut xr = Integer::from(1);
504 let mut yr = Integer::from(0);
505 for (p, e) in facs {
506 let m4 = p.clone() % 4;
507 if m4 == 3 {
508 debug_assert!(e % 2 == 0);
509 let half = e / 2;
510 let pk = p.clone().pow(half);
511 xr *= &pk;
512 yr *= &pk;
513 continue;
514 }
515 if p == 2 {
516 for _ in 0..e {
517 let c = compose_sum_sq(&xr, &yr, &Integer::from(1), &Integer::from(1));
518 xr = c.0;
519 yr = c.1;
520 }
521 continue;
522 }
523 let (up, vp) = match prime_as_sum_two_squares(&p) {
525 Some(t) => t,
526 None => return vec![],
527 };
528 let mut xq = Integer::from(1);
529 let mut yq = Integer::from(0);
530 for _ in 0..e {
531 let c = compose_sum_sq(&xq, &yq, &up, &vp);
532 xq = c.0;
533 yq = c.1;
534 }
535 let c = compose_sum_sq(&xr, &yr, &xq, &yq);
536 xr = c.0;
537 yr = c.1;
538 }
539
540 for _ in 0..e2 {
541 let c = compose_sum_sq(&xr, &yr, &Integer::from(1), &Integer::from(1));
542 xr = c.0;
543 yr = c.1;
544 }
545
546 let mut out = canonical_pairs(xr, yr);
547 if n.significant_bits() <= 256 {
548 merge_distinct_pairs(&mut out, scan_sum_two_squares_pairs(n));
549 }
550 out
551}
552
553fn canonical_pairs(x: Integer, y: Integer) -> Vec<(Integer, Integer)> {
554 let x = x.abs();
555 let y = y.abs();
556 let mut pts = Vec::new();
557 if x <= y {
558 pts.push((x.clone(), y.clone()));
559 if x < y {
560 pts.push((y, x));
561 }
562 } else {
563 pts.push((y.clone(), x.clone()));
564 if y < x {
565 pts.push((x, y));
566 }
567 }
568 pts
569}
570
571fn solve_sum_two_squares_scan(pool: &ExprPool, n: &Integer) -> DiophantineSolution {
572 let n = n.clone();
573 if n < 0 {
574 return DiophantineSolution::NoSolution;
575 }
576 if n == 0 {
577 let z = pool.integer(0);
578 return DiophantineSolution::Finite(vec![vec![z, z]]);
579 }
580 let mut pts: Vec<(Integer, Integer)> = Vec::new();
581 let mut x = Integer::from(0);
582 let max_x = n.clone().sqrt();
583 while x <= max_x {
584 let r = n.clone() - &x * &x;
585 if is_perfect_square(&r) {
586 let y = r.sqrt();
587 if x <= y {
588 pts.push((x.clone(), y.clone()));
589 if x < y {
590 pts.push((y.clone(), x.clone()));
591 }
592 }
593 }
594 x += 1;
595 }
596 if pts.is_empty() {
597 return DiophantineSolution::NoSolution;
598 }
599 let sols: Vec<Vec<ExprId>> = pts
600 .into_iter()
601 .map(|(xi, yi)| vec![pool.integer(xi), pool.integer(yi)])
602 .collect();
603 DiophantineSolution::Finite(sols)
604}
605
606fn solve_sum_two_squares(
607 pool: &ExprPool,
608 _a: &Integer,
609 n: &Integer,
610 _vx: ExprId,
611 _vy: ExprId,
612) -> DiophantineSolution {
613 let rep = sum_two_squares_representatives(n);
614 if !rep.is_empty() {
615 let sols: Vec<Vec<ExprId>> = rep
616 .into_iter()
617 .map(|(xi, yi)| vec![pool.integer(xi), pool.integer(yi)])
618 .collect();
619 return DiophantineSolution::Finite(sols);
620 }
621 solve_sum_two_squares_scan(pool, n)
623}
624
625#[allow(clippy::too_many_arguments)]
627fn sqrt_cf_step(
628 d: &Integer,
629 a0: &Integer,
630 m: &mut Integer,
631 d_cf: &mut Integer,
632 a: &mut Integer,
633 h_prev: &mut Integer,
634 k_prev: &mut Integer,
635 h: &mut Integer,
636 k: &mut Integer,
637) -> Option<()> {
638 *m = (&*d_cf * &*a - &*m).into();
639 let num = d.clone() - &*m * &*m;
640 *d_cf = div_exact(&num, d_cf)?;
641 if *d_cf == 0 {
642 return None;
643 }
644 let sum: Integer = (a0 + &*m).into();
645 *a = div_exact(&sum, d_cf)?;
646 let h_new: Integer = (&*a * &*h + &*h_prev).into();
647 let k_new: Integer = (&*a * &*k + &*k_prev).into();
648 *h_prev = h.clone();
649 *k_prev = k.clone();
650 *h = h_new;
651 *k = k_new;
652 Some(())
653}
654
655fn pell_norm(h: &Integer, k: &Integer, d: &Integer) -> Integer {
656 h.clone() * h - d.clone() * k * k
657}
658
659fn pell_fundamental_xy(d: &Integer) -> Option<(Integer, Integer)> {
661 pell_convergent_solution(d, &Integer::from(1))
662}
663
664fn pell_convergent_solution(d: &Integer, target: &Integer) -> Option<(Integer, Integer)> {
666 let d = d.clone();
667 if d <= 0 {
668 return None;
669 }
670 let (_, rem) = d.clone().sqrt_rem(Integer::new());
671 if rem == 0 {
672 return None;
673 }
674 let a0 = d.clone().sqrt();
675 let mut m = Integer::from(0);
676 let mut d_cf = Integer::from(1);
677 let mut a = a0.clone();
678
679 let mut h_prev = Integer::from(1);
680 let mut h = a0.clone();
681 let mut k_prev = Integer::from(0);
682 let mut k = Integer::from(1);
683
684 let max_steps = 500_000u64;
685 for _ in 0..max_steps {
686 let lhs = pell_norm(&h, &k, &d);
687 if lhs == *target {
688 return Some((h, k));
689 }
690 sqrt_cf_step(
691 &d,
692 &a0,
693 &mut m,
694 &mut d_cf,
695 &mut a,
696 &mut h_prev,
697 &mut k_prev,
698 &mut h,
699 &mut k,
700 )?;
701 }
702 None
703}
704
705fn pell_y_sweep(d: &Integer, target: &Integer) -> Option<(Integer, Integer)> {
707 let bound = Integer::from(2_000_000);
708 let mut y = Integer::from(0);
709 while y <= bound {
710 let rhs = target.clone() + d.clone() * &y * &y;
711 if rhs.cmp0().is_ge() && is_perfect_square(&rhs) {
712 let x = rhs.sqrt();
713 if pell_norm(&x, &y, d) == *target {
714 return Some((x, y));
715 }
716 }
717 y += 1;
718 }
719 None
720}
721
722fn solve_pell_like(
723 pool: &ExprPool,
724 pos: &Integer,
725 neg: &Integer,
726 rhs: &Integer,
727) -> Result<DiophantineSolution, DiophantineError> {
728 if *pos == 0 || *neg == 0 {
729 return Err(DiophantineError::Unsupported("degenerate quadratic".into()));
730 }
731 let g = pos.clone().gcd(neg).gcd(&rhs.clone().abs());
732 let p = div_exact(pos, &g).unwrap();
733 let nn = div_exact(neg, &g).unwrap();
734 let r = div_exact(rhs, &g).unwrap();
735 if r == 0 {
738 if let Some(s2) = div_exact(&nn, &p) {
740 if is_perfect_square(&s2) {
741 let s = s2.sqrt();
742 let t = pool.symbol("_t", Domain::Integer);
743 let x_e = pool.mul(vec![pool.integer(s), t]);
744 return Ok(DiophantineSolution::ParametricLinear {
745 parameter: t,
746 values: vec![x_e, t],
747 });
748 }
749 }
750 if let Some(t2) = div_exact(&p, &nn) {
751 if is_perfect_square(&t2) {
752 let tc = t2.sqrt();
753 let t = pool.symbol("_t", Domain::Integer);
754 let y_e = pool.mul(vec![pool.integer(tc), t]);
755 return Ok(DiophantineSolution::ParametricLinear {
756 parameter: t,
757 values: vec![t, y_e],
758 });
759 }
760 }
761 let z = pool.integer(0);
762 return Ok(DiophantineSolution::Finite(vec![vec![z, z]]));
763 }
764
765 let g2 = p.clone().gcd(&nn);
766 let (_, rem) = r.clone().div_rem_euc_ref(&g2).into();
767 if rem != 0 {
768 return Ok(DiophantineSolution::NoSolution);
769 }
770 let p2 = div_exact(&p, &g2).unwrap();
771 let n2 = div_exact(&nn, &g2).unwrap();
772 let r2 = div_exact(&r, &g2).unwrap();
773
774 if p2 != 1 {
775 return Err(DiophantineError::Unsupported(
776 "Pell-type equation must reduce to x² - D·y² = N (leading x² coefficient 1 after gcd)"
777 .into(),
778 ));
779 }
780
781 let (ux, uy) = match pell_fundamental_xy(&n2) {
782 Some(u) => u,
783 None => {
784 return Err(DiophantineError::Unsupported(
785 "no fundamental unit (D may be a perfect square)".into(),
786 ));
787 }
788 };
789
790 if r2 == 0 {
791 unreachable!("handled above");
792 }
793
794 if r2 == 1 {
795 return Ok(DiophantineSolution::PellFundamental {
796 d: pool.integer(n2),
797 x0: pool.integer(ux),
798 y0: pool.integer(uy),
799 });
800 }
801
802 let part = pell_convergent_solution(&n2, &r2)
803 .or_else(|| pell_y_sweep(&n2, &r2))
804 .ok_or(DiophantineError::NoSolution)?;
805
806 Ok(DiophantineSolution::PellGeneralized {
807 d: pool.integer(n2.clone()),
808 n: pool.integer(r2),
809 x0: pool.integer(part.0),
810 y0: pool.integer(part.1),
811 unit_x: pool.integer(ux),
812 unit_y: pool.integer(uy),
813 })
814}
815
816fn solve_linear_two_var(
817 pool: &ExprPool,
818 a: &Integer,
819 b: &Integer,
820 c: &Integer,
821 _vx: ExprId,
822 _vy: ExprId,
823) -> Result<DiophantineSolution, DiophantineError> {
824 let rhs = -c.clone();
825 let g = a.clone().gcd(b);
826 let (_, rem) = rhs.clone().div_rem_euc_ref(&g).into();
827 if rem != 0 {
828 return Ok(DiophantineSolution::NoSolution);
829 }
830 let (g0, u, v) = extended_gcd(a, b);
831 debug_assert_eq!(g0, g);
832 let a1 = div_exact(a, &g).unwrap();
833 let b1 = div_exact(b, &g).unwrap();
834 let rhs1 = div_exact(&rhs, &g).unwrap();
835 let x0 = &u * &rhs1;
836 let y0 = &v * &rhs1;
837 let t = pool.symbol("_t", Domain::Integer);
838 let bt = pool.mul(vec![pool.integer(b1.clone()), t]);
839 let neg_one = pool.integer(-1_i32);
840 let neg_at = pool.mul(vec![neg_one, pool.integer(a1.clone()), t]);
841 let xt = pool.add(vec![pool.integer(x0), bt]);
842 let yt = pool.add(vec![pool.integer(y0), neg_at]);
843 Ok(DiophantineSolution::ParametricLinear {
844 parameter: t,
845 values: vec![xt, yt],
846 })
847}
848
849fn classify_and_solve(
850 pool: &ExprPool,
851 terms: &BTreeMap<Vec<u32>, Integer>,
852 vars: &[ExprId],
853) -> Result<DiophantineSolution, DiophantineError> {
854 if vars.len() != 2 {
855 return Err(DiophantineError::Unsupported(
856 "exactly two variables are required".into(),
857 ));
858 }
859 let vx = vars[0];
860 let vy = vars[1];
861
862 let mut max_deg = 0u32;
863 for e in terms.keys() {
864 let tdeg: u32 = e.iter().sum();
865 max_deg = max_deg.max(tdeg);
866 }
867
868 if max_deg > 2 {
869 return Err(DiophantineError::Unsupported(
870 "degree > 2 is not supported".into(),
871 ));
872 }
873
874 if max_deg <= 1 {
875 let c00 = terms
876 .get(&vec![0, 0])
877 .cloned()
878 .unwrap_or_else(|| Integer::from(0));
879 let c10 = terms
880 .get(&vec![1, 0])
881 .cloned()
882 .unwrap_or_else(|| Integer::from(0));
883 let c01 = terms
884 .get(&vec![0, 1])
885 .cloned()
886 .unwrap_or_else(|| Integer::from(0));
887 if terms.len() > 3 {
888 return Err(DiophantineError::Unsupported(
889 "linear equation with unexpected monomials".into(),
890 ));
891 }
892 for e in terms.keys() {
893 let s: u32 = e.iter().sum();
894 if s > 1 {
895 return Err(DiophantineError::Unsupported(
896 "mixed-degree polynomial".into(),
897 ));
898 }
899 }
900 return solve_linear_two_var(pool, &c10, &c01, &c00, vx, vy);
901 }
902
903 let c20 = terms
904 .get(&vec![2, 0])
905 .cloned()
906 .unwrap_or_else(|| Integer::from(0));
907 let c11 = terms
908 .get(&vec![1, 1])
909 .cloned()
910 .unwrap_or_else(|| Integer::from(0));
911 let c02 = terms
912 .get(&vec![0, 2])
913 .cloned()
914 .unwrap_or_else(|| Integer::from(0));
915 let c10 = terms
916 .get(&vec![1, 0])
917 .cloned()
918 .unwrap_or_else(|| Integer::from(0));
919 let c01 = terms
920 .get(&vec![0, 1])
921 .cloned()
922 .unwrap_or_else(|| Integer::from(0));
923 let c00 = terms
924 .get(&vec![0, 0])
925 .cloned()
926 .unwrap_or_else(|| Integer::from(0));
927
928 if c10 != 0 || c01 != 0 || c11 != 0 {
929 return Err(DiophantineError::Unsupported(
930 "quadratic with linear or xy terms is not implemented".into(),
931 ));
932 }
933
934 let g_content = term_gcd(&[c20.clone(), c02.clone(), c00.clone()]);
935 if g_content == 0 {
936 return Err(DiophantineError::Unsupported("zero polynomial".into()));
937 }
938 let a2 = div_exact(&c20, &g_content).unwrap();
939 let b2 = div_exact(&c02, &g_content).unwrap();
940 let cc = div_exact(&c00, &g_content).unwrap();
941
942 if a2 == 0 && b2 == 0 {
943 return Err(DiophantineError::Unsupported("no quadratic terms".into()));
944 }
945
946 if (a2 > 0 && b2 > 0) || (a2 < 0 && b2 < 0) {
947 if a2 != b2 {
948 return Err(DiophantineError::Unsupported(
949 "x² and y² must have equal coefficients for the ellipse case".into(),
950 ));
951 }
952 let a_abs = a2.clone().abs();
953 let (_, rem) = cc.clone().div_rem_euc_ref(&a_abs).into();
954 if rem != 0 {
955 return Ok(DiophantineSolution::NoSolution);
956 }
957 let n = -cc / &a_abs;
958 return Ok(solve_sum_two_squares(pool, &a_abs, &n, vx, vy));
959 }
960
961 if (a2 > 0 && b2 < 0) || (a2 < 0 && b2 > 0) {
962 let pos = if a2 > 0 { a2.clone() } else { b2.clone().abs() };
963 let neg = if a2 > 0 {
964 b2.clone().abs()
965 } else {
966 a2.clone().abs()
967 };
968 let rhs = -cc;
969
970 if rhs == 0 {
971 let (_, remd) = neg.clone().sqrt_rem(Integer::new());
972 if remd != 0 {
973 let z = pool.integer(0);
974 return Ok(DiophantineSolution::Finite(vec![vec![z, z]]));
975 }
976 let s = neg.sqrt();
977 let t = pool.symbol("_t", Domain::Integer);
978 let st = pool.mul(vec![pool.integer(s), t]);
979 return Ok(DiophantineSolution::ParametricLinear {
980 parameter: t,
981 values: vec![st, t],
982 });
983 }
984
985 return solve_pell_like(pool, &pos, &neg, &rhs);
986 }
987
988 Err(DiophantineError::Unsupported(
989 "unrecognized binary quadratic shape".into(),
990 ))
991}
992
993pub fn diophantine(
995 pool: &ExprPool,
996 equation: ExprId,
997 vars: &[ExprId],
998) -> Result<DiophantineSolution, DiophantineError> {
999 if vars.len() != 2 {
1000 return Err(DiophantineError::Unsupported(
1001 "exactly two variables are required".into(),
1002 ));
1003 }
1004 let poly = expr_to_gbpoly(equation, vars, pool)?;
1005 let int_terms = gbpoly_integer_coeffs(&poly)?;
1006 for c in poly.terms.values() {
1007 if !c.is_integer() {
1008 return Err(DiophantineError::NonIntegerCoefficients);
1009 }
1010 }
1011 classify_and_solve(pool, &int_terms, vars)
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017 use crate::kernel::{ExprData, ExprPool};
1018
1019 #[test]
1020 fn linear_3x_5y_1() {
1021 let pool = ExprPool::new();
1022 let x = pool.symbol("x", Domain::Integer);
1023 let y = pool.symbol("y", Domain::Integer);
1024 let eq = pool.add(vec![
1025 pool.mul(vec![pool.integer(3), x]),
1026 pool.mul(vec![pool.integer(5), y]),
1027 pool.integer(-1),
1028 ]);
1029 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1030 match r {
1031 DiophantineSolution::ParametricLinear { .. } => {}
1032 _ => panic!("expected parametric linear"),
1033 }
1034 }
1035
1036 #[test]
1037 fn pell_x2_2y2_1() {
1038 let pool = ExprPool::new();
1039 let x = pool.symbol("x", Domain::Integer);
1040 let y = pool.symbol("y", Domain::Integer);
1041 let x2 = pool.pow(x, pool.integer(2));
1042 let y2 = pool.pow(y, pool.integer(2));
1043 let eq = pool.add(vec![
1044 x2,
1045 pool.mul(vec![pool.integer(-2), y2]),
1046 pool.integer(-1),
1047 ]);
1048 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1049 match r {
1050 DiophantineSolution::PellFundamental { x0, y0, .. } => {
1051 assert!(pool.with(x0, |d| matches!(d, ExprData::Integer(n) if n.0 == 3)));
1052 assert!(pool.with(y0, |d| matches!(d, ExprData::Integer(n) if n.0 == 2)));
1053 }
1054 _ => panic!("expected Pell fundamental"),
1055 }
1056 }
1057
1058 #[test]
1059 fn sum_squares_5() {
1060 let pool = ExprPool::new();
1061 let x = pool.symbol("x", Domain::Integer);
1062 let y = pool.symbol("y", Domain::Integer);
1063 let eq = pool.add(vec![
1064 pool.pow(x, pool.integer(2)),
1065 pool.pow(y, pool.integer(2)),
1066 pool.integer(-5),
1067 ]);
1068 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1069 match r {
1070 DiophantineSolution::Finite(v) => {
1071 assert_eq!(v.len(), 2);
1072 }
1073 _ => panic!("expected finite set"),
1074 }
1075 }
1076
1077 #[test]
1078 fn sum_squares_65_two_orbits() {
1079 let pool = ExprPool::new();
1081 let x = pool.symbol("x", Domain::Integer);
1082 let y = pool.symbol("y", Domain::Integer);
1083 let eq = pool.add(vec![
1084 pool.pow(x, pool.integer(2)),
1085 pool.pow(y, pool.integer(2)),
1086 pool.integer(-65),
1087 ]);
1088 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1089 match r {
1090 DiophantineSolution::Finite(v) => {
1091 let sets: std::collections::HashSet<(i32, i32)> = v
1092 .iter()
1093 .map(|row| {
1094 let xi = match pool.get(row[0]) {
1095 ExprData::Integer(i) => i.0.to_i32().unwrap(),
1096 _ => panic!(),
1097 };
1098 let yi = match pool.get(row[1]) {
1099 ExprData::Integer(i) => i.0.to_i32().unwrap(),
1100 _ => panic!(),
1101 };
1102 (xi, yi)
1103 })
1104 .collect();
1105 assert!(sets.contains(&(1, 8)));
1106 assert!(sets.contains(&(8, 1)));
1107 assert!(sets.contains(&(4, 7)));
1108 assert!(sets.contains(&(7, 4)));
1109 }
1110 _ => panic!("expected finite set"),
1111 }
1112 }
1113
1114 #[test]
1115 fn pell_generalized_n_minus1() {
1116 let pool = ExprPool::new();
1118 let x = pool.symbol("x", Domain::Integer);
1119 let y = pool.symbol("y", Domain::Integer);
1120 let eq = pool.add(vec![
1121 pool.pow(x, pool.integer(2)),
1122 pool.mul(vec![pool.integer(-2), pool.pow(y, pool.integer(2))]),
1123 pool.integer(1),
1124 ]);
1125 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1126 match r {
1127 DiophantineSolution::PellGeneralized { .. } => {}
1128 DiophantineSolution::PellFundamental { .. } => {
1129 }
1131 _ => panic!("expected Pell generalized or fundamental: {:?}", r),
1132 }
1133 }
1134
1135 #[test]
1136 fn linear_no_solution() {
1137 let pool = ExprPool::new();
1138 let x = pool.symbol("x", Domain::Integer);
1139 let y = pool.symbol("y", Domain::Integer);
1140 let eq = pool.add(vec![
1141 pool.mul(vec![pool.integer(2), x]),
1142 pool.mul(vec![pool.integer(4), y]),
1143 pool.integer(1),
1144 ]);
1145 let r = diophantine(&pool, eq, &[x, y]).unwrap();
1146 assert!(matches!(r, DiophantineSolution::NoSolution));
1147 }
1148
1149 #[test]
1150 fn cornacchia_prime_13() {
1151 let p = Integer::from(13);
1152 let r = prime_as_sum_two_squares(&p).unwrap();
1153 assert_eq!(r.0.clone() * &r.0 + r.1.clone() * &r.1, p);
1154 }
1155}