Skip to main content

alkahest_cas/poly/
real_roots.rs

1//! V2-4 — Real root isolation via Vincent–Akritas–Strzeboński (VAS).
2//!
3//! # Algorithm
4//!
5//! The public entry point [`real_roots`] implements the **VAS continued-fraction
6//! method** for isolating all real roots of a univariate polynomial with integer
7//! coefficients.  The core loop is the **Möbius-based Descartes bisection** (VCA):
8//!
9//! 1. Extract the squarefree part `p / gcd(p, p')` to eliminate repeated roots.
10//! 2. Separate positive and negative roots (negative = negated positive roots of
11//!    `p(−x)`).
12//! 3. Maintain a stack of `(poly, Möbius (a,b,c,d))` frames where
13//!    `x = (a·t + b)/(c·t + d)`.  The positive real roots of `poly(t)` biject with
14//!    the real roots of `p(x)` in the tracking interval.
15//! 4. At each frame:
16//!    - **Descartes' rule**: count sign variations `V` in the non-zero coefficients.
17//!      `V = 0` → no roots; `V = 1` → exactly one root, record the interval.
18//!    - **VAS CF step**: compute a Cauchy-based integer lower bound `k` on the
19//!      smallest positive root; if `k ≥ 1`, shift `p(x+k)` (Taylor translate)
20//!      before splitting — the key VAS speedup over plain bisection.
21//!    - **Bisect at t = 1**: push the right child `q(t+1)` and the left child
22//!      `(t+1)ⁿ q(1/(t+1))` = `taylor_shift_1(reverse(q))`.
23//! 5. Roots exactly at the split point `t = 1` (or at `t = 0` after a CF shift)
24//!    are detected by checking `p(1) = 0` before bisecting, recorded as exact-point
25//!    intervals, and deflated.  After any deflation a forced bisect avoids producing
26//!    overlapping intervals.
27//!
28//! # Public API
29//!
30//! - [`real_roots`] — isolate all real roots of a [`UniPoly`].
31//! - [`real_roots_symbolic`] — same, starting from a symbolic [`ExprId`].
32//! - [`refine_root`] — narrow a [`RootInterval`] to a given bit-precision.
33//! - [`RootInterval`] — rational isolating interval `[lo, hi]`.
34//! - [`RealRootError`] — error type.
35
36use crate::ball::ArbBall;
37use crate::kernel::{ExprId, ExprPool};
38use crate::poly::error::ConversionError;
39use crate::poly::unipoly::UniPoly;
40use rug::Integer;
41use std::fmt;
42
43// ---------------------------------------------------------------------------
44// Error type
45// ---------------------------------------------------------------------------
46
47/// Error returned by [`real_roots`] and [`real_roots_symbolic`].
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum RealRootError {
50    /// The expression could not be converted to a univariate polynomial with
51    /// integer coefficients.
52    NotAPolynomial(ConversionError),
53    /// The polynomial is identically zero.
54    ZeroPolynomial,
55}
56
57impl From<ConversionError> for RealRootError {
58    fn from(e: ConversionError) -> Self {
59        RealRootError::NotAPolynomial(e)
60    }
61}
62
63impl fmt::Display for RealRootError {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            RealRootError::NotAPolynomial(e) => write!(f, "not a polynomial: {e}"),
67            RealRootError::ZeroPolynomial => {
68                write!(f, "zero polynomial has infinitely many roots (E-ROOT-002)")
69            }
70        }
71    }
72}
73
74impl std::error::Error for RealRootError {}
75
76impl crate::errors::AlkahestError for RealRootError {
77    fn code(&self) -> &'static str {
78        match self {
79            RealRootError::NotAPolynomial(_) => "E-ROOT-001",
80            RealRootError::ZeroPolynomial => "E-ROOT-002",
81        }
82    }
83
84    fn remediation(&self) -> Option<&'static str> {
85        match self {
86            RealRootError::NotAPolynomial(_) => Some(
87                "ensure the input is a polynomial expression with integer coefficients \
88                 in a single variable",
89            ),
90            RealRootError::ZeroPolynomial => {
91                Some("real_roots is only defined for non-zero polynomials")
92            }
93        }
94    }
95}
96
97// ---------------------------------------------------------------------------
98// RootInterval — rational isolating interval
99// ---------------------------------------------------------------------------
100
101/// A closed rational interval `[lo, hi]` containing exactly one real root of a
102/// squarefree polynomial.  For an exact rational root `r`, `lo == hi == r`.
103#[derive(Debug, Clone)]
104pub struct RootInterval {
105    pub lo: rug::Rational,
106    pub hi: rug::Rational,
107}
108
109impl RootInterval {
110    /// Construct from two rational endpoints with `lo ≤ hi`.
111    pub fn new(lo: rug::Rational, hi: rug::Rational) -> Self {
112        debug_assert!(lo <= hi, "RootInterval requires lo ≤ hi");
113        RootInterval { lo, hi }
114    }
115
116    /// Approximate lower bound as `f64`.
117    pub fn lo_f64(&self) -> f64 {
118        self.lo.to_f64()
119    }
120
121    /// Approximate upper bound as `f64`.
122    pub fn hi_f64(&self) -> f64 {
123        self.hi.to_f64()
124    }
125
126    /// Width `hi − lo` as a [`rug::Rational`].
127    pub fn width(&self) -> rug::Rational {
128        self.hi.clone() - self.lo.clone()
129    }
130
131    /// Lower bound as `(numerator_string, denominator_string)` in decimal.
132    pub fn lo_exact(&self) -> (String, String) {
133        (self.lo.numer().to_string(), self.lo.denom().to_string())
134    }
135
136    /// Upper bound as `(numerator_string, denominator_string)` in decimal.
137    pub fn hi_exact(&self) -> (String, String) {
138        (self.hi.numer().to_string(), self.hi.denom().to_string())
139    }
140}
141
142impl fmt::Display for RootInterval {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        write!(f, "[{}, {}]", self.lo, self.hi)
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Primitive polynomial operations on Vec<Integer>
150// ---------------------------------------------------------------------------
151// Polynomials are stored as coefficient vectors in **ascending degree order**:
152// index 0 is the constant term.
153
154/// Count sign variations in the non-zero coefficients (Descartes' rule of signs).
155fn sign_variations(coeffs: &[Integer]) -> usize {
156    let nonzero: Vec<&Integer> = coeffs.iter().filter(|c| **c != 0).collect();
157    if nonzero.len() < 2 {
158        return 0;
159    }
160    let mut count = 0;
161    for w in nonzero.windows(2) {
162        let pos0 = *w[0] > 0;
163        let pos1 = *w[1] > 0;
164        if pos0 != pos1 {
165            count += 1;
166        }
167    }
168    count
169}
170
171/// Compute `p(x + 1)` using the O(n²) de Casteljau / Taylor-shift algorithm.
172///
173/// For each `i = 0..n−1`, for each `j = (i..n−1)` in reverse:
174/// `c[j] += c[j+1]`.
175fn taylor_shift_by_1(coeffs: &[Integer]) -> Vec<Integer> {
176    let mut c: Vec<Integer> = coeffs.to_vec();
177    let n = c.len();
178    for i in 0..n.saturating_sub(1) {
179        for j in (i..n - 1).rev() {
180            let cjp1 = c[j + 1].clone();
181            c[j] += cjp1;
182        }
183    }
184    c
185}
186
187/// Compute `p(x + k)` for a non-negative integer `k`.
188fn taylor_shift_by(coeffs: &[Integer], k: u64) -> Vec<Integer> {
189    if k == 0 {
190        return coeffs.to_vec();
191    }
192    let mut c = coeffs.to_vec();
193    let n = c.len();
194    for i in 0..n.saturating_sub(1) {
195        for j in (i..n - 1).rev() {
196            let delta = c[j + 1].clone() * k;
197            c[j] += delta;
198        }
199    }
200    c
201}
202
203/// Reverse the coefficient vector: `[c₀,…,cₙ] → [cₙ,…,c₀]`.
204fn reverse_coeffs(coeffs: &[Integer]) -> Vec<Integer> {
205    coeffs.iter().cloned().rev().collect()
206}
207
208/// Remove trailing zeros (eliminates zero leading coefficients).
209fn trim_trailing_zeros(c: &mut Vec<Integer>) {
210    while c.last().is_some_and(|v| *v == 0) {
211        c.pop();
212    }
213}
214
215/// Sum all coefficients: evaluates `p(1)`.
216fn evaluate_at_1(coeffs: &[Integer]) -> Integer {
217    coeffs.iter().fold(Integer::from(0), |acc, c| acc + c)
218}
219
220/// Divide by `t` (caller guarantees `c[0] == 0`).
221fn divide_by_t(coeffs: &[Integer]) -> Vec<Integer> {
222    debug_assert_eq!(coeffs[0], 0, "divide_by_t: constant term must be zero");
223    coeffs[1..].to_vec()
224}
225
226/// Divide `p` by `(t − 1)` via synthetic division (caller guarantees `p(1) = 0`).
227///
228/// Recurrence: `q[n−1] = c[n]`, `q[k−1] = c[k] + q[k]` for `k = n−1 … 1`.
229fn divide_by_t_minus_1(coeffs: &[Integer]) -> Vec<Integer> {
230    let n = coeffs.len() - 1;
231    if n == 0 {
232        return vec![];
233    }
234    let mut q = vec![Integer::from(0); n];
235    q[n - 1] = coeffs[n].clone();
236    for k in (1..n).rev() {
237        let qk = q[k].clone();
238        q[k - 1] = coeffs[k].clone() + qk;
239    }
240    q
241}
242
243/// Remove the content (integer GCD of all coefficients).
244fn remove_content(coeffs: &[Integer]) -> Vec<Integer> {
245    if coeffs.is_empty() {
246        return vec![];
247    }
248    let g = coeffs.iter().fold(Integer::from(0), |acc, c| {
249        let ca = c.clone().abs();
250        acc.gcd(&ca)
251    });
252    if g == 0 || g == 1 {
253        return coeffs.to_vec();
254    }
255    coeffs.iter().map(|c| c.clone() / g.clone()).collect()
256}
257
258/// Formal derivative: `[c₀,c₁,…,cₙ] → [c₁, 2c₂, …, ncₙ]`.
259fn derivative_coeffs(coeffs: &[Integer]) -> Vec<Integer> {
260    if coeffs.len() <= 1 {
261        return vec![];
262    }
263    coeffs[1..]
264        .iter()
265        .enumerate()
266        .map(|(i, c)| c.clone() * (i as u64 + 1))
267        .collect()
268}
269
270// ---------------------------------------------------------------------------
271// Polynomial GCD via subresultant-style pseudo-remainder
272// ---------------------------------------------------------------------------
273
274/// Pseudo-remainder of `a ÷ b` using coefficient-exact arithmetic.
275///
276/// Computes `R` satisfying `lc(b)^d · a = Q · b + R`.
277/// All arithmetic stays in ℤ; no rational numbers required.
278fn poly_pseudo_rem(a: &[Integer], b: &[Integer]) -> Vec<Integer> {
279    let db = b.len().saturating_sub(1);
280    if db == 0 {
281        // `b` is a non-zero constant → remainder is 0.
282        if b.iter().any(|c| *c != 0) {
283            return vec![];
284        }
285        return a.to_vec();
286    }
287    let lc_b = b.last().unwrap().clone();
288    let mut r = a.to_vec();
289
290    while r.len().saturating_sub(1) >= db {
291        let dr = r.len() - 1;
292        let shift = dr - db;
293        let lc_r = r.last().unwrap().clone();
294
295        // r ← lc(b) · r − lc(r) · xˢʰⁱᶠᵗ · b
296        // Coefficients at positions 0..shift: multiply by lc(b).
297        for coeff in r[..shift].iter_mut() {
298            *coeff = lc_b.clone() * coeff.clone();
299        }
300        // Coefficients at positions shift..shift+b.len():
301        // scale by lc(b) and subtract lc(r)·b[i].
302        for i in 0..b.len() {
303            let old = r[i + shift].clone();
304            r[i + shift] = lc_b.clone() * old - lc_r.clone() * b[i].clone();
305        }
306
307        r.pop();
308        trim_trailing_zeros(&mut r);
309    }
310    r
311}
312
313/// GCD of two integer polynomials (normalised to positive leading coefficient).
314fn poly_gcd(a: &[Integer], b: &[Integer]) -> Vec<Integer> {
315    let b_zero = b.iter().all(|c| *c == 0);
316    if b_zero {
317        let mut g = remove_content(a);
318        trim_trailing_zeros(&mut g);
319        if g.last().is_some_and(|v| *v < 0) {
320            for c in g.iter_mut() {
321                *c = Integer::from(0) - c.clone();
322            }
323        }
324        return g;
325    }
326
327    let prem = poly_pseudo_rem(a, b);
328    let prem_zero = prem.iter().all(|c| *c == 0);
329    if prem_zero {
330        return poly_gcd(b, &[]);
331    }
332    let mut r = remove_content(&prem);
333    trim_trailing_zeros(&mut r);
334    poly_gcd(b, &r)
335}
336
337/// Exact polynomial division `a / b` (requires `b | a`).
338fn poly_exact_div(a: &[Integer], b: &[Integer]) -> Vec<Integer> {
339    let da = a.len() as i64 - 1;
340    let db = b.len() as i64 - 1;
341    if da < db || b.iter().all(|c| *c == 0) {
342        return vec![Integer::from(0)];
343    }
344    let deg_q = (da - db) as usize;
345    let mut q = vec![Integer::from(0); deg_q + 1];
346    let mut r = a.to_vec();
347    let lc_b = b.last().unwrap().clone();
348
349    for i in (0..=deg_q).rev() {
350        let lc_r = r[i + b.len() - 1].clone();
351        let qi = lc_r / lc_b.clone();
352        q[i] = qi.clone();
353        for (j, bj) in b.iter().enumerate() {
354            let old = r[i + j].clone();
355            r[i + j] = old - qi.clone() * bj.clone();
356        }
357    }
358    q
359}
360
361// ---------------------------------------------------------------------------
362// Squarefree decomposition
363// ---------------------------------------------------------------------------
364
365/// Extract the squarefree part `p / gcd(p, p')`.
366fn squarefree_part(coeffs: &[Integer]) -> Vec<Integer> {
367    if coeffs.len() <= 1 {
368        return coeffs.to_vec();
369    }
370    let dp = derivative_coeffs(coeffs);
371    if dp.iter().all(|c| *c == 0) {
372        return coeffs.to_vec();
373    }
374    let g = poly_gcd(coeffs, &dp);
375    if g.len() <= 1 {
376        // GCD is a non-zero constant: polynomial is squarefree.
377        return coeffs.to_vec();
378    }
379    let result = poly_exact_div(coeffs, &g);
380    let mut r = remove_content(&result);
381    trim_trailing_zeros(&mut r);
382    // Normalise to positive leading coefficient.
383    if r.last().is_some_and(|v| *v < 0) {
384        for c in r.iter_mut() {
385            *c = Integer::from(0) - c.clone();
386        }
387    }
388    r
389}
390
391// ---------------------------------------------------------------------------
392// VAS CF lower bound
393// ---------------------------------------------------------------------------
394
395/// Integer lower bound on the smallest positive root of `p`.
396///
397/// Uses a doubling-then-binary-search over integer evaluation points.
398/// Precondition: `p(0) ≠ 0` (no root at the origin).
399/// Returns the largest integer `k ≥ 1` such that `p(k)` has the same sign
400/// as `p(0)` (implying all positive roots are `> k`), or `0` if the
401/// smallest positive root is in `(0, 1]`.
402fn cf_lower_bound_floor(coeffs: &[Integer]) -> u64 {
403    if coeffs.is_empty() {
404        return 0;
405    }
406    let n = coeffs.len() - 1;
407    if n == 0 {
408        return 0;
409    }
410
411    let p0 = &coeffs[0];
412    if *p0 == 0 {
413        return 0;
414    }
415    let sign = *p0 > 0;
416
417    // Horner evaluation at a non-negative integer point.
418    let eval_at = |k: u64| -> Integer {
419        let k_int = Integer::from(k);
420        coeffs
421            .iter()
422            .rev()
423            .fold(Integer::from(0), |acc, c| acc * k_int.clone() + c.clone())
424    };
425
426    // If p(1) has a different sign (or is zero), the root is in (0, 1].
427    let p1 = evaluate_at_1(coeffs);
428    if p1 == 0 || (p1 > 0) != sign {
429        return 0;
430    }
431
432    // Doubling search: find hi where sign changes.
433    let mut lo: u64 = 1;
434    let mut hi: u64 = 2;
435    let mut found_sign_change = false;
436    loop {
437        if hi > 1_000_000 {
438            break;
439        }
440        let pval = eval_at(hi);
441        if pval == 0 || (pval > 0) != sign {
442            found_sign_change = true;
443            break;
444        }
445        lo = hi;
446        hi = hi.saturating_mul(2);
447    }
448
449    // No sign change found → polynomial is positive for all integers in [1, limit],
450    // meaning all positive roots are in (0, 1).  No shift is useful.
451    if !found_sign_change {
452        return 0;
453    }
454
455    // Binary search for the transition.
456    while hi - lo > 1 {
457        let mid = lo + (hi - lo) / 2;
458        let pval = eval_at(mid);
459        if pval == 0 || (pval > 0) != sign {
460            hi = mid;
461        } else {
462            lo = mid;
463        }
464    }
465
466    lo
467}
468
469// ---------------------------------------------------------------------------
470// Main VAS bisection algorithm
471// ---------------------------------------------------------------------------
472
473/// Stack frame: polynomial together with the Möbius transform tracking which
474/// sub-interval of the original positive half-line this frame covers.
475///
476/// Invariant: the positive real roots of `poly(t)` biject with the roots of
477/// the original squarefree polynomial in `(b/d, a/c)` (or `(b/d, +∞)` when
478/// `c = 0`) via `x = (a·t + b)/(c·t + d)`.
479struct Frame {
480    poly: Vec<Integer>,
481    a: Integer,
482    b: Integer,
483    c: Integer,
484    d: Integer,
485    /// True immediately after a root-at-0 or root-at-1 deflation.
486    /// When set, skip the `sign_var == 1` shortcut and always bisect.
487    just_deflated: bool,
488}
489
490/// Compute both endpoints of the Möbius interval.
491///
492/// - `at_zero  = b/d`  (value at `t = 0`)
493/// - `at_inf   = a/c`  (value at `t → ∞`, or `None` when `c = 0`)
494///
495/// Returns `(lo, hi)` with `lo ≤ hi`.
496fn mobius_interval(
497    a: &Integer,
498    b: &Integer,
499    c: &Integer,
500    d: &Integer,
501) -> (rug::Rational, Option<rug::Rational>) {
502    let at_zero = rug::Rational::from((b.clone(), d.clone()));
503    let at_inf = if *c == 0 {
504        None
505    } else {
506        Some(rug::Rational::from((a.clone(), c.clone())))
507    };
508    match at_inf {
509        None => (at_zero, None),
510        Some(ai) => {
511            if at_zero <= ai {
512                (at_zero, Some(ai))
513            } else {
514                (ai, Some(at_zero))
515            }
516        }
517    }
518}
519
520/// Isolate all strictly-positive real roots of `coeffs` via VAS bisection.
521///
522/// The input polynomial must have a **non-zero constant term** (root at `x = 0`
523/// should be removed before calling this function).
524fn isolate_positive_roots(coeffs: Vec<Integer>) -> Vec<RootInterval> {
525    if coeffs.is_empty() || coeffs.iter().all(|c| *c == 0) {
526        return vec![];
527    }
528
529    let mut result = Vec::new();
530    let mut stack: Vec<Frame> = vec![Frame {
531        poly: coeffs,
532        a: Integer::from(1),
533        b: Integer::from(0),
534        c: Integer::from(0),
535        d: Integer::from(1),
536        just_deflated: false,
537    }];
538
539    let max_iters: usize = 500_000;
540    let mut iters = 0usize;
541
542    while let Some(mut frame) = stack.pop() {
543        iters += 1;
544        if iters > max_iters {
545            break;
546        }
547
548        trim_trailing_zeros(&mut frame.poly);
549        if frame.poly.is_empty() || frame.poly.iter().all(|c| *c == 0) {
550            continue;
551        }
552
553        // ---- Root at t = 0 (constant term = 0) --------------------------------
554        // t = 0 corresponds to x = b/d.
555        if frame.poly[0] == 0 {
556            let root_x = rug::Rational::from((frame.b.clone(), frame.d.clone()));
557            result.push(RootInterval::new(root_x.clone(), root_x));
558            frame.poly = divide_by_t(&frame.poly);
559            trim_trailing_zeros(&mut frame.poly);
560            if frame.poly.is_empty() {
561                continue;
562            }
563            // Push back with just_deflated=true so the sign_var=1 shortcut is
564            // suppressed (the remaining roots are strictly in (b/d, …), but the
565            // Möbius still starts at b/d, risking a half-open overlap).
566            frame.just_deflated = true;
567            stack.push(frame);
568            continue;
569        }
570
571        // ---- Root at t = 1 (p(1) = sum of coefficients = 0) ------------------
572        // t = 1 corresponds to x = (a+b)/(c+d).
573        let val_at_1 = evaluate_at_1(&frame.poly);
574        if val_at_1 == 0 {
575            let a_plus_b = frame.a.clone() + frame.b.clone();
576            let c_plus_d = frame.c.clone() + frame.d.clone();
577            if c_plus_d != 0 {
578                let root_x = rug::Rational::from((a_plus_b, c_plus_d));
579                result.push(RootInterval::new(root_x.clone(), root_x));
580            }
581            frame.poly = divide_by_t_minus_1(&frame.poly);
582            trim_trailing_zeros(&mut frame.poly);
583            if frame.poly.is_empty() {
584                continue;
585            }
586            // After deflation by (t−1) the remaining roots are NOT all in
587            // (1,∞); they could be anywhere in (0,∞).  Force a bisect pass
588            // so that the children's intervals are strictly disjoint from the
589            // just-recorded exact root at the split point.
590            frame.just_deflated = true;
591            stack.push(frame);
592            continue;
593        }
594
595        let v = sign_variations(&frame.poly);
596
597        match v {
598            0 => continue,
599            1 if !frame.just_deflated => {
600                // Exactly one root; if the tracking interval is bounded record it.
601                let (lo, hi_opt) = mobius_interval(&frame.a, &frame.b, &frame.c, &frame.d);
602                if let Some(hi) = hi_opt {
603                    result.push(RootInterval::new(lo, hi));
604                    continue;
605                }
606                // Unbounded interval (c = 0): fall through to CF + bisect to
607                // narrow down a finite upper bound.
608            }
609            _ => {
610                // v == 0 handled above; v ≥ 2 or v == 1 with just_deflated falls here.
611            }
612        }
613
614        // ---- VAS CF step: shift by integer lower bound k ----------------------
615        frame.just_deflated = false; // reset flag before bisection
616
617        let k = cf_lower_bound_floor(&frame.poly);
618        if k >= 1 {
619            let new_p = taylor_shift_by(&frame.poly, k);
620            let ki = Integer::from(k);
621            let new_b = frame.a.clone() * ki.clone() + frame.b.clone();
622            let new_d = frame.c.clone() * ki + frame.d.clone();
623            frame.b = new_b;
624            frame.d = new_d;
625            frame.poly = remove_content(&new_p);
626            trim_trailing_zeros(&mut frame.poly);
627            if frame.poly.is_empty() {
628                continue;
629            }
630            // Push back so the root-at-0 / root-at-1 checks fire before bisection.
631            stack.push(frame);
632            continue;
633        }
634
635        // ---- Bisect at t = 1 --------------------------------------------------
636
637        let a = frame.a.clone();
638        let b = frame.b.clone();
639        let c = frame.c.clone();
640        let d = frame.d.clone();
641
642        // Right child: roots of q in (1, ∞)  →  poly = q(t+1), Möbius (a, a+b, c, c+d).
643        {
644            let q_right_raw = taylor_shift_by_1(&frame.poly);
645            let mut q_right = remove_content(&q_right_raw);
646            trim_trailing_zeros(&mut q_right);
647            if !q_right.is_empty() && q_right.iter().any(|c| *c != 0) {
648                stack.push(Frame {
649                    poly: q_right,
650                    a: a.clone(),
651                    b: a.clone() + b.clone(),
652                    c: c.clone(),
653                    d: c.clone() + d.clone(),
654                    just_deflated: false,
655                });
656            }
657        }
658
659        // Left child: roots of q in (0, 1)  →  poly = (t+1)ⁿ·q(1/(t+1))
660        //            = taylor_shift_1(reverse(q)), Möbius (b, a+b, d, c+d).
661        {
662            let rev = reverse_coeffs(&frame.poly);
663            let q_left_raw = taylor_shift_by_1(&rev);
664            let mut q_left = remove_content(&q_left_raw);
665            trim_trailing_zeros(&mut q_left);
666            if !q_left.is_empty() && q_left.iter().any(|c| *c != 0) {
667                stack.push(Frame {
668                    poly: q_left,
669                    a: b.clone(),
670                    b: a + b,
671                    c: d.clone(),
672                    d: c + d,
673                    just_deflated: false,
674                });
675            }
676        }
677    }
678
679    result
680}
681
682// ---------------------------------------------------------------------------
683// Public entry points
684// ---------------------------------------------------------------------------
685
686/// Isolate all real roots of `poly`.
687///
688/// Returns a vector of [`RootInterval`]s sorted by lower endpoint.  Each
689/// interval contains exactly one real root of the squarefree part of `poly`.
690/// Repeated roots appear once each.
691///
692/// # Errors
693///
694/// - [`RealRootError::ZeroPolynomial`] — `poly` is the zero polynomial.
695pub fn real_roots(poly: &UniPoly) -> Result<Vec<RootInterval>, RealRootError> {
696    let mut coeffs: Vec<Integer> = poly.coefficients();
697    trim_trailing_zeros(&mut coeffs);
698
699    if coeffs.is_empty() {
700        return Err(RealRootError::ZeroPolynomial);
701    }
702    if coeffs.len() == 1 {
703        return Ok(vec![]); // Non-zero constant: no roots.
704    }
705
706    // Normalise to positive leading coefficient.
707    if coeffs.last().is_some_and(|v| *v < 0) {
708        for c in coeffs.iter_mut() {
709            *c = Integer::from(0) - c.clone();
710        }
711    }
712
713    // Squarefree part.
714    let sq = squarefree_part(&coeffs);
715
716    // Check for root at x = 0 (constant term = 0).
717    let mut result = Vec::new();
718    let working = if sq[0] == 0 {
719        result.push(RootInterval::new(
720            rug::Rational::from(0),
721            rug::Rational::from(0),
722        ));
723        sq[1..].to_vec()
724    } else {
725        sq.clone()
726    };
727
728    if working.len() <= 1 {
729        result.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap_or(std::cmp::Ordering::Equal));
730        return Ok(result);
731    }
732
733    // Positive roots.
734    result.extend(isolate_positive_roots(working.clone()));
735
736    // Negative roots: positive roots of p(−x), then negate.
737    let neg_coeffs: Vec<Integer> = working
738        .iter()
739        .enumerate()
740        .map(|(i, c)| {
741            if i % 2 == 1 {
742                Integer::from(0) - c.clone()
743            } else {
744                c.clone()
745            }
746        })
747        .collect();
748    let neg_pos = isolate_positive_roots(neg_coeffs);
749    for iv in neg_pos {
750        let neg_hi = rug::Rational::from((
751            Integer::from(0) - iv.lo.numer().clone(),
752            iv.lo.denom().clone(),
753        ));
754        let neg_lo = rug::Rational::from((
755            Integer::from(0) - iv.hi.numer().clone(),
756            iv.hi.denom().clone(),
757        ));
758        result.push(RootInterval::new(neg_lo, neg_hi));
759    }
760
761    result.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap_or(std::cmp::Ordering::Equal));
762    Ok(result)
763}
764
765/// Isolate all real roots of a symbolic expression in `var`.
766///
767/// # Errors
768///
769/// - [`RealRootError::NotAPolynomial`] if the expression cannot be converted.
770/// - [`RealRootError::ZeroPolynomial`] if the polynomial is identically zero.
771pub fn real_roots_symbolic(
772    expr: ExprId,
773    var: ExprId,
774    pool: &ExprPool,
775) -> Result<Vec<RootInterval>, RealRootError> {
776    let poly = UniPoly::from_symbolic(expr, var, pool).map_err(RealRootError::NotAPolynomial)?;
777    real_roots(&poly)
778}
779
780/// Narrow a [`RootInterval`] to at least `prec` bits of precision.
781///
782/// Uses bisection with floating-point Horner evaluation.  For exact roots
783/// (`lo == hi`), returns a zero-radius [`ArbBall`].
784pub fn refine_root(poly: &UniPoly, interval: &RootInterval, prec: u32) -> ArbBall {
785    if interval.lo == interval.hi {
786        return ArbBall::from_midpoint_radius(interval.lo.to_f64(), 0.0, prec.max(53));
787    }
788
789    let coeffs_f64: Vec<f64> = poly.coefficients().iter().map(|c| c.to_f64()).collect();
790    let eval = |x: f64| -> f64 { coeffs_f64.iter().rev().fold(0.0_f64, |acc, &c| acc * x + c) };
791
792    let target_width = 2.0_f64.powi(-(prec as i32));
793    let mut lo = interval.lo.to_f64();
794    let mut hi = interval.hi.to_f64();
795    let mut f_lo = eval(lo);
796
797    for _ in 0..300 {
798        if hi - lo <= target_width {
799            break;
800        }
801        let mid = (lo + hi) / 2.0;
802        let f_mid = eval(mid);
803        if f_lo * f_mid <= 0.0 {
804            hi = mid;
805        } else {
806            lo = mid;
807            f_lo = f_mid;
808        }
809    }
810
811    let center = (lo + hi) / 2.0;
812    let radius = (hi - lo) / 2.0;
813    ArbBall::from_midpoint_radius(center, radius, prec.max(53))
814}
815
816// ---------------------------------------------------------------------------
817// Unit tests
818// ---------------------------------------------------------------------------
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use crate::flint::{FlintInteger, FlintPoly};
824    use crate::kernel::{Domain, ExprPool};
825
826    /// Build a `UniPoly` from a slice of `i64` coefficients (ascending degree).
827    fn make_poly(coeffs: &[i64]) -> UniPoly {
828        let p = ExprPool::new();
829        let x = p.symbol("x", Domain::Real);
830        let mut flint = FlintPoly::new();
831        for (i, &c) in coeffs.iter().enumerate() {
832            let fi = FlintInteger::from_i64(c);
833            flint.set_coeff_flint(i, &fi);
834        }
835        UniPoly {
836            var: x,
837            coeffs: flint,
838        }
839    }
840
841    // ---- sign_variations ----
842
843    #[test]
844    fn sv_all_positive() {
845        let c: Vec<Integer> = vec![1, 2, 3].into_iter().map(Integer::from).collect();
846        assert_eq!(sign_variations(&c), 0);
847    }
848
849    #[test]
850    fn sv_alternating() {
851        let c: Vec<Integer> = vec![1, -1, 1, -1i64]
852            .into_iter()
853            .map(Integer::from)
854            .collect();
855        assert_eq!(sign_variations(&c), 3);
856    }
857
858    #[test]
859    fn sv_with_zeros() {
860        // Zeros are ignored: [1, 0, -1] → one sign change.
861        let c: Vec<Integer> = vec![1i64, 0, -1].into_iter().map(Integer::from).collect();
862        assert_eq!(sign_variations(&c), 1);
863    }
864
865    // ---- taylor_shift_by_1 ----
866
867    #[test]
868    fn taylor_shift_quadratic() {
869        // p(x) = x² + 2x + 1 = [1,2,1]; p(x+1) = x² + 4x + 4 = [4,4,1].
870        let c: Vec<Integer> = vec![1, 2, 1i64].into_iter().map(Integer::from).collect();
871        let shifted = taylor_shift_by_1(&c);
872        let expected: Vec<Integer> = vec![4, 4, 1i64].into_iter().map(Integer::from).collect();
873        assert_eq!(shifted, expected);
874    }
875
876    #[test]
877    fn taylor_shift_linear() {
878        // p(x) = 3x + 2; p(x+1) = 3x + 5; [2,3] → [5,3].
879        let c: Vec<Integer> = vec![2, 3i64].into_iter().map(Integer::from).collect();
880        let shifted = taylor_shift_by_1(&c);
881        assert_eq!(shifted[0], Integer::from(5i64));
882        assert_eq!(shifted[1], Integer::from(3i64));
883    }
884
885    // ---- squarefree_part ----
886
887    #[test]
888    fn sqfree_linear_already_squarefree() {
889        let c: Vec<Integer> = vec![-1i64, 1].into_iter().map(Integer::from).collect();
890        let sq = squarefree_part(&c);
891        assert_eq!(sq.len(), 2);
892    }
893
894    #[test]
895    fn sqfree_removes_double_root() {
896        // (x-1)² = x² - 2x + 1 = [1,-2,1]; squarefree part = x - 1 (degree 1).
897        let c: Vec<Integer> = vec![1i64, -2, 1].into_iter().map(Integer::from).collect();
898        let sq = squarefree_part(&c);
899        assert_eq!(sq.len(), 2, "squarefree part of (x-1)² must be degree 1");
900    }
901
902    #[test]
903    fn sqfree_triple_root() {
904        // (x-2)³ = x³ - 6x² + 12x - 8 = [-8,12,-6,1]; squarefree part = x-2.
905        let c: Vec<Integer> = vec![-8i64, 12, -6, 1]
906            .into_iter()
907            .map(Integer::from)
908            .collect();
909        let sq = squarefree_part(&c);
910        assert_eq!(sq.len(), 2, "squarefree part of (x-2)³ must be degree 1");
911    }
912
913    // ---- divide_by_t_minus_1 ----
914
915    #[test]
916    fn div_t_minus_1_basic() {
917        // x² - 1 = (x-1)(x+1); dividing by (t-1) gives (x+1) = [1,1].
918        let c: Vec<Integer> = vec![-1i64, 0, 1].into_iter().map(Integer::from).collect();
919        assert_eq!(evaluate_at_1(&c), Integer::from(0i64));
920        let q = divide_by_t_minus_1(&c);
921        assert_eq!(q.len(), 2);
922        // x² - 1 = [-1, 0, 1]; divide by (t-1):
923        //   q[1] = coeffs[2] = 1
924        //   q[0] = coeffs[1] + q[1] = 0 + 1 = 1
925        // → q = [1, 1] = x + 1, ascending order.
926        assert_eq!(
927            q[0],
928            Integer::from(1i64),
929            "constant term of x+1 should be 1"
930        );
931        assert_eq!(
932            q[1],
933            Integer::from(1i64),
934            "x-coefficient of x+1 should be 1"
935        );
936    }
937
938    // ---- poly_pseudo_rem ----
939
940    #[test]
941    fn pseudo_rem_double_root() {
942        // prem(x² - 2x + 1, 2x - 2) should give 0 (since gcd = x-1).
943        let a: Vec<Integer> = vec![1i64, -2, 1].into_iter().map(Integer::from).collect();
944        let b: Vec<Integer> = vec![-2i64, 2].into_iter().map(Integer::from).collect();
945        let r = poly_pseudo_rem(&a, &b);
946        assert!(
947            r.iter().all(|c| *c == 0),
948            "prem of (x-1)² by 2(x-1) should be 0, got {:?}",
949            r
950        );
951    }
952
953    // ---- isolate_positive_roots ----
954
955    #[test]
956    fn isolate_x_minus_1() {
957        let c: Vec<Integer> = vec![-1i64, 1].into_iter().map(Integer::from).collect();
958        let roots = isolate_positive_roots(c);
959        assert_eq!(roots.len(), 1);
960        assert!(roots[0].lo <= 1);
961        assert!(roots[0].hi >= 1);
962    }
963
964    #[test]
965    fn isolate_x_squared_minus_1_positive() {
966        // x² - 1 = (x-1)(x+1); one positive root at x=1.
967        let c: Vec<Integer> = vec![-1i64, 0, 1].into_iter().map(Integer::from).collect();
968        let roots = isolate_positive_roots(c);
969        assert_eq!(roots.len(), 1);
970        assert!(roots[0].lo <= 1);
971        assert!(roots[0].hi >= 1);
972    }
973
974    #[test]
975    fn isolate_two_positive_roots() {
976        // (x-1)(x-2) = x² - 3x + 2; roots at 1 and 2.
977        let c: Vec<Integer> = vec![2i64, -3, 1].into_iter().map(Integer::from).collect();
978        let roots = isolate_positive_roots(c);
979        assert_eq!(roots.len(), 2, "expected 2 positive roots");
980        let mut sorted = roots;
981        sorted.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap());
982        // Intervals must be disjoint: sorted[0].hi ≤ sorted[1].lo.
983        assert!(
984            sorted[0].hi <= sorted[1].lo,
985            "intervals must be disjoint: [{},{}] and [{},{}]",
986            sorted[0].lo,
987            sorted[0].hi,
988            sorted[1].lo,
989            sorted[1].hi
990        );
991    }
992
993    // ---- real_roots ----
994
995    #[test]
996    fn real_roots_x_squared_minus_1() {
997        let poly = make_poly(&[-1, 0, 1]);
998        let roots = real_roots(&poly).unwrap();
999        assert_eq!(roots.len(), 2, "x² - 1 has 2 real roots");
1000        assert!(roots[0].lo < 0);
1001        assert!(roots[1].lo >= 0);
1002    }
1003
1004    #[test]
1005    fn real_roots_no_real_roots() {
1006        // x² + 1 has no real roots.
1007        let poly = make_poly(&[1, 0, 1]);
1008        let roots = real_roots(&poly).unwrap();
1009        assert_eq!(roots.len(), 0);
1010    }
1011
1012    #[test]
1013    fn real_roots_cluster_squarefree() {
1014        // (x-1)⁵·(x+1)³ has roots at ±1; squarefree part = x²-1.
1015        let poly = make_poly(&[-1, 0, 1]); // Use squarefree representative.
1016        let roots = real_roots(&poly).unwrap();
1017        assert_eq!(roots.len(), 2);
1018    }
1019
1020    #[test]
1021    fn real_roots_disjoint() {
1022        // (x-1)(x-2)(x-3) = x³ - 6x² + 11x - 6; roots at 1, 2, 3.
1023        let poly = make_poly(&[-6, 11, -6, 1]);
1024        let mut roots = real_roots(&poly).unwrap();
1025        assert_eq!(roots.len(), 3);
1026        roots.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap());
1027        for w in roots.windows(2) {
1028            assert!(
1029                w[0].hi <= w[1].lo,
1030                "intervals must be disjoint: {} and {}",
1031                w[0],
1032                w[1]
1033            );
1034        }
1035    }
1036
1037    #[test]
1038    fn real_roots_chebyshev_t4() {
1039        // T₄(x) = 8x⁴ - 8x² + 1; 4 roots in (-1, 1).
1040        let poly = make_poly(&[1, 0, -8, 0, 8]);
1041        let roots = real_roots(&poly).unwrap();
1042        assert_eq!(roots.len(), 4, "T₄ has 4 real roots");
1043        for r in &roots {
1044            assert!(r.lo >= -1);
1045            assert!(r.hi <= 1);
1046        }
1047    }
1048
1049    #[test]
1050    fn real_roots_constant_zero() {
1051        let poly = make_poly(&[0]);
1052        assert!(matches!(
1053            real_roots(&poly),
1054            Err(RealRootError::ZeroPolynomial)
1055        ));
1056    }
1057
1058    #[test]
1059    fn real_roots_constant_nonzero() {
1060        let poly = make_poly(&[5]);
1061        assert_eq!(real_roots(&poly).unwrap().len(), 0);
1062    }
1063
1064    #[test]
1065    fn real_roots_symbolic_x_squared_minus_4() {
1066        let p = ExprPool::new();
1067        let x = p.symbol("x", Domain::Real);
1068        let xsq = p.pow(x, p.integer(2_i32));
1069        let expr = p.add(vec![xsq, p.integer(-4_i32)]);
1070        let roots = real_roots_symbolic(expr, x, &p).unwrap();
1071        assert_eq!(roots.len(), 2);
1072        assert!(roots[0].lo <= -2);
1073        assert!(roots[0].hi >= -2);
1074        assert!(roots[1].lo <= 2);
1075        assert!(roots[1].hi >= 2);
1076    }
1077
1078    #[test]
1079    fn real_roots_five_distinct() {
1080        // (x-1)(x-2)(x-3)(x-4)(x-5) = x⁵ - 15x⁴ + 85x³ - 225x² + 274x - 120.
1081        let poly = make_poly(&[-120, 274, -225, 85, -15, 1]);
1082        let roots = real_roots(&poly).unwrap();
1083        assert_eq!(roots.len(), 5, "product (x-1)…(x-5) must have 5 real roots");
1084        // Each known root must be enclosed.
1085        for k in 1..=5i64 {
1086            let rk = rug::Rational::from(k);
1087            assert!(
1088                roots.iter().any(|iv| iv.lo <= rk && iv.hi >= rk),
1089                "root at x={k} not enclosed"
1090            );
1091        }
1092    }
1093
1094    #[test]
1095    fn real_roots_disjoint_five() {
1096        let poly = make_poly(&[-120, 274, -225, 85, -15, 1]);
1097        let mut roots = real_roots(&poly).unwrap();
1098        roots.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap());
1099        for w in roots.windows(2) {
1100            assert!(
1101                w[0].hi <= w[1].lo,
1102                "intervals overlap: {} and {}",
1103                w[0],
1104                w[1]
1105            );
1106        }
1107    }
1108
1109    #[test]
1110    fn refine_root_x_minus_3() {
1111        let poly = make_poly(&[-3, 1]);
1112        let roots = real_roots(&poly).unwrap();
1113        assert_eq!(roots.len(), 1);
1114        let ball = refine_root(&poly, &roots[0], 53);
1115        assert!(ball.contains(3.0), "refined ball must contain x=3");
1116    }
1117}