use crate::deriv::{DerivationLog, DerivedExpr, RewriteStep};
use crate::flint::integer::FlintInteger;
use crate::flint::mpoly::FlintMPolyCtx;
use crate::kernel::{ExprData, ExprId, ExprPool};
use crate::poly::error::ConversionError;
use crate::poly::multipoly::multi_to_flint_pub;
use crate::poly::multipoly::MultiPoly;
use crate::poly::unipoly::UniPoly;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResultantError {
NotAPolynomial(ConversionError),
FlintError,
}
impl From<ConversionError> for ResultantError {
fn from(e: ConversionError) -> Self {
ResultantError::NotAPolynomial(e)
}
}
impl fmt::Display for ResultantError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ResultantError::NotAPolynomial(e) => write!(f, "not a polynomial: {e}"),
ResultantError::FlintError => {
write!(f, "FLINT resultant computation failed (E-RES-003)")
}
}
}
}
impl std::error::Error for ResultantError {}
impl crate::errors::AlkahestError for ResultantError {
fn code(&self) -> &'static str {
match self {
ResultantError::NotAPolynomial(_) => "E-RES-001",
ResultantError::FlintError => "E-RES-003",
}
}
fn remediation(&self) -> Option<&'static str> {
match self {
ResultantError::NotAPolynomial(_) => Some(
"ensure both arguments are polynomial expressions with integer \
coefficients in the given variable",
),
ResultantError::FlintError => None,
}
}
}
pub(crate) fn collect_free_vars(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
let mut set = BTreeSet::new();
collect_vars_rec(expr, pool, &mut set);
set.into_iter().collect()
}
fn collect_vars_rec(expr: ExprId, pool: &ExprPool, out: &mut BTreeSet<ExprId>) {
let children: Vec<ExprId> = pool.with(expr, |data| match data {
ExprData::Symbol { .. } => {
out.insert(expr);
vec![]
}
ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_) => vec![],
ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
ExprData::Pow { base, exp } => vec![*base, *exp],
ExprData::Func { args, .. } => args.clone(),
ExprData::Piecewise { branches, default } => {
let mut ids: Vec<ExprId> = branches.iter().flat_map(|(c, v)| [*c, *v]).collect();
ids.push(*default);
ids
}
ExprData::Predicate { args, .. } => args.clone(),
ExprData::Forall { var, body } | ExprData::Exists { var, body } => vec![*var, *body],
ExprData::BigO(arg) => vec![*arg],
});
for child in children {
collect_vars_rec(child, pool, out);
}
}
pub fn resultant(
p: ExprId,
q: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<ExprId>, ResultantError> {
let mut all: BTreeSet<ExprId> = BTreeSet::new();
for v in collect_free_vars(p, pool) {
all.insert(v);
}
for v in collect_free_vars(q, pool) {
all.insert(v);
}
all.insert(var);
let vars: Vec<ExprId> = all.into_iter().collect();
let nvars = vars.len();
let var_idx = vars.iter().position(|&v| v == var).unwrap();
let mp = MultiPoly::from_symbolic(p, vars.clone(), pool)?;
let mq = MultiPoly::from_symbolic(q, vars.clone(), pool)?;
let ctx = FlintMPolyCtx::new(nvars.max(1));
let fp = multi_to_flint_pub(&mp, &ctx);
let fq = multi_to_flint_pub(&mq, &ctx);
let fr = fp
.resultant(&fq, var_idx, &ctx)
.ok_or(ResultantError::FlintError)?;
let res_raw = fr.terms(nvars.max(1), &ctx);
let remaining_vars: Vec<ExprId> = vars
.iter()
.enumerate()
.filter_map(|(i, &v)| if i == var_idx { None } else { Some(v) })
.collect();
let mut new_terms: BTreeMap<Vec<u32>, rug::Integer> = BTreeMap::new();
for (exp, coeff) in res_raw {
let mut new_exp: Vec<u32> = exp
.into_iter()
.enumerate()
.filter_map(|(i, e)| if i == var_idx { None } else { Some(e) })
.collect();
while new_exp.last() == Some(&0) {
new_exp.pop();
}
let entry = new_terms
.entry(new_exp)
.or_insert_with(|| rug::Integer::from(0));
*entry += &coeff;
}
new_terms.retain(|_, v| *v != 0);
let result_mp = MultiPoly {
vars: remaining_vars,
terms: new_terms,
};
let result_expr = result_mp.to_expr(pool);
let step = RewriteStep::simple("Resultant", p, result_expr);
Ok(DerivedExpr::with_step(result_expr, step))
}
pub fn subresultant_prs(
p: ExprId,
q: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<Vec<ExprId>>, ResultantError> {
let mut up = UniPoly::from_symbolic(p, var, pool)?;
let mut uq = UniPoly::from_symbolic(q, var, pool)?;
if up.degree() < uq.degree() {
std::mem::swap(&mut up, &mut uq);
}
let prs_polys = sprs_inner(up, uq);
let exprs: Vec<ExprId> = prs_polys
.into_iter()
.map(|poly| poly.to_symbolic_expr(pool))
.collect();
let mut log = DerivationLog::new();
if let (Some(&first), Some(&last)) = (exprs.first(), exprs.last()) {
log.push(RewriteStep::simple("SubresultantPRS", first, last));
}
Ok(DerivedExpr::with_log(exprs, log))
}
fn sprs_inner(p: UniPoly, q: UniPoly) -> Vec<UniPoly> {
let var = p.var;
let mut sequence = vec![p.clone(), q.clone()];
if q.is_zero() {
return sequence;
}
let m = p.degree();
let n = q.degree();
if n < 0 {
return sequence;
}
let delta0 = (m - n) as u32;
let beta: rug::Integer = if (delta0 + 1) % 2 == 0 {
rug::Integer::from(1)
} else {
rug::Integer::from(-1)
};
let mut beta_cur = beta;
let mut psi_cur: rug::Integer = rug::Integer::from(-1);
let mut a = p;
let mut b = q;
loop {
if b.is_zero() {
break;
}
let deg_a = a.degree();
let deg_b = b.degree();
if deg_b < 0 {
break;
}
let delta = (deg_a - deg_b) as u32;
let (_, r_flint, _d) = a.coeffs.pseudo_divrem(&b.coeffs);
if r_flint.is_zero() {
break;
}
let beta_fi = FlintInteger::from_rug(&beta_cur);
let c_coeffs = r_flint.scalar_divexact_fmpz(&beta_fi);
let c = UniPoly {
var,
coeffs: c_coeffs,
};
sequence.push(c.clone());
let lc_b_fmpz = b.coeffs.leading_coeff_fmpz();
let lc_b = lc_b_fmpz.to_rug();
let neg_lc_b: rug::Integer = -lc_b;
let psi_new = if delta <= 1 {
rug_pow(&neg_lc_b, delta)
} else {
let num = rug_pow(&neg_lc_b, delta);
let den = rug_pow(&psi_cur, delta - 1);
rug::Integer::from(num.div_exact_ref(&den))
};
let beta_new = neg_lc_b * &psi_new;
a = b;
b = c;
psi_cur = psi_new;
beta_cur = beta_new;
}
sequence
}
fn rug_pow(base: &rug::Integer, exp: u32) -> rug::Integer {
if exp == 0 {
return rug::Integer::from(1);
}
let mut r = base.clone();
for _ in 1..exp {
r *= base;
}
r
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool_xy() -> (ExprPool, ExprId, ExprId) {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
(p, x, y)
}
#[test]
fn free_vars_constant() {
let p = ExprPool::new();
let five = p.integer(5_i32);
let vars = collect_free_vars(five, &p);
assert!(vars.is_empty());
}
#[test]
fn free_vars_symbol() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let vars = collect_free_vars(x, &p);
assert_eq!(vars, vec![x]);
}
#[test]
fn free_vars_polynomial() {
let (p, x, y) = pool_xy();
let xsq = p.pow(x, p.integer(2_i32));
let expr = p.add(vec![xsq, y, p.integer(-1_i32)]);
let vars = collect_free_vars(expr, &p);
assert_eq!(vars.len(), 2);
assert!(vars.contains(&x));
assert!(vars.contains(&y));
}
#[test]
fn resultant_common_root() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let xsq = p.pow(x, p.integer(2_i32));
let five_x = p.mul(vec![p.integer(-5_i32), x]);
let poly_p = p.add(vec![xsq, five_x, p.integer(6_i32)]);
let poly_q = p.add(vec![x, p.integer(-2_i32)]);
let dr = resultant(poly_p, poly_q, x, &p).unwrap();
match p.get(dr.value) {
ExprData::Integer(n) => assert_eq!(n.0, 0),
_ => panic!("expected integer 0, got {:?}", p.get(dr.value)),
}
assert_eq!(dr.log.len(), 1);
assert_eq!(dr.log.steps()[0].rule_name, "Resultant");
}
#[test]
fn resultant_coprime() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let xsq = p.pow(x, p.integer(2_i32));
let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
let poly_q = p.add(vec![x, p.integer(-1_i32)]);
let dr = resultant(poly_p, poly_q, x, &p).unwrap();
match p.get(dr.value) {
ExprData::Integer(n) => assert_eq!(n.0, 2),
_ => panic!("expected integer 2, got {:?}", p.get(dr.value)),
}
}
#[test]
fn resultant_linear_linear() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let poly_p = p.add(vec![x, p.integer(-3_i32)]);
let poly_q = p.add(vec![x, p.integer(-7_i32)]);
let dr = resultant(poly_p, poly_q, x, &p).unwrap();
match p.get(dr.value) {
ExprData::Integer(n) => {
assert_eq!(
n.0.clone().abs(),
rug::Integer::from(4),
"magnitude should be 4"
);
}
_ => panic!("expected integer, got {:?}", p.get(dr.value)),
}
}
#[test]
fn resultant_bivariate_eliminates_var() {
let (p, x, y) = pool_xy();
let xsq = p.pow(x, p.integer(2_i32));
let ysq = p.pow(y, p.integer(2_i32));
let circle = p.add(vec![xsq, ysq, p.integer(-1_i32)]);
let line = p.add(vec![y, p.mul(vec![p.integer(-1_i32), x])]);
let dr = resultant(circle, line, y, &p).unwrap();
let res_expr = dr.value;
let res_poly = UniPoly::from_symbolic(res_expr, x, &p).unwrap();
assert_eq!(res_poly.degree(), 2, "expected degree-2 resultant in x");
let coeffs = res_poly.coefficients_i64();
assert_eq!(coeffs[0], -1, "constant term should be -1");
assert_eq!(coeffs[2], 2, "leading coefficient should be 2");
}
#[test]
fn resultant_implicitization_twisted_cubic() {
let pool = ExprPool::new();
let t = pool.symbol("t", Domain::Real);
let x = pool.symbol("x", Domain::Real);
let y = pool.symbol("y", Domain::Real);
let t2 = pool.pow(t, pool.integer(2_i32));
let p1 = pool.add(vec![x, pool.mul(vec![pool.integer(-1_i32), t2])]);
let t3 = pool.pow(t, pool.integer(3_i32));
let p2 = pool.add(vec![y, pool.mul(vec![pool.integer(-1_i32), t3])]);
let dr = resultant(p1, p2, t, &pool).unwrap();
let res_expr = dr.value;
use crate::kernel::subs;
use std::collections::HashMap;
let one = pool.integer(1_i32);
let two = pool.integer(2_i32);
let four = pool.integer(4_i32);
let eight = pool.integer(8_i32);
let mut map_on = HashMap::new();
map_on.insert(x, four);
map_on.insert(y, eight);
let at_4_8 = subs(res_expr, &map_on, &pool);
let simplified_0 = crate::simplify::simplify(at_4_8, &pool);
match pool.get(simplified_0.value) {
ExprData::Integer(n) => assert_eq!(n.0, 0, "res at (4,8) should be 0"),
_ => {
panic!(
"expected integer 0 at (4,8), got {:?}",
pool.get(simplified_0.value)
)
}
}
let mut map_off = HashMap::new();
map_off.insert(x, one);
map_off.insert(y, two);
let at_1_2 = subs(res_expr, &map_off, &pool);
let simplified_nz = crate::simplify::simplify(at_1_2, &pool);
if let ExprData::Integer(n) = pool.get(simplified_nz.value) {
assert_ne!(n.0, 0, "res at (1,2) should be non-zero");
} }
#[test]
fn sprs_sequence_length() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let xsq = p.pow(x, p.integer(2_i32));
let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
let poly_q = p.add(vec![x, p.integer(-1_i32)]);
let dr = subresultant_prs(poly_p, poly_q, x, &p).unwrap();
let seq = &dr.value;
assert!(seq.len() >= 2, "sequence must have at least [p, q]");
let last_id = *seq.last().unwrap();
match p.get(last_id) {
ExprData::Integer(_) => {} _ => {
let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
assert_eq!(last_poly.degree(), 0, "last PRS element should be degree 0");
}
}
}
#[test]
fn sprs_first_elements() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let two = p.integer(2_i32);
let xsq = p.pow(x, p.integer(2_i32));
let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
let two_x = p.mul(vec![two, x]);
let poly_q_expr = p.add(vec![two_x, p.integer(-2_i32)]);
let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
assert!(dr.value.len() >= 2);
}
#[test]
fn sprs_gcd_from_sequence() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let xsq = p.pow(x, p.integer(2_i32));
let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
let poly_q_expr = p.add(vec![x, p.integer(-1_i32)]);
let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
let seq = &dr.value;
assert!(seq.len() >= 2);
let last_id = *seq.last().unwrap();
let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
assert_eq!(
last_poly.degree(),
1,
"last PRS element should be degree-1 (matching GCD)"
);
}
#[test]
fn sprs_sylvester_consistency() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let poly_p_expr = p.add(vec![x, p.integer(-3_i32)]);
let poly_q_expr = p.add(vec![x, p.integer(-7_i32)]);
let dr_prs = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
let dr_res = resultant(poly_p_expr, poly_q_expr, x, &p).unwrap();
let last = *dr_prs.value.last().unwrap();
match p.get(last) {
ExprData::Integer(n) => {
let res_n = match p.get(dr_res.value) {
ExprData::Integer(m) => m.0.clone(),
_ => panic!("resultant not integer"),
};
assert_eq!(n.0.clone().abs(), res_n.abs());
}
_ => {
}
}
}
#[test]
fn resultant_non_polynomial_error() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let sin_x = p.func("sin", vec![x]);
let poly_q = p.add(vec![x, p.integer(-1_i32)]);
let err = resultant(sin_x, poly_q, x, &p);
assert!(
matches!(err, Err(ResultantError::NotAPolynomial(_))),
"expected NotAPolynomial error"
);
}
#[test]
fn subresultant_prs_non_polynomial_error() {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
let poly_p = p.add(vec![x, y]);
let poly_q = p.add(vec![x, p.integer(-1_i32)]);
let err = subresultant_prs(poly_p, poly_q, x, &p);
assert!(
matches!(err, Err(ResultantError::NotAPolynomial(_))),
"expected NotAPolynomial error for multivariate input to subresultant_prs"
);
}
}