Skip to main content

alkahest_cas/solver/
diophantine.rs

1//! Diophantine equations — linear parametric families and binary quadratics.
2//!
3//! ## Sum of two squares
4//!
5//! For `x² + y² = n` with `n ≥ 0`, factor `n` and use **Cornacchia** on primes `p ≡ 1 (mod 4)`,
6//! then **compose** representations via the Brahmagupta–Fibonacci identity.
7//! When factorization is impractical (very large `n`), falls back to scanning `x ≤ √n`.
8//!
9//! ## Generalized Pell
10//!
11//! `x² - D·y² = N` with `D > 0` non-square: search **continued-fraction convergents** of `√D`,
12//! then a bounded `y`-sweep `N + D·y² = □`.  Solutions multiply by the unit `u² - D·v² = 1`.
13//! `N = 0`: trivial `(0,0)` if `D` is non-square; if `D = s²`, a parametric line `x = s·t`, `y = t`.
14
15use crate::errors::AlkahestError;
16use crate::kernel::{Domain, ExprId, ExprPool};
17use crate::poly::groebner::ideal::GbPoly;
18use rug::ops::Pow;
19use rug::Integer;
20use std::collections::BTreeMap;
21use std::fmt;
22
23use super::{expr_to_gbpoly, SolverError};
24
25/// Errors from [`diophantine`].
26#[derive(Debug, Clone)]
27pub enum DiophantineError {
28    /// Equation is not a polynomial in the listed variables.
29    NotPolynomial(String),
30    /// Coefficients are not rational integers (even after clearing denominators).
31    NonIntegerCoefficients,
32    /// Equation degree or term pattern is not handled.
33    Unsupported(String),
34    /// No integer solutions exist for this instance.
35    NoSolution,
36}
37
38impl fmt::Display for DiophantineError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            DiophantineError::NotPolynomial(s) => write!(f, "diophantine: {s}"),
42            DiophantineError::NonIntegerCoefficients => {
43                write!(f, "diophantine: coefficients must be rational integers")
44            }
45            DiophantineError::Unsupported(s) => write!(f, "diophantine: unsupported: {s}"),
46            DiophantineError::NoSolution => write!(f, "diophantine: no integer solution"),
47        }
48    }
49}
50
51impl std::error::Error for DiophantineError {}
52
53impl AlkahestError for DiophantineError {
54    fn code(&self) -> &'static str {
55        match self {
56            DiophantineError::NotPolynomial(_) => "E-DIOPH-001",
57            DiophantineError::NonIntegerCoefficients => "E-DIOPH-002",
58            DiophantineError::Unsupported(_) => "E-DIOPH-003",
59            DiophantineError::NoSolution => "E-DIOPH-004",
60        }
61    }
62
63    fn remediation(&self) -> Option<&'static str> {
64        match self {
65            DiophantineError::NotPolynomial(_) => Some(
66                "pass a single polynomial equation in the listed symbols with integer/rational coefficients",
67            ),
68            DiophantineError::NonIntegerCoefficients => Some(
69                "rewrite so all coefficients are integers (no fractional parameters)",
70            ),
71            DiophantineError::Unsupported(_) => Some(
72                "supported: linear two-variable, x²+y²=n, x²−D·y²=N (no xy term); huge integers may need a smaller instance",
73            ),
74            DiophantineError::NoSolution => Some(
75                "check divisibility for linear equations; for quadratics verify solvability over ℤ",
76            ),
77        }
78    }
79}
80
81impl From<SolverError> for DiophantineError {
82    fn from(e: SolverError) -> Self {
83        DiophantineError::NotPolynomial(e.to_string())
84    }
85}
86
87/// Result of [`diophantine`].
88#[derive(Debug, Clone)]
89pub enum DiophantineSolution {
90    /// `a·x + b·y + … = 0`: values are `x(t)`, `y(t)`, … in the same order as `vars`,
91    /// with integer parameter `t`.
92    ParametricLinear {
93        parameter: ExprId,
94        values: Vec<ExprId>,
95    },
96    /// Explicit list of integer tuples (each parallel to `vars`).
97    Finite(Vec<Vec<ExprId>>),
98    /// `x² - D·y² = 1`: fundamental unit `(x0, y0)`; all solutions via
99    /// `(x0 + y0√D)^k`, `k ∈ ℤ`.
100    PellFundamental { d: ExprId, x0: ExprId, y0: ExprId },
101    /// `x² - D·y² = N` with `N ≠ 1`: minimal found pair `(x0, y0)` and unit `(ux, uy)` with
102    /// `ux² - D·uy² = 1`.  All solutions satisfy
103    /// `x + y√D = (x0 + y0√D)·(ux + uy√D)^k`, `k ∈ ℤ`.
104    PellGeneralized {
105        d: ExprId,
106        n: ExprId,
107        x0: ExprId,
108        y0: ExprId,
109        unit_x: ExprId,
110        unit_y: ExprId,
111    },
112    /// No integer solutions.
113    NoSolution,
114}
115
116fn lcm_rational_denominators(poly: &GbPoly) -> Integer {
117    let mut l = Integer::from(1);
118    for c in poly.terms.values() {
119        let den: Integer = c.denom().into();
120        l = l.lcm(&den);
121    }
122    l
123}
124
125fn gbpoly_integer_coeffs(poly: &GbPoly) -> Result<BTreeMap<Vec<u32>, Integer>, DiophantineError> {
126    let scale = lcm_rational_denominators(poly);
127    let mut out = BTreeMap::new();
128    for (e, c) in &poly.terms {
129        let num: Integer = c.numer().into();
130        let den: Integer = c.denom().into();
131        let prod = num * &scale;
132        let scaled = div_exact(&prod, &den).ok_or(DiophantineError::NonIntegerCoefficients)?;
133        if scaled != 0 {
134            out.insert(e.clone(), scaled);
135        }
136    }
137    Ok(out)
138}
139
140fn term_gcd(iv: &[Integer]) -> Integer {
141    let mut g = iv.first().cloned().unwrap_or_else(|| Integer::from(0));
142    for x in iv.iter().skip(1) {
143        g = g.gcd(x);
144    }
145    g
146}
147
148fn div_exact(a: &Integer, g: &Integer) -> Option<Integer> {
149    let (q, r) = a.clone().div_rem_euc_ref(g).into();
150    if r == 0 {
151        Some(q)
152    } else {
153        None
154    }
155}
156
157/// Extended gcd: `(g, u, v)` with `u·a + v·b = g = gcd(a,b)`.
158fn extended_gcd(a: &Integer, b: &Integer) -> (Integer, Integer, Integer) {
159    let mut old_r = a.clone();
160    let mut r = b.clone();
161    let mut old_s = Integer::from(1);
162    let mut s = Integer::from(0);
163    let mut old_t = Integer::from(0);
164    let mut t = Integer::from(1);
165    while r != 0 {
166        let q = old_r.clone() / &r;
167        let mut tmp = old_r - &q * &r;
168        old_r = r;
169        r = tmp;
170        tmp = old_s - &q * &s;
171        old_s = s;
172        s = tmp;
173        tmp = old_t - &q * &t;
174        old_t = t;
175        t = tmp;
176    }
177    (old_r, old_s, old_t)
178}
179
180/// `(a²+b²)(c²+d²) = (ac−bd)² + (ad+bc)²`
181fn compose_sum_sq(x: &Integer, y: &Integer, c: &Integer, d: &Integer) -> (Integer, Integer) {
182    let nx: Integer = x.clone() * c - y.clone() * d;
183    let ny: Integer = x.clone() * d + y.clone() * c;
184    (nx, ny)
185}
186
187fn is_perfect_square(n: &Integer) -> bool {
188    if n.cmp0().is_lt() {
189        return false;
190    }
191    let (_, r) = n.clone().sqrt_rem(Integer::new());
192    r == 0
193}
194
195/// Legendre symbol (a / p) for odd prime p, a not divisible by p → ±1.
196fn legendre(a: &Integer, p: &Integer) -> i32 {
197    let exp = (p.clone() - 1) / 2;
198    let ls = a
199        .clone()
200        .pow_mod(&exp, p)
201        .unwrap_or_else(|_| Integer::from(0));
202    if ls == 1 {
203        1
204    } else if ls == p.clone() - 1 {
205        -1
206    } else {
207        0
208    }
209}
210
211/// Tonelli–Shanks: square root of `n` mod odd prime `p` (when it exists).
212fn tonelli_shanks(n: &Integer, p: &Integer) -> Option<Integer> {
213    let (_, rrem) = n.clone().div_rem_euc_ref(p).into();
214    if rrem == 0 {
215        return Some(Integer::from(0));
216    }
217    if legendre(n, p) != 1 {
218        return None;
219    }
220    if p.clone() % 4u32 == 3 {
221        let exp = (p.clone() + 1) / 4;
222        return n.clone().pow_mod(&exp, p).ok();
223    }
224
225    let mut q: Integer = p.clone() - Integer::from(1);
226    let mut s = 0u32;
227    while q.clone() % 2u32 == 0 {
228        q /= 2u32;
229        s += 1;
230    }
231
232    let mut z = Integer::from(2);
233    while legendre(&z, p) != -1 {
234        z += 1;
235        if z >= *p {
236            return None;
237        }
238    }
239
240    let mut m = s;
241    let mut c = z.clone().pow_mod(&q, p).ok()?;
242    let mut t = n.clone().pow_mod(&q, p).ok()?;
243    let mut r = n.clone().pow_mod(&((q.clone() + 1) / 2), p).ok()?;
244
245    while t != 1 {
246        let mut i = 0u32;
247        let mut tt = t.clone();
248        while tt != 1 {
249            tt = (tt.clone() * &tt) % p;
250            i += 1;
251            if i > m {
252                return None;
253            }
254        }
255        let exp = m - i - 1;
256        let two_exp = Integer::from(1) << exp;
257        let b = c.clone().pow_mod(&two_exp, p).ok()?;
258        r = (r.clone() * &b) % p;
259        t = (t * &b * &b) % p;
260        c = (b.clone() * &b) % p;
261        m = i;
262    }
263    Some(r)
264}
265
266/// Cornacchia: `x² + d·y² = p` for odd prime `p`, `gcd(d,p)=1`, `(−d/p)=1`.
267/// Returns `(x, y)` with `x, y ≥ 0`.
268fn cornacchia_prime(d: &Integer, p: &Integer) -> Option<(Integer, Integer)> {
269    if *p == 2 {
270        if *d == 1 {
271            return Some((Integer::from(1), Integer::from(1)));
272        }
273        return None;
274    }
275    if p.clone() % 2 == 0 {
276        return None;
277    }
278
279    // (−d / p) = 1
280    let negd = (p.clone() - (d.clone() % p)) % p;
281    if legendre(&negd, p) != 1 {
282        return None;
283    }
284
285    let mut r0 = tonelli_shanks(&negd, p)?;
286    if r0.clone() > p.clone() / 2 {
287        r0 = p.clone() - &r0;
288    }
289
290    let mut r = p.clone();
291    let mut s = r0;
292    while s.clone() * &s > *p {
293        let rem = r.clone() % &s;
294        r = s;
295        s = rem;
296    }
297
298    let diff = p.clone() - &s * &s;
299    if diff.cmp0().is_lt() {
300        return None;
301    }
302    let q = div_exact(&diff, d)?;
303    let (_, rr) = q.clone().sqrt_rem(Integer::new());
304    if rr != 0 {
305        return None;
306    }
307    let y = q.sqrt();
308    Some((s, y))
309}
310
311/// `x² + y² = p` for prime `p`.
312fn prime_as_sum_two_squares(p: &Integer) -> Option<(Integer, Integer)> {
313    cornacchia_prime(&Integer::from(1), p)
314}
315
316fn pollard_step(g: &Integer, c: &Integer, x: &Integer) -> Integer {
317    (x.clone() * x + c) % g
318}
319
320/// One nontrivial factor of composite `n` (not necessarily prime).
321fn pollard_rho_factor(n: &Integer) -> Option<Integer> {
322    if n <= &Integer::from(3) || is_probable_prime(n) {
323        return None;
324    }
325    let mut x = Integer::from(2);
326    let mut y = Integer::from(2);
327    let mut d = Integer::from(1);
328    let c = Integer::from(1);
329    while d == 1 {
330        x = pollard_step(n, &c, &x);
331        y = pollard_step(n, &c, &pollard_step(n, &c, &y));
332        let diff = if x.clone() >= y {
333            x.clone() - &y
334        } else {
335            y.clone() - &x
336        };
337        d = diff.gcd(n);
338        if d == *n {
339            return None;
340        }
341    }
342    if d > 1 && d < *n {
343        Some(d)
344    } else {
345        None
346    }
347}
348
349/// Deterministic probable-prime (Miller–Rabin with small bases) for odd `n > 2`.
350fn is_probable_prime(n: &Integer) -> bool {
351    if n <= &Integer::from(1) {
352        return false;
353    }
354    if n <= &Integer::from(3) {
355        return true;
356    }
357    if n.clone() % 2u32 == 0 {
358        return false;
359    }
360    n.is_probably_prime(40) != rug::integer::IsPrime::No
361}
362
363/// Distinct prime factors with multiplicity, `n ≥ 2`.
364fn factor_positive(mut n: Integer) -> Vec<(Integer, u32)> {
365    let mut fac: Vec<(Integer, u32)> = Vec::new();
366
367    let push_pow = |fac: &mut Vec<(Integer, u32)>, p: Integer, e: u32| {
368        if e > 0 {
369            fac.push((p, e));
370        }
371    };
372
373    let small: [u32; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
374    for &pr in &small {
375        let p = Integer::from(pr);
376        if n <= 1 {
377            break;
378        }
379        let mut e = 0u32;
380        while n.clone() % &p == 0 {
381            n /= &p;
382            e += 1;
383        }
384        push_pow(&mut fac, p, e);
385    }
386
387    let mut stack: Vec<Integer> = Vec::new();
388    if n > 1 {
389        stack.push(n);
390    }
391    let mut prime_parts: Vec<Integer> = Vec::new();
392    while let Some(m) = stack.pop() {
393        if m <= 1 {
394            continue;
395        }
396        if is_probable_prime(&m) {
397            prime_parts.push(m);
398            continue;
399        }
400        let mut split = None;
401        for _ in 0..16 {
402            if let Some(d) = pollard_rho_factor(&m) {
403                let other = m.clone() / &d;
404                split = Some((d, other));
405                break;
406            }
407        }
408        if let Some((d, other)) = split {
409            stack.push(d);
410            stack.push(other);
411        } else {
412            prime_parts.push(m);
413        }
414    }
415
416    prime_parts.sort();
417    let mut i = 0usize;
418    while i < prime_parts.len() {
419        let p = prime_parts[i].clone();
420        let mut e = 0u32;
421        while i < prime_parts.len() && prime_parts[i] == p {
422            e += 1;
423            i += 1;
424        }
425        push_pow(&mut fac, p, e);
426    }
427
428    fac
429}
430
431fn scan_sum_two_squares_pairs(n: &Integer) -> Vec<(Integer, Integer)> {
432    let mut pts: Vec<(Integer, Integer)> = Vec::new();
433    let mut x = Integer::from(0);
434    let max_x = n.clone().sqrt();
435    while x <= max_x {
436        let r = n.clone() - &x * &x;
437        if is_perfect_square(&r) {
438            let y = r.sqrt();
439            if x <= y {
440                pts.push((x.clone(), y.clone()));
441                if x < y {
442                    pts.push((y.clone(), x.clone()));
443                }
444            }
445        }
446        x += 1;
447    }
448    pts
449}
450
451fn merge_distinct_pairs(acc: &mut Vec<(Integer, Integer)>, more: Vec<(Integer, Integer)>) {
452    use std::collections::BTreeSet;
453    let mut seen: BTreeSet<String> = acc.iter().map(|(a, b)| format!("{a},{b}")).collect();
454    for (x, y) in more {
455        let k = format!("{x},{y}");
456        if seen.insert(k) {
457            acc.push((x, y));
458        }
459    }
460}
461
462/// Ordered pairs `(x,y)` with `x,y ≥ 0` and `x² + y² = n`: one orbit from Cornacchia composition,
463/// plus any further orbits found by a bounded scan when `n` is moderate (bit size ≤ 256).
464fn sum_two_squares_representatives(n: &Integer) -> Vec<(Integer, Integer)> {
465    if n.cmp0().is_lt() {
466        return vec![];
467    }
468    if *n == 0 {
469        return vec![(Integer::from(0), Integer::from(0))];
470    }
471
472    if n.significant_bits() > 4000 {
473        return vec![];
474    }
475
476    let mut rest = n.clone();
477    let mut e2 = 0u32;
478    while rest.clone() % 2u32 == 0 {
479        rest /= 2u32;
480        e2 += 1;
481    }
482
483    if rest == 1 {
484        // n = 2^e2
485        let mut x = Integer::from(1);
486        let mut y = Integer::from(0);
487        for _ in 0..e2 {
488            let c = compose_sum_sq(&x, &y, &Integer::from(1), &Integer::from(1));
489            x = c.0;
490            y = c.1;
491        }
492        return canonical_pairs(x, y);
493    }
494
495    let facs = factor_positive(rest);
496    for (p, e) in &facs {
497        let m4 = p.clone() % 4;
498        if m4 == 3 && e % 2 == 1 {
499            return vec![];
500        }
501    }
502
503    let mut xr = Integer::from(1);
504    let mut yr = Integer::from(0);
505    for (p, e) in facs {
506        let m4 = p.clone() % 4;
507        if m4 == 3 {
508            debug_assert!(e % 2 == 0);
509            let half = e / 2;
510            let pk = p.clone().pow(half);
511            xr *= &pk;
512            yr *= &pk;
513            continue;
514        }
515        if p == 2 {
516            for _ in 0..e {
517                let c = compose_sum_sq(&xr, &yr, &Integer::from(1), &Integer::from(1));
518                xr = c.0;
519                yr = c.1;
520            }
521            continue;
522        }
523        // p ≡ 1 (mod 4)
524        let (up, vp) = match prime_as_sum_two_squares(&p) {
525            Some(t) => t,
526            None => return vec![],
527        };
528        let mut xq = Integer::from(1);
529        let mut yq = Integer::from(0);
530        for _ in 0..e {
531            let c = compose_sum_sq(&xq, &yq, &up, &vp);
532            xq = c.0;
533            yq = c.1;
534        }
535        let c = compose_sum_sq(&xr, &yr, &xq, &yq);
536        xr = c.0;
537        yr = c.1;
538    }
539
540    for _ in 0..e2 {
541        let c = compose_sum_sq(&xr, &yr, &Integer::from(1), &Integer::from(1));
542        xr = c.0;
543        yr = c.1;
544    }
545
546    let mut out = canonical_pairs(xr, yr);
547    if n.significant_bits() <= 256 {
548        merge_distinct_pairs(&mut out, scan_sum_two_squares_pairs(n));
549    }
550    out
551}
552
553fn canonical_pairs(x: Integer, y: Integer) -> Vec<(Integer, Integer)> {
554    let x = x.abs();
555    let y = y.abs();
556    let mut pts = Vec::new();
557    if x <= y {
558        pts.push((x.clone(), y.clone()));
559        if x < y {
560            pts.push((y, x));
561        }
562    } else {
563        pts.push((y.clone(), x.clone()));
564        if y < x {
565            pts.push((x, y));
566        }
567    }
568    pts
569}
570
571fn solve_sum_two_squares_scan(pool: &ExprPool, n: &Integer) -> DiophantineSolution {
572    let n = n.clone();
573    if n < 0 {
574        return DiophantineSolution::NoSolution;
575    }
576    if n == 0 {
577        let z = pool.integer(0);
578        return DiophantineSolution::Finite(vec![vec![z, z]]);
579    }
580    let mut pts: Vec<(Integer, Integer)> = Vec::new();
581    let mut x = Integer::from(0);
582    let max_x = n.clone().sqrt();
583    while x <= max_x {
584        let r = n.clone() - &x * &x;
585        if is_perfect_square(&r) {
586            let y = r.sqrt();
587            if x <= y {
588                pts.push((x.clone(), y.clone()));
589                if x < y {
590                    pts.push((y.clone(), x.clone()));
591                }
592            }
593        }
594        x += 1;
595    }
596    if pts.is_empty() {
597        return DiophantineSolution::NoSolution;
598    }
599    let sols: Vec<Vec<ExprId>> = pts
600        .into_iter()
601        .map(|(xi, yi)| vec![pool.integer(xi), pool.integer(yi)])
602        .collect();
603    DiophantineSolution::Finite(sols)
604}
605
606fn solve_sum_two_squares(
607    pool: &ExprPool,
608    _a: &Integer,
609    n: &Integer,
610    _vx: ExprId,
611    _vy: ExprId,
612) -> DiophantineSolution {
613    let rep = sum_two_squares_representatives(n);
614    if !rep.is_empty() {
615        let sols: Vec<Vec<ExprId>> = rep
616            .into_iter()
617            .map(|(xi, yi)| vec![pool.integer(xi), pool.integer(yi)])
618            .collect();
619        return DiophantineSolution::Finite(sols);
620    }
621    // Fallback when factorization failed or n has no two-square representation.
622    solve_sum_two_squares_scan(pool, n)
623}
624
625/// One step of continued fraction for `√d`; updates `(h,k)` convergents.
626#[allow(clippy::too_many_arguments)]
627fn sqrt_cf_step(
628    d: &Integer,
629    a0: &Integer,
630    m: &mut Integer,
631    d_cf: &mut Integer,
632    a: &mut Integer,
633    h_prev: &mut Integer,
634    k_prev: &mut Integer,
635    h: &mut Integer,
636    k: &mut Integer,
637) -> Option<()> {
638    *m = (&*d_cf * &*a - &*m).into();
639    let num = d.clone() - &*m * &*m;
640    *d_cf = div_exact(&num, d_cf)?;
641    if *d_cf == 0 {
642        return None;
643    }
644    let sum: Integer = (a0 + &*m).into();
645    *a = div_exact(&sum, d_cf)?;
646    let h_new: Integer = (&*a * &*h + &*h_prev).into();
647    let k_new: Integer = (&*a * &*k + &*k_prev).into();
648    *h_prev = h.clone();
649    *k_prev = k.clone();
650    *h = h_new;
651    *k = k_new;
652    Some(())
653}
654
655fn pell_norm(h: &Integer, k: &Integer, d: &Integer) -> Integer {
656    h.clone() * h - d.clone() * k * k
657}
658
659/// Minimal positive solution to `x² - d·y² = 1` (`d` non-square), via convergents.
660fn pell_fundamental_xy(d: &Integer) -> Option<(Integer, Integer)> {
661    pell_convergent_solution(d, &Integer::from(1))
662}
663
664/// Some `(x, y)` with `x² - d·y² = target` if found among convergents or a bounded search.
665fn pell_convergent_solution(d: &Integer, target: &Integer) -> Option<(Integer, Integer)> {
666    let d = d.clone();
667    if d <= 0 {
668        return None;
669    }
670    let (_, rem) = d.clone().sqrt_rem(Integer::new());
671    if rem == 0 {
672        return None;
673    }
674    let a0 = d.clone().sqrt();
675    let mut m = Integer::from(0);
676    let mut d_cf = Integer::from(1);
677    let mut a = a0.clone();
678
679    let mut h_prev = Integer::from(1);
680    let mut h = a0.clone();
681    let mut k_prev = Integer::from(0);
682    let mut k = Integer::from(1);
683
684    let max_steps = 500_000u64;
685    for _ in 0..max_steps {
686        let lhs = pell_norm(&h, &k, &d);
687        if lhs == *target {
688            return Some((h, k));
689        }
690        sqrt_cf_step(
691            &d,
692            &a0,
693            &mut m,
694            &mut d_cf,
695            &mut a,
696            &mut h_prev,
697            &mut k_prev,
698            &mut h,
699            &mut k,
700        )?;
701    }
702    None
703}
704
705/// Try `x² = target + d·y²` for increasing `y`.
706fn pell_y_sweep(d: &Integer, target: &Integer) -> Option<(Integer, Integer)> {
707    let bound = Integer::from(2_000_000);
708    let mut y = Integer::from(0);
709    while y <= bound {
710        let rhs = target.clone() + d.clone() * &y * &y;
711        if rhs.cmp0().is_ge() && is_perfect_square(&rhs) {
712            let x = rhs.sqrt();
713            if pell_norm(&x, &y, d) == *target {
714                return Some((x, y));
715            }
716        }
717        y += 1;
718    }
719    None
720}
721
722fn solve_pell_like(
723    pool: &ExprPool,
724    pos: &Integer,
725    neg: &Integer,
726    rhs: &Integer,
727) -> Result<DiophantineSolution, DiophantineError> {
728    if *pos == 0 || *neg == 0 {
729        return Err(DiophantineError::Unsupported("degenerate quadratic".into()));
730    }
731    let g = pos.clone().gcd(neg).gcd(&rhs.clone().abs());
732    let p = div_exact(pos, &g).unwrap();
733    let nn = div_exact(neg, &g).unwrap();
734    let r = div_exact(rhs, &g).unwrap();
735    // p·X² - nn·Y² = r
736
737    if r == 0 {
738        // p·X² = nn·Y²: if nn/p or p/nn is a perfect square, parametrize; else only (0,0).
739        if let Some(s2) = div_exact(&nn, &p) {
740            if is_perfect_square(&s2) {
741                let s = s2.sqrt();
742                let t = pool.symbol("_t", Domain::Integer);
743                let x_e = pool.mul(vec![pool.integer(s), t]);
744                return Ok(DiophantineSolution::ParametricLinear {
745                    parameter: t,
746                    values: vec![x_e, t],
747                });
748            }
749        }
750        if let Some(t2) = div_exact(&p, &nn) {
751            if is_perfect_square(&t2) {
752                let tc = t2.sqrt();
753                let t = pool.symbol("_t", Domain::Integer);
754                let y_e = pool.mul(vec![pool.integer(tc), t]);
755                return Ok(DiophantineSolution::ParametricLinear {
756                    parameter: t,
757                    values: vec![t, y_e],
758                });
759            }
760        }
761        let z = pool.integer(0);
762        return Ok(DiophantineSolution::Finite(vec![vec![z, z]]));
763    }
764
765    let g2 = p.clone().gcd(&nn);
766    let (_, rem) = r.clone().div_rem_euc_ref(&g2).into();
767    if rem != 0 {
768        return Ok(DiophantineSolution::NoSolution);
769    }
770    let p2 = div_exact(&p, &g2).unwrap();
771    let n2 = div_exact(&nn, &g2).unwrap();
772    let r2 = div_exact(&r, &g2).unwrap();
773
774    if p2 != 1 {
775        return Err(DiophantineError::Unsupported(
776            "Pell-type equation must reduce to x² - D·y² = N (leading x² coefficient 1 after gcd)"
777                .into(),
778        ));
779    }
780
781    let (ux, uy) = match pell_fundamental_xy(&n2) {
782        Some(u) => u,
783        None => {
784            return Err(DiophantineError::Unsupported(
785                "no fundamental unit (D may be a perfect square)".into(),
786            ));
787        }
788    };
789
790    if r2 == 0 {
791        unreachable!("handled above");
792    }
793
794    if r2 == 1 {
795        return Ok(DiophantineSolution::PellFundamental {
796            d: pool.integer(n2),
797            x0: pool.integer(ux),
798            y0: pool.integer(uy),
799        });
800    }
801
802    let part = pell_convergent_solution(&n2, &r2)
803        .or_else(|| pell_y_sweep(&n2, &r2))
804        .ok_or(DiophantineError::NoSolution)?;
805
806    Ok(DiophantineSolution::PellGeneralized {
807        d: pool.integer(n2.clone()),
808        n: pool.integer(r2),
809        x0: pool.integer(part.0),
810        y0: pool.integer(part.1),
811        unit_x: pool.integer(ux),
812        unit_y: pool.integer(uy),
813    })
814}
815
816fn solve_linear_two_var(
817    pool: &ExprPool,
818    a: &Integer,
819    b: &Integer,
820    c: &Integer,
821    _vx: ExprId,
822    _vy: ExprId,
823) -> Result<DiophantineSolution, DiophantineError> {
824    let rhs = -c.clone();
825    let g = a.clone().gcd(b);
826    let (_, rem) = rhs.clone().div_rem_euc_ref(&g).into();
827    if rem != 0 {
828        return Ok(DiophantineSolution::NoSolution);
829    }
830    let (g0, u, v) = extended_gcd(a, b);
831    debug_assert_eq!(g0, g);
832    let a1 = div_exact(a, &g).unwrap();
833    let b1 = div_exact(b, &g).unwrap();
834    let rhs1 = div_exact(&rhs, &g).unwrap();
835    let x0 = &u * &rhs1;
836    let y0 = &v * &rhs1;
837    let t = pool.symbol("_t", Domain::Integer);
838    let bt = pool.mul(vec![pool.integer(b1.clone()), t]);
839    let neg_one = pool.integer(-1_i32);
840    let neg_at = pool.mul(vec![neg_one, pool.integer(a1.clone()), t]);
841    let xt = pool.add(vec![pool.integer(x0), bt]);
842    let yt = pool.add(vec![pool.integer(y0), neg_at]);
843    Ok(DiophantineSolution::ParametricLinear {
844        parameter: t,
845        values: vec![xt, yt],
846    })
847}
848
849fn classify_and_solve(
850    pool: &ExprPool,
851    terms: &BTreeMap<Vec<u32>, Integer>,
852    vars: &[ExprId],
853) -> Result<DiophantineSolution, DiophantineError> {
854    if vars.len() != 2 {
855        return Err(DiophantineError::Unsupported(
856            "exactly two variables are required".into(),
857        ));
858    }
859    let vx = vars[0];
860    let vy = vars[1];
861
862    let mut max_deg = 0u32;
863    for e in terms.keys() {
864        let tdeg: u32 = e.iter().sum();
865        max_deg = max_deg.max(tdeg);
866    }
867
868    if max_deg > 2 {
869        return Err(DiophantineError::Unsupported(
870            "degree > 2 is not supported".into(),
871        ));
872    }
873
874    if max_deg <= 1 {
875        let c00 = terms
876            .get(&vec![0, 0])
877            .cloned()
878            .unwrap_or_else(|| Integer::from(0));
879        let c10 = terms
880            .get(&vec![1, 0])
881            .cloned()
882            .unwrap_or_else(|| Integer::from(0));
883        let c01 = terms
884            .get(&vec![0, 1])
885            .cloned()
886            .unwrap_or_else(|| Integer::from(0));
887        if terms.len() > 3 {
888            return Err(DiophantineError::Unsupported(
889                "linear equation with unexpected monomials".into(),
890            ));
891        }
892        for e in terms.keys() {
893            let s: u32 = e.iter().sum();
894            if s > 1 {
895                return Err(DiophantineError::Unsupported(
896                    "mixed-degree polynomial".into(),
897                ));
898            }
899        }
900        return solve_linear_two_var(pool, &c10, &c01, &c00, vx, vy);
901    }
902
903    let c20 = terms
904        .get(&vec![2, 0])
905        .cloned()
906        .unwrap_or_else(|| Integer::from(0));
907    let c11 = terms
908        .get(&vec![1, 1])
909        .cloned()
910        .unwrap_or_else(|| Integer::from(0));
911    let c02 = terms
912        .get(&vec![0, 2])
913        .cloned()
914        .unwrap_or_else(|| Integer::from(0));
915    let c10 = terms
916        .get(&vec![1, 0])
917        .cloned()
918        .unwrap_or_else(|| Integer::from(0));
919    let c01 = terms
920        .get(&vec![0, 1])
921        .cloned()
922        .unwrap_or_else(|| Integer::from(0));
923    let c00 = terms
924        .get(&vec![0, 0])
925        .cloned()
926        .unwrap_or_else(|| Integer::from(0));
927
928    if c10 != 0 || c01 != 0 || c11 != 0 {
929        return Err(DiophantineError::Unsupported(
930            "quadratic with linear or xy terms is not implemented".into(),
931        ));
932    }
933
934    let g_content = term_gcd(&[c20.clone(), c02.clone(), c00.clone()]);
935    if g_content == 0 {
936        return Err(DiophantineError::Unsupported("zero polynomial".into()));
937    }
938    let a2 = div_exact(&c20, &g_content).unwrap();
939    let b2 = div_exact(&c02, &g_content).unwrap();
940    let cc = div_exact(&c00, &g_content).unwrap();
941
942    if a2 == 0 && b2 == 0 {
943        return Err(DiophantineError::Unsupported("no quadratic terms".into()));
944    }
945
946    if (a2 > 0 && b2 > 0) || (a2 < 0 && b2 < 0) {
947        if a2 != b2 {
948            return Err(DiophantineError::Unsupported(
949                "x² and y² must have equal coefficients for the ellipse case".into(),
950            ));
951        }
952        let a_abs = a2.clone().abs();
953        let (_, rem) = cc.clone().div_rem_euc_ref(&a_abs).into();
954        if rem != 0 {
955            return Ok(DiophantineSolution::NoSolution);
956        }
957        let n = -cc / &a_abs;
958        return Ok(solve_sum_two_squares(pool, &a_abs, &n, vx, vy));
959    }
960
961    if (a2 > 0 && b2 < 0) || (a2 < 0 && b2 > 0) {
962        let pos = if a2 > 0 { a2.clone() } else { b2.clone().abs() };
963        let neg = if a2 > 0 {
964            b2.clone().abs()
965        } else {
966            a2.clone().abs()
967        };
968        let rhs = -cc;
969
970        if rhs == 0 {
971            let (_, remd) = neg.clone().sqrt_rem(Integer::new());
972            if remd != 0 {
973                let z = pool.integer(0);
974                return Ok(DiophantineSolution::Finite(vec![vec![z, z]]));
975            }
976            let s = neg.sqrt();
977            let t = pool.symbol("_t", Domain::Integer);
978            let st = pool.mul(vec![pool.integer(s), t]);
979            return Ok(DiophantineSolution::ParametricLinear {
980                parameter: t,
981                values: vec![st, t],
982            });
983        }
984
985        return solve_pell_like(pool, &pos, &neg, &rhs);
986    }
987
988    Err(DiophantineError::Unsupported(
989        "unrecognized binary quadratic shape".into(),
990    ))
991}
992
993/// Solve a single Diophantine equation in integer unknowns.
994pub fn diophantine(
995    pool: &ExprPool,
996    equation: ExprId,
997    vars: &[ExprId],
998) -> Result<DiophantineSolution, DiophantineError> {
999    if vars.len() != 2 {
1000        return Err(DiophantineError::Unsupported(
1001            "exactly two variables are required".into(),
1002        ));
1003    }
1004    let poly = expr_to_gbpoly(equation, vars, pool)?;
1005    let int_terms = gbpoly_integer_coeffs(&poly)?;
1006    for c in poly.terms.values() {
1007        if !c.is_integer() {
1008            return Err(DiophantineError::NonIntegerCoefficients);
1009        }
1010    }
1011    classify_and_solve(pool, &int_terms, vars)
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017    use crate::kernel::{ExprData, ExprPool};
1018
1019    #[test]
1020    fn linear_3x_5y_1() {
1021        let pool = ExprPool::new();
1022        let x = pool.symbol("x", Domain::Integer);
1023        let y = pool.symbol("y", Domain::Integer);
1024        let eq = pool.add(vec![
1025            pool.mul(vec![pool.integer(3), x]),
1026            pool.mul(vec![pool.integer(5), y]),
1027            pool.integer(-1),
1028        ]);
1029        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1030        match r {
1031            DiophantineSolution::ParametricLinear { .. } => {}
1032            _ => panic!("expected parametric linear"),
1033        }
1034    }
1035
1036    #[test]
1037    fn pell_x2_2y2_1() {
1038        let pool = ExprPool::new();
1039        let x = pool.symbol("x", Domain::Integer);
1040        let y = pool.symbol("y", Domain::Integer);
1041        let x2 = pool.pow(x, pool.integer(2));
1042        let y2 = pool.pow(y, pool.integer(2));
1043        let eq = pool.add(vec![
1044            x2,
1045            pool.mul(vec![pool.integer(-2), y2]),
1046            pool.integer(-1),
1047        ]);
1048        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1049        match r {
1050            DiophantineSolution::PellFundamental { x0, y0, .. } => {
1051                assert!(pool.with(x0, |d| matches!(d, ExprData::Integer(n) if n.0 == 3)));
1052                assert!(pool.with(y0, |d| matches!(d, ExprData::Integer(n) if n.0 == 2)));
1053            }
1054            _ => panic!("expected Pell fundamental"),
1055        }
1056    }
1057
1058    #[test]
1059    fn sum_squares_5() {
1060        let pool = ExprPool::new();
1061        let x = pool.symbol("x", Domain::Integer);
1062        let y = pool.symbol("y", Domain::Integer);
1063        let eq = pool.add(vec![
1064            pool.pow(x, pool.integer(2)),
1065            pool.pow(y, pool.integer(2)),
1066            pool.integer(-5),
1067        ]);
1068        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1069        match r {
1070            DiophantineSolution::Finite(v) => {
1071                assert_eq!(v.len(), 2);
1072            }
1073            _ => panic!("expected finite set"),
1074        }
1075    }
1076
1077    #[test]
1078    fn sum_squares_65_two_orbits() {
1079        // 65 = 1²+8² = 4²+7²
1080        let pool = ExprPool::new();
1081        let x = pool.symbol("x", Domain::Integer);
1082        let y = pool.symbol("y", Domain::Integer);
1083        let eq = pool.add(vec![
1084            pool.pow(x, pool.integer(2)),
1085            pool.pow(y, pool.integer(2)),
1086            pool.integer(-65),
1087        ]);
1088        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1089        match r {
1090            DiophantineSolution::Finite(v) => {
1091                let sets: std::collections::HashSet<(i32, i32)> = v
1092                    .iter()
1093                    .map(|row| {
1094                        let xi = match pool.get(row[0]) {
1095                            ExprData::Integer(i) => i.0.to_i32().unwrap(),
1096                            _ => panic!(),
1097                        };
1098                        let yi = match pool.get(row[1]) {
1099                            ExprData::Integer(i) => i.0.to_i32().unwrap(),
1100                            _ => panic!(),
1101                        };
1102                        (xi, yi)
1103                    })
1104                    .collect();
1105                assert!(sets.contains(&(1, 8)));
1106                assert!(sets.contains(&(8, 1)));
1107                assert!(sets.contains(&(4, 7)));
1108                assert!(sets.contains(&(7, 4)));
1109            }
1110            _ => panic!("expected finite set"),
1111        }
1112    }
1113
1114    #[test]
1115    fn pell_generalized_n_minus1() {
1116        // x² - 2 y² = -1  →  (1,1) fundamental for negative Pell
1117        let pool = ExprPool::new();
1118        let x = pool.symbol("x", Domain::Integer);
1119        let y = pool.symbol("y", Domain::Integer);
1120        let eq = pool.add(vec![
1121            pool.pow(x, pool.integer(2)),
1122            pool.mul(vec![pool.integer(-2), pool.pow(y, pool.integer(2))]),
1123            pool.integer(1),
1124        ]);
1125        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1126        match r {
1127            DiophantineSolution::PellGeneralized { .. } => {}
1128            DiophantineSolution::PellFundamental { .. } => {
1129                // tolerate unit-path implementation detail
1130            }
1131            _ => panic!("expected Pell generalized or fundamental: {:?}", r),
1132        }
1133    }
1134
1135    #[test]
1136    fn linear_no_solution() {
1137        let pool = ExprPool::new();
1138        let x = pool.symbol("x", Domain::Integer);
1139        let y = pool.symbol("y", Domain::Integer);
1140        let eq = pool.add(vec![
1141            pool.mul(vec![pool.integer(2), x]),
1142            pool.mul(vec![pool.integer(4), y]),
1143            pool.integer(1),
1144        ]);
1145        let r = diophantine(&pool, eq, &[x, y]).unwrap();
1146        assert!(matches!(r, DiophantineSolution::NoSolution));
1147    }
1148
1149    #[test]
1150    fn cornacchia_prime_13() {
1151        let p = Integer::from(13);
1152        let r = prime_as_sum_two_squares(&p).unwrap();
1153        assert_eq!(r.0.clone() * &r.0 + r.1.clone() * &r.1, p);
1154    }
1155}