Skip to main content

alkahest_cas/poly/
resultant.rs

1//! Resultant and subresultant polynomial remainder sequence (V2-2).
2//!
3//! # Public API
4//!
5//! - [`resultant`] — compute `res(p, q, var)` using FLINT's multivariate
6//!   resultant.  Works for univariate (integer result) and multivariate
7//!   (polynomial result) inputs.
8//! - [`subresultant_prs`] — compute the full subresultant polynomial
9//!   remainder sequence for univariate polynomials with integer coefficients.
10//!
11//! # Derivation log
12//!
13//! Both functions record a single [`RewriteStep`] with rule name
14//! `"Resultant"` / `"SubresultantPRS"` and the Lean 4 theorem tag
15//! `Polynomial.resultant_eq_zero_iff_common_root`.
16
17use crate::deriv::{DerivationLog, DerivedExpr, RewriteStep};
18use crate::flint::integer::FlintInteger;
19use crate::flint::mpoly::FlintMPolyCtx;
20use crate::kernel::{ExprData, ExprId, ExprPool};
21use crate::poly::error::ConversionError;
22use crate::poly::multipoly::multi_to_flint_pub;
23use crate::poly::multipoly::MultiPoly;
24use crate::poly::unipoly::UniPoly;
25use std::collections::{BTreeMap, BTreeSet};
26use std::fmt;
27
28// ---------------------------------------------------------------------------
29// Error type
30// ---------------------------------------------------------------------------
31
32/// Error returned by [`resultant`] and [`subresultant_prs`].
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ResultantError {
35    /// One or both expressions could not be parsed as polynomials in the
36    /// given variable(s).
37    NotAPolynomial(ConversionError),
38    /// FLINT's internal resultant computation failed (algorithm error).
39    FlintError,
40}
41
42impl From<ConversionError> for ResultantError {
43    fn from(e: ConversionError) -> Self {
44        ResultantError::NotAPolynomial(e)
45    }
46}
47
48impl fmt::Display for ResultantError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            ResultantError::NotAPolynomial(e) => write!(f, "not a polynomial: {e}"),
52            ResultantError::FlintError => {
53                write!(f, "FLINT resultant computation failed (E-RES-003)")
54            }
55        }
56    }
57}
58
59impl std::error::Error for ResultantError {}
60
61impl crate::errors::AlkahestError for ResultantError {
62    fn code(&self) -> &'static str {
63        match self {
64            ResultantError::NotAPolynomial(_) => "E-RES-001",
65            ResultantError::FlintError => "E-RES-003",
66        }
67    }
68
69    fn remediation(&self) -> Option<&'static str> {
70        match self {
71            ResultantError::NotAPolynomial(_) => Some(
72                "ensure both arguments are polynomial expressions with integer \
73                 coefficients in the given variable",
74            ),
75            ResultantError::FlintError => None,
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// Free-variable collection
82// ---------------------------------------------------------------------------
83
84/// Walk the expression DAG and collect every distinct [`ExprId`] that
85/// corresponds to a `Symbol` node.  Result is sorted by `ExprId` for a
86/// deterministic variable ordering.
87pub(crate) fn collect_free_vars(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
88    let mut set = BTreeSet::new();
89    collect_vars_rec(expr, pool, &mut set);
90    set.into_iter().collect()
91}
92
93fn collect_vars_rec(expr: ExprId, pool: &ExprPool, out: &mut BTreeSet<ExprId>) {
94    // Collect sub-expression IDs to recurse into without holding the pool lock.
95    let children: Vec<ExprId> = pool.with(expr, |data| match data {
96        ExprData::Symbol { .. } => {
97            out.insert(expr);
98            vec![]
99        }
100        ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_) => vec![],
101        ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
102        ExprData::Pow { base, exp } => vec![*base, *exp],
103        ExprData::Func { args, .. } => args.clone(),
104        ExprData::Piecewise { branches, default } => {
105            let mut ids: Vec<ExprId> = branches.iter().flat_map(|(c, v)| [*c, *v]).collect();
106            ids.push(*default);
107            ids
108        }
109        ExprData::Predicate { args, .. } => args.clone(),
110        ExprData::Forall { var, body } | ExprData::Exists { var, body } => vec![*var, *body],
111        ExprData::BigO(arg) => vec![*arg],
112    });
113    for child in children {
114        collect_vars_rec(child, pool, out);
115    }
116}
117
118// ---------------------------------------------------------------------------
119// resultant
120// ---------------------------------------------------------------------------
121
122/// Compute the resultant of `p` and `q` with respect to `var`.
123///
124/// Both `p` and `q` must be polynomial expressions with integer coefficients
125/// in all the symbolic variables they contain.  Non-polynomial sub-expressions
126/// (transcendental functions, rational coefficients, symbolic exponents) are
127/// rejected with [`ResultantError::NotAPolynomial`].
128///
129/// The return value is the resultant polynomial as a symbolic expression:
130/// - In the **univariate** case (only `var` appears) the result is an integer
131///   constant.
132/// - In the **multivariate** case the result is a polynomial in the remaining
133///   variables (`var` has been eliminated).
134///
135/// # Derivation log
136///
137/// Records a single `"Resultant"` step tagged with the Lean 4 theorem
138/// `Polynomial.resultant_eq_zero_iff_common_root`.
139///
140/// # Errors
141///
142/// - [`ResultantError::NotAPolynomial`] — an input is not a polynomial with
143///   integer coefficients.
144/// - [`ResultantError::FlintError`] — FLINT's internal computation failed
145///   (extremely rare; indicates degenerate or overflow inputs).
146///
147/// # Examples
148///
149/// ```text
150/// // Univariate: res(x^2 - 5x + 6, x - 2, x) == 0  (common root x=2)
151/// // Bivariate:  res(x^2 + y^2 - 1, y - x, y) == 2*x^2 - 1
152/// ```
153pub fn resultant(
154    p: ExprId,
155    q: ExprId,
156    var: ExprId,
157    pool: &ExprPool,
158) -> Result<DerivedExpr<ExprId>, ResultantError> {
159    // Collect all free variables from both expressions; always include `var`.
160    let mut all: BTreeSet<ExprId> = BTreeSet::new();
161    for v in collect_free_vars(p, pool) {
162        all.insert(v);
163    }
164    for v in collect_free_vars(q, pool) {
165        all.insert(v);
166    }
167    all.insert(var);
168
169    let vars: Vec<ExprId> = all.into_iter().collect();
170    let nvars = vars.len();
171    let var_idx = vars.iter().position(|&v| v == var).unwrap();
172
173    // Convert both expressions to MultiPoly in the unified variable list.
174    let mp = MultiPoly::from_symbolic(p, vars.clone(), pool)?;
175    let mq = MultiPoly::from_symbolic(q, vars.clone(), pool)?;
176
177    // Build FLINT multivariate context and polynomials.
178    let ctx = FlintMPolyCtx::new(nvars.max(1));
179    let fp = multi_to_flint_pub(&mp, &ctx);
180    let fq = multi_to_flint_pub(&mq, &ctx);
181
182    // Call FLINT's resultant.
183    let fr = fp
184        .resultant(&fq, var_idx, &ctx)
185        .ok_or(ResultantError::FlintError)?;
186
187    // Extract terms from the FLINT result (all in the same nvars-dim context).
188    let res_raw = fr.terms(nvars.max(1), &ctx);
189
190    // Build a MultiPoly for the result, dropping the eliminated variable
191    // dimension (its exponent should be 0 in every term).
192    let remaining_vars: Vec<ExprId> = vars
193        .iter()
194        .enumerate()
195        .filter_map(|(i, &v)| if i == var_idx { None } else { Some(v) })
196        .collect();
197
198    let mut new_terms: BTreeMap<Vec<u32>, rug::Integer> = BTreeMap::new();
199    for (exp, coeff) in res_raw {
200        let mut new_exp: Vec<u32> = exp
201            .into_iter()
202            .enumerate()
203            .filter_map(|(i, e)| if i == var_idx { None } else { Some(e) })
204            .collect();
205        while new_exp.last() == Some(&0) {
206            new_exp.pop();
207        }
208        let entry = new_terms
209            .entry(new_exp)
210            .or_insert_with(|| rug::Integer::from(0));
211        *entry += &coeff;
212    }
213    new_terms.retain(|_, v| *v != 0);
214
215    let result_mp = MultiPoly {
216        vars: remaining_vars,
217        terms: new_terms,
218    };
219    let result_expr = result_mp.to_expr(pool);
220
221    let step = RewriteStep::simple("Resultant", p, result_expr);
222    Ok(DerivedExpr::with_step(result_expr, step))
223}
224
225// ---------------------------------------------------------------------------
226// subresultant_prs — pure-Rust, univariate, integer coefficients
227// ---------------------------------------------------------------------------
228
229/// Compute the subresultant polynomial remainder sequence of `p` and `q`
230/// with respect to `var`.
231///
232/// Both polynomials must be **univariate** in `var` with **integer**
233/// coefficients.  Multivariate inputs (coefficients involving other symbols)
234/// produce [`ResultantError::NotAPolynomial`].
235///
236/// Returns a [`DerivedExpr`] whose value is the full PRS as a
237/// `Vec<ExprId>`:
238/// `[p, q, S₂, S₃, …, Sₖ]`
239///
240/// The 0th subresultant — the resultant — can be extracted as the last
241/// element that is a constant (degree-0) polynomial, or from
242/// [`resultant`] directly.
243///
244/// # Algorithm
245///
246/// Classical Brown–Collins subresultant algorithm (1971/1967).  Computations
247/// stay in ℤ\[x\]; all coefficient scalings are exact integer divisions
248/// guaranteed by the subresultant theory.
249///
250/// # Derivation log
251///
252/// Records a single `"SubresultantPRS"` step.
253pub fn subresultant_prs(
254    p: ExprId,
255    q: ExprId,
256    var: ExprId,
257    pool: &ExprPool,
258) -> Result<DerivedExpr<Vec<ExprId>>, ResultantError> {
259    // Convert to UniPoly (rejects non-integer coefficients and other symbols).
260    let mut up = UniPoly::from_symbolic(p, var, pool)?;
261    let mut uq = UniPoly::from_symbolic(q, var, pool)?;
262
263    // Canonical orientation: deg(P) >= deg(Q).
264    if up.degree() < uq.degree() {
265        std::mem::swap(&mut up, &mut uq);
266    }
267
268    let prs_polys = sprs_inner(up, uq);
269
270    // Convert each polynomial in the sequence back to a symbolic expression.
271    let exprs: Vec<ExprId> = prs_polys
272        .into_iter()
273        .map(|poly| poly.to_symbolic_expr(pool))
274        .collect();
275
276    let mut log = DerivationLog::new();
277    if let (Some(&first), Some(&last)) = (exprs.first(), exprs.last()) {
278        log.push(RewriteStep::simple("SubresultantPRS", first, last));
279    }
280    Ok(DerivedExpr::with_log(exprs, log))
281}
282
283// ---------------------------------------------------------------------------
284// Internal: Brown–Collins subresultant PRS
285// ---------------------------------------------------------------------------
286
287/// Classical subresultant PRS (Brown 1971, Collins 1967).
288///
289/// Requires `deg(p) >= deg(q)`.  Returns the sequence `[P, Q, S₂, …, Sₖ]`.
290fn sprs_inner(p: UniPoly, q: UniPoly) -> Vec<UniPoly> {
291    let var = p.var;
292    let mut sequence = vec![p.clone(), q.clone()];
293
294    if q.is_zero() {
295        return sequence;
296    }
297
298    let m = p.degree();
299    let n = q.degree();
300    if n < 0 {
301        return sequence;
302    }
303
304    // β₁ = (-1)^(m - n + 1)
305    let delta0 = (m - n) as u32;
306    let beta: rug::Integer = if (delta0 + 1) % 2 == 0 {
307        rug::Integer::from(1)
308    } else {
309        rug::Integer::from(-1)
310    };
311
312    let mut beta_cur = beta;
313    let mut psi_cur: rug::Integer = rug::Integer::from(-1);
314
315    let mut a = p;
316    let mut b = q;
317
318    loop {
319        if b.is_zero() {
320            break;
321        }
322
323        let deg_a = a.degree();
324        let deg_b = b.degree();
325        if deg_b < 0 {
326            break;
327        }
328        let delta = (deg_a - deg_b) as u32;
329
330        // Pseudo-remainder: lc(b)^d * a = Q*b + R
331        let (_, r_flint, _d) = a.coeffs.pseudo_divrem(&b.coeffs);
332        if r_flint.is_zero() {
333            break;
334        }
335
336        // S_{i+1} = prem(S_{i-1}, S_i) / β_i  [exact scalar division]
337        let beta_fi = FlintInteger::from_rug(&beta_cur);
338        let c_coeffs = r_flint.scalar_divexact_fmpz(&beta_fi);
339        let c = UniPoly {
340            var,
341            coeffs: c_coeffs,
342        };
343        sequence.push(c.clone());
344
345        // Update ψ: ψ_{i+1} = (-lc(b))^δ / ψ_i^(δ-1)
346        let lc_b_fmpz = b.coeffs.leading_coeff_fmpz();
347        let lc_b = lc_b_fmpz.to_rug();
348        let neg_lc_b: rug::Integer = -lc_b;
349
350        let psi_new = if delta <= 1 {
351            // ψ^0 = 1, so result is just (-lc(b))^δ
352            rug_pow(&neg_lc_b, delta)
353        } else {
354            let num = rug_pow(&neg_lc_b, delta);
355            let den = rug_pow(&psi_cur, delta - 1);
356            rug::Integer::from(num.div_exact_ref(&den))
357        };
358
359        // β_{i+1} = -lc(b) · ψ_{i+1}
360        let beta_new = neg_lc_b * &psi_new;
361
362        a = b;
363        b = c;
364        psi_cur = psi_new;
365        beta_cur = beta_new;
366    }
367
368    sequence
369}
370
371/// Integer exponentiation for [`rug::Integer`] (non-negative exponent).
372fn rug_pow(base: &rug::Integer, exp: u32) -> rug::Integer {
373    if exp == 0 {
374        return rug::Integer::from(1);
375    }
376    let mut r = base.clone();
377    for _ in 1..exp {
378        r *= base;
379    }
380    r
381}
382
383// ---------------------------------------------------------------------------
384// Unit tests
385// ---------------------------------------------------------------------------
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::kernel::{Domain, ExprPool};
391
392    fn pool_xy() -> (ExprPool, ExprId, ExprId) {
393        let p = ExprPool::new();
394        let x = p.symbol("x", Domain::Real);
395        let y = p.symbol("y", Domain::Real);
396        (p, x, y)
397    }
398
399    // --- collect_free_vars ---
400
401    #[test]
402    fn free_vars_constant() {
403        let p = ExprPool::new();
404        let five = p.integer(5_i32);
405        let vars = collect_free_vars(five, &p);
406        assert!(vars.is_empty());
407    }
408
409    #[test]
410    fn free_vars_symbol() {
411        let p = ExprPool::new();
412        let x = p.symbol("x", Domain::Real);
413        let vars = collect_free_vars(x, &p);
414        assert_eq!(vars, vec![x]);
415    }
416
417    #[test]
418    fn free_vars_polynomial() {
419        let (p, x, y) = pool_xy();
420        // x^2 + y - 1
421        let xsq = p.pow(x, p.integer(2_i32));
422        let expr = p.add(vec![xsq, y, p.integer(-1_i32)]);
423        let vars = collect_free_vars(expr, &p);
424        assert_eq!(vars.len(), 2);
425        assert!(vars.contains(&x));
426        assert!(vars.contains(&y));
427    }
428
429    // --- resultant: univariate cases ---
430
431    #[test]
432    fn resultant_common_root() {
433        // res(x^2 - 5x + 6, x - 2, x) == 0  (both vanish at x=2)
434        let p = ExprPool::new();
435        let x = p.symbol("x", Domain::Real);
436        // p = x^2 - 5x + 6
437        let xsq = p.pow(x, p.integer(2_i32));
438        let five_x = p.mul(vec![p.integer(-5_i32), x]);
439        let poly_p = p.add(vec![xsq, five_x, p.integer(6_i32)]);
440        // q = x - 2
441        let poly_q = p.add(vec![x, p.integer(-2_i32)]);
442
443        let dr = resultant(poly_p, poly_q, x, &p).unwrap();
444        // Result should be the integer 0
445        match p.get(dr.value) {
446            ExprData::Integer(n) => assert_eq!(n.0, 0),
447            _ => panic!("expected integer 0, got {:?}", p.get(dr.value)),
448        }
449        // Derivation log records one step
450        assert_eq!(dr.log.len(), 1);
451        assert_eq!(dr.log.steps()[0].rule_name, "Resultant");
452    }
453
454    #[test]
455    fn resultant_coprime() {
456        // res(x^2 + 1, x - 1, x) == 2  (no common roots over ℂ... actually x=i,
457        // but x-1 has root 1 and x^2+1 has roots ±i, so coprime)
458        let p = ExprPool::new();
459        let x = p.symbol("x", Domain::Real);
460        // x^2 + 1
461        let xsq = p.pow(x, p.integer(2_i32));
462        let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
463        // x - 1
464        let poly_q = p.add(vec![x, p.integer(-1_i32)]);
465        let dr = resultant(poly_p, poly_q, x, &p).unwrap();
466        match p.get(dr.value) {
467            ExprData::Integer(n) => assert_eq!(n.0, 2),
468            _ => panic!("expected integer 2, got {:?}", p.get(dr.value)),
469        }
470    }
471
472    #[test]
473    fn resultant_linear_linear() {
474        // res(x - a, x - b, x) = a - b  (resultant = lc(f)^deg(g) * g(roots of f))
475        // Concretely: res(x - 3, x - 7, x) = g(3) = 3 - 7 = -4
476        let p = ExprPool::new();
477        let x = p.symbol("x", Domain::Real);
478        let poly_p = p.add(vec![x, p.integer(-3_i32)]);
479        let poly_q = p.add(vec![x, p.integer(-7_i32)]);
480        let dr = resultant(poly_p, poly_q, x, &p).unwrap();
481        match p.get(dr.value) {
482            ExprData::Integer(n) => {
483                // res(x-3, x-7) = (3 - 7) = -4
484                assert_eq!(
485                    n.0.clone().abs(),
486                    rug::Integer::from(4),
487                    "magnitude should be 4"
488                );
489            }
490            _ => panic!("expected integer, got {:?}", p.get(dr.value)),
491        }
492    }
493
494    // --- resultant: bivariate (implicitization) ---
495
496    #[test]
497    fn resultant_bivariate_eliminates_var() {
498        // res(x^2 + y^2 - 1, y - x, y) should equal 2x^2 - 1
499        // We verify by checking the result is non-zero and degree 2 in x.
500        let (p, x, y) = pool_xy();
501
502        // x^2 + y^2 - 1
503        let xsq = p.pow(x, p.integer(2_i32));
504        let ysq = p.pow(y, p.integer(2_i32));
505        let circle = p.add(vec![xsq, ysq, p.integer(-1_i32)]);
506
507        // y - x
508        let line = p.add(vec![y, p.mul(vec![p.integer(-1_i32), x])]);
509
510        let dr = resultant(circle, line, y, &p).unwrap();
511        let res_expr = dr.value;
512
513        // The result should be a polynomial in x of degree 2.
514        // Verify by converting to UniPoly in x.
515        let res_poly = UniPoly::from_symbolic(res_expr, x, &p).unwrap();
516        assert_eq!(res_poly.degree(), 2, "expected degree-2 resultant in x");
517        // Coefficients should be [-1, 0, 2] i.e. -1 + 0*x + 2*x^2
518        let coeffs = res_poly.coefficients_i64();
519        assert_eq!(coeffs[0], -1, "constant term should be -1");
520        assert_eq!(coeffs[2], 2, "leading coefficient should be 2");
521    }
522
523    // --- implicitization: twisted cubic (t^2, t^3) ---
524
525    #[test]
526    fn resultant_implicitization_twisted_cubic() {
527        // Parametrically: x = t^2, y = t^3.
528        // Eliminate t: res(x - t^2, y - t^3, t) == y^2 - x^3
529        let pool = ExprPool::new();
530        let t = pool.symbol("t", Domain::Real);
531        let x = pool.symbol("x", Domain::Real);
532        let y = pool.symbol("y", Domain::Real);
533
534        // p1 = x - t^2
535        let t2 = pool.pow(t, pool.integer(2_i32));
536        let p1 = pool.add(vec![x, pool.mul(vec![pool.integer(-1_i32), t2])]);
537
538        // p2 = y - t^3
539        let t3 = pool.pow(t, pool.integer(3_i32));
540        let p2 = pool.add(vec![y, pool.mul(vec![pool.integer(-1_i32), t3])]);
541
542        let dr = resultant(p1, p2, t, &pool).unwrap();
543        let res_expr = dr.value;
544
545        // The result should be y^2 - x^3 (or a scalar multiple).
546        // Verify by evaluating at (x=4, y=8): 64 - 64 = 0 (point on the curve).
547        // And at (x=1, y=2): 4 - 1 = 3 ≠ 0 (not on the curve).
548        use crate::kernel::subs;
549        use std::collections::HashMap;
550        let one = pool.integer(1_i32);
551        let two = pool.integer(2_i32);
552        let four = pool.integer(4_i32);
553        let eight = pool.integer(8_i32);
554
555        // Substitute (x=4, y=8) → should give 0
556        let mut map_on = HashMap::new();
557        map_on.insert(x, four);
558        map_on.insert(y, eight);
559        let at_4_8 = subs(res_expr, &map_on, &pool);
560        let simplified_0 = crate::simplify::simplify(at_4_8, &pool);
561        match pool.get(simplified_0.value) {
562            ExprData::Integer(n) => assert_eq!(n.0, 0, "res at (4,8) should be 0"),
563            _ => {
564                panic!(
565                    "expected integer 0 at (4,8), got {:?}",
566                    pool.get(simplified_0.value)
567                )
568            }
569        }
570
571        // Substitute (x=1, y=2) → should give nonzero
572        let mut map_off = HashMap::new();
573        map_off.insert(x, one);
574        map_off.insert(y, two);
575        let at_1_2 = subs(res_expr, &map_off, &pool);
576        let simplified_nz = crate::simplify::simplify(at_1_2, &pool);
577        if let ExprData::Integer(n) = pool.get(simplified_nz.value) {
578            assert_ne!(n.0, 0, "res at (1,2) should be non-zero");
579        } // non-integer result is also non-zero
580    }
581
582    // --- subresultant_prs ---
583
584    #[test]
585    fn sprs_sequence_length() {
586        // For coprime polynomials, PRS terminates at degree 0.
587        let p = ExprPool::new();
588        let x = p.symbol("x", Domain::Real);
589        // x^2 + 1  (irreducible over ℤ)
590        let xsq = p.pow(x, p.integer(2_i32));
591        let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
592        // x - 1
593        let poly_q = p.add(vec![x, p.integer(-1_i32)]);
594
595        let dr = subresultant_prs(poly_p, poly_q, x, &p).unwrap();
596        // Sequence starts with [p, q, ...] and ends with a constant (or empty
597        // if gcd is non-trivial).
598        let seq = &dr.value;
599        assert!(seq.len() >= 2, "sequence must have at least [p, q]");
600        // First element is p or q (may have been swapped by degree).
601        // Last element should be a constant (degree 0) for coprime polynomials.
602        let last_id = *seq.last().unwrap();
603        match p.get(last_id) {
604            ExprData::Integer(_) => {} // scalar: good
605            _ => {
606                // Try parsing as UniPoly and check degree.
607                let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
608                assert_eq!(last_poly.degree(), 0, "last PRS element should be degree 0");
609            }
610        }
611    }
612
613    #[test]
614    fn sprs_first_elements() {
615        // The first two elements of the PRS are p and q (possibly swapped).
616        let p = ExprPool::new();
617        let x = p.symbol("x", Domain::Real);
618        let two = p.integer(2_i32);
619        let xsq = p.pow(x, p.integer(2_i32));
620        // p = x^2 - 1
621        let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
622        // q = 2x - 2  (to test: gcd = x - 1)
623        let two_x = p.mul(vec![two, x]);
624        let poly_q_expr = p.add(vec![two_x, p.integer(-2_i32)]);
625
626        let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
627        assert!(dr.value.len() >= 2);
628    }
629
630    #[test]
631    fn sprs_gcd_from_sequence() {
632        // The last non-zero element of the PRS (up to content) is the GCD.
633        // gcd(x^2 - 1, x - 1) = x - 1
634        let p = ExprPool::new();
635        let x = p.symbol("x", Domain::Real);
636        let xsq = p.pow(x, p.integer(2_i32));
637        let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
638        let poly_q_expr = p.add(vec![x, p.integer(-1_i32)]);
639
640        let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
641        let seq = &dr.value;
642        assert!(seq.len() >= 2);
643        // Convert the last element to UniPoly.
644        let last_id = *seq.last().unwrap();
645        let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
646        // Should have degree 1 (matching gcd x - 1 up to scalar).
647        assert_eq!(
648            last_poly.degree(),
649            1,
650            "last PRS element should be degree-1 (matching GCD)"
651        );
652    }
653
654    #[test]
655    fn sprs_sylvester_consistency() {
656        // The resultant is the last constant element of the subresultant PRS.
657        // For x - 3 and x - 7, res = 4.
658        let p = ExprPool::new();
659        let x = p.symbol("x", Domain::Real);
660        let poly_p_expr = p.add(vec![x, p.integer(-3_i32)]);
661        let poly_q_expr = p.add(vec![x, p.integer(-7_i32)]);
662
663        let dr_prs = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
664        let dr_res = resultant(poly_p_expr, poly_q_expr, x, &p).unwrap();
665
666        // The resultant should match the constant at the end of the PRS.
667        let last = *dr_prs.value.last().unwrap();
668        match p.get(last) {
669            ExprData::Integer(n) => {
670                let res_n = match p.get(dr_res.value) {
671                    ExprData::Integer(m) => m.0.clone(),
672                    _ => panic!("resultant not integer"),
673                };
674                // They should match up to sign.
675                assert_eq!(n.0.clone().abs(), res_n.abs());
676            }
677            _ => {
678                // Degree-0 polynomial stored as a mul/add — tolerate this form.
679            }
680        }
681    }
682
683    // --- error cases ---
684
685    #[test]
686    fn resultant_non_polynomial_error() {
687        let p = ExprPool::new();
688        let x = p.symbol("x", Domain::Real);
689        // sin(x) is not a polynomial
690        let sin_x = p.func("sin", vec![x]);
691        let poly_q = p.add(vec![x, p.integer(-1_i32)]);
692        let err = resultant(sin_x, poly_q, x, &p);
693        assert!(
694            matches!(err, Err(ResultantError::NotAPolynomial(_))),
695            "expected NotAPolynomial error"
696        );
697    }
698
699    #[test]
700    fn subresultant_prs_non_polynomial_error() {
701        let p = ExprPool::new();
702        let x = p.symbol("x", Domain::Real);
703        let y = p.symbol("y", Domain::Real);
704        // y appears as a free variable — not polynomial in x alone
705        let poly_p = p.add(vec![x, y]);
706        let poly_q = p.add(vec![x, p.integer(-1_i32)]);
707        let err = subresultant_prs(poly_p, poly_q, x, &p);
708        assert!(
709            matches!(err, Err(ResultantError::NotAPolynomial(_))),
710            "expected NotAPolynomial error for multivariate input to subresultant_prs"
711        );
712    }
713}