1use crate::deriv::{DerivationLog, DerivedExpr, RewriteStep};
18use crate::flint::integer::FlintInteger;
19use crate::flint::mpoly::FlintMPolyCtx;
20use crate::kernel::{ExprData, ExprId, ExprPool};
21use crate::poly::error::ConversionError;
22use crate::poly::multipoly::multi_to_flint_pub;
23use crate::poly::multipoly::MultiPoly;
24use crate::poly::unipoly::UniPoly;
25use std::collections::{BTreeMap, BTreeSet};
26use std::fmt;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ResultantError {
35 NotAPolynomial(ConversionError),
38 FlintError,
40}
41
42impl From<ConversionError> for ResultantError {
43 fn from(e: ConversionError) -> Self {
44 ResultantError::NotAPolynomial(e)
45 }
46}
47
48impl fmt::Display for ResultantError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 ResultantError::NotAPolynomial(e) => write!(f, "not a polynomial: {e}"),
52 ResultantError::FlintError => {
53 write!(f, "FLINT resultant computation failed (E-RES-003)")
54 }
55 }
56 }
57}
58
59impl std::error::Error for ResultantError {}
60
61impl crate::errors::AlkahestError for ResultantError {
62 fn code(&self) -> &'static str {
63 match self {
64 ResultantError::NotAPolynomial(_) => "E-RES-001",
65 ResultantError::FlintError => "E-RES-003",
66 }
67 }
68
69 fn remediation(&self) -> Option<&'static str> {
70 match self {
71 ResultantError::NotAPolynomial(_) => Some(
72 "ensure both arguments are polynomial expressions with integer \
73 coefficients in the given variable",
74 ),
75 ResultantError::FlintError => None,
76 }
77 }
78}
79
80pub(crate) fn collect_free_vars(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
88 let mut set = BTreeSet::new();
89 collect_vars_rec(expr, pool, &mut set);
90 set.into_iter().collect()
91}
92
93fn collect_vars_rec(expr: ExprId, pool: &ExprPool, out: &mut BTreeSet<ExprId>) {
94 let children: Vec<ExprId> = pool.with(expr, |data| match data {
96 ExprData::Symbol { .. } => {
97 out.insert(expr);
98 vec![]
99 }
100 ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_) => vec![],
101 ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
102 ExprData::Pow { base, exp } => vec![*base, *exp],
103 ExprData::Func { args, .. } => args.clone(),
104 ExprData::Piecewise { branches, default } => {
105 let mut ids: Vec<ExprId> = branches.iter().flat_map(|(c, v)| [*c, *v]).collect();
106 ids.push(*default);
107 ids
108 }
109 ExprData::Predicate { args, .. } => args.clone(),
110 ExprData::Forall { var, body } | ExprData::Exists { var, body } => vec![*var, *body],
111 ExprData::BigO(arg) => vec![*arg],
112 });
113 for child in children {
114 collect_vars_rec(child, pool, out);
115 }
116}
117
118pub fn resultant(
154 p: ExprId,
155 q: ExprId,
156 var: ExprId,
157 pool: &ExprPool,
158) -> Result<DerivedExpr<ExprId>, ResultantError> {
159 let mut all: BTreeSet<ExprId> = BTreeSet::new();
161 for v in collect_free_vars(p, pool) {
162 all.insert(v);
163 }
164 for v in collect_free_vars(q, pool) {
165 all.insert(v);
166 }
167 all.insert(var);
168
169 let vars: Vec<ExprId> = all.into_iter().collect();
170 let nvars = vars.len();
171 let var_idx = vars.iter().position(|&v| v == var).unwrap();
172
173 let mp = MultiPoly::from_symbolic(p, vars.clone(), pool)?;
175 let mq = MultiPoly::from_symbolic(q, vars.clone(), pool)?;
176
177 let ctx = FlintMPolyCtx::new(nvars.max(1));
179 let fp = multi_to_flint_pub(&mp, &ctx);
180 let fq = multi_to_flint_pub(&mq, &ctx);
181
182 let fr = fp
184 .resultant(&fq, var_idx, &ctx)
185 .ok_or(ResultantError::FlintError)?;
186
187 let res_raw = fr.terms(nvars.max(1), &ctx);
189
190 let remaining_vars: Vec<ExprId> = vars
193 .iter()
194 .enumerate()
195 .filter_map(|(i, &v)| if i == var_idx { None } else { Some(v) })
196 .collect();
197
198 let mut new_terms: BTreeMap<Vec<u32>, rug::Integer> = BTreeMap::new();
199 for (exp, coeff) in res_raw {
200 let mut new_exp: Vec<u32> = exp
201 .into_iter()
202 .enumerate()
203 .filter_map(|(i, e)| if i == var_idx { None } else { Some(e) })
204 .collect();
205 while new_exp.last() == Some(&0) {
206 new_exp.pop();
207 }
208 let entry = new_terms
209 .entry(new_exp)
210 .or_insert_with(|| rug::Integer::from(0));
211 *entry += &coeff;
212 }
213 new_terms.retain(|_, v| *v != 0);
214
215 let result_mp = MultiPoly {
216 vars: remaining_vars,
217 terms: new_terms,
218 };
219 let result_expr = result_mp.to_expr(pool);
220
221 let step = RewriteStep::simple("Resultant", p, result_expr);
222 Ok(DerivedExpr::with_step(result_expr, step))
223}
224
225pub fn subresultant_prs(
254 p: ExprId,
255 q: ExprId,
256 var: ExprId,
257 pool: &ExprPool,
258) -> Result<DerivedExpr<Vec<ExprId>>, ResultantError> {
259 let mut up = UniPoly::from_symbolic(p, var, pool)?;
261 let mut uq = UniPoly::from_symbolic(q, var, pool)?;
262
263 if up.degree() < uq.degree() {
265 std::mem::swap(&mut up, &mut uq);
266 }
267
268 let prs_polys = sprs_inner(up, uq);
269
270 let exprs: Vec<ExprId> = prs_polys
272 .into_iter()
273 .map(|poly| poly.to_symbolic_expr(pool))
274 .collect();
275
276 let mut log = DerivationLog::new();
277 if let (Some(&first), Some(&last)) = (exprs.first(), exprs.last()) {
278 log.push(RewriteStep::simple("SubresultantPRS", first, last));
279 }
280 Ok(DerivedExpr::with_log(exprs, log))
281}
282
283fn sprs_inner(p: UniPoly, q: UniPoly) -> Vec<UniPoly> {
291 let var = p.var;
292 let mut sequence = vec![p.clone(), q.clone()];
293
294 if q.is_zero() {
295 return sequence;
296 }
297
298 let m = p.degree();
299 let n = q.degree();
300 if n < 0 {
301 return sequence;
302 }
303
304 let delta0 = (m - n) as u32;
306 let beta: rug::Integer = if (delta0 + 1) % 2 == 0 {
307 rug::Integer::from(1)
308 } else {
309 rug::Integer::from(-1)
310 };
311
312 let mut beta_cur = beta;
313 let mut psi_cur: rug::Integer = rug::Integer::from(-1);
314
315 let mut a = p;
316 let mut b = q;
317
318 loop {
319 if b.is_zero() {
320 break;
321 }
322
323 let deg_a = a.degree();
324 let deg_b = b.degree();
325 if deg_b < 0 {
326 break;
327 }
328 let delta = (deg_a - deg_b) as u32;
329
330 let (_, r_flint, _d) = a.coeffs.pseudo_divrem(&b.coeffs);
332 if r_flint.is_zero() {
333 break;
334 }
335
336 let beta_fi = FlintInteger::from_rug(&beta_cur);
338 let c_coeffs = r_flint.scalar_divexact_fmpz(&beta_fi);
339 let c = UniPoly {
340 var,
341 coeffs: c_coeffs,
342 };
343 sequence.push(c.clone());
344
345 let lc_b_fmpz = b.coeffs.leading_coeff_fmpz();
347 let lc_b = lc_b_fmpz.to_rug();
348 let neg_lc_b: rug::Integer = -lc_b;
349
350 let psi_new = if delta <= 1 {
351 rug_pow(&neg_lc_b, delta)
353 } else {
354 let num = rug_pow(&neg_lc_b, delta);
355 let den = rug_pow(&psi_cur, delta - 1);
356 rug::Integer::from(num.div_exact_ref(&den))
357 };
358
359 let beta_new = neg_lc_b * &psi_new;
361
362 a = b;
363 b = c;
364 psi_cur = psi_new;
365 beta_cur = beta_new;
366 }
367
368 sequence
369}
370
371fn rug_pow(base: &rug::Integer, exp: u32) -> rug::Integer {
373 if exp == 0 {
374 return rug::Integer::from(1);
375 }
376 let mut r = base.clone();
377 for _ in 1..exp {
378 r *= base;
379 }
380 r
381}
382
383#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::kernel::{Domain, ExprPool};
391
392 fn pool_xy() -> (ExprPool, ExprId, ExprId) {
393 let p = ExprPool::new();
394 let x = p.symbol("x", Domain::Real);
395 let y = p.symbol("y", Domain::Real);
396 (p, x, y)
397 }
398
399 #[test]
402 fn free_vars_constant() {
403 let p = ExprPool::new();
404 let five = p.integer(5_i32);
405 let vars = collect_free_vars(five, &p);
406 assert!(vars.is_empty());
407 }
408
409 #[test]
410 fn free_vars_symbol() {
411 let p = ExprPool::new();
412 let x = p.symbol("x", Domain::Real);
413 let vars = collect_free_vars(x, &p);
414 assert_eq!(vars, vec![x]);
415 }
416
417 #[test]
418 fn free_vars_polynomial() {
419 let (p, x, y) = pool_xy();
420 let xsq = p.pow(x, p.integer(2_i32));
422 let expr = p.add(vec![xsq, y, p.integer(-1_i32)]);
423 let vars = collect_free_vars(expr, &p);
424 assert_eq!(vars.len(), 2);
425 assert!(vars.contains(&x));
426 assert!(vars.contains(&y));
427 }
428
429 #[test]
432 fn resultant_common_root() {
433 let p = ExprPool::new();
435 let x = p.symbol("x", Domain::Real);
436 let xsq = p.pow(x, p.integer(2_i32));
438 let five_x = p.mul(vec![p.integer(-5_i32), x]);
439 let poly_p = p.add(vec![xsq, five_x, p.integer(6_i32)]);
440 let poly_q = p.add(vec![x, p.integer(-2_i32)]);
442
443 let dr = resultant(poly_p, poly_q, x, &p).unwrap();
444 match p.get(dr.value) {
446 ExprData::Integer(n) => assert_eq!(n.0, 0),
447 _ => panic!("expected integer 0, got {:?}", p.get(dr.value)),
448 }
449 assert_eq!(dr.log.len(), 1);
451 assert_eq!(dr.log.steps()[0].rule_name, "Resultant");
452 }
453
454 #[test]
455 fn resultant_coprime() {
456 let p = ExprPool::new();
459 let x = p.symbol("x", Domain::Real);
460 let xsq = p.pow(x, p.integer(2_i32));
462 let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
463 let poly_q = p.add(vec![x, p.integer(-1_i32)]);
465 let dr = resultant(poly_p, poly_q, x, &p).unwrap();
466 match p.get(dr.value) {
467 ExprData::Integer(n) => assert_eq!(n.0, 2),
468 _ => panic!("expected integer 2, got {:?}", p.get(dr.value)),
469 }
470 }
471
472 #[test]
473 fn resultant_linear_linear() {
474 let p = ExprPool::new();
477 let x = p.symbol("x", Domain::Real);
478 let poly_p = p.add(vec![x, p.integer(-3_i32)]);
479 let poly_q = p.add(vec![x, p.integer(-7_i32)]);
480 let dr = resultant(poly_p, poly_q, x, &p).unwrap();
481 match p.get(dr.value) {
482 ExprData::Integer(n) => {
483 assert_eq!(
485 n.0.clone().abs(),
486 rug::Integer::from(4),
487 "magnitude should be 4"
488 );
489 }
490 _ => panic!("expected integer, got {:?}", p.get(dr.value)),
491 }
492 }
493
494 #[test]
497 fn resultant_bivariate_eliminates_var() {
498 let (p, x, y) = pool_xy();
501
502 let xsq = p.pow(x, p.integer(2_i32));
504 let ysq = p.pow(y, p.integer(2_i32));
505 let circle = p.add(vec![xsq, ysq, p.integer(-1_i32)]);
506
507 let line = p.add(vec![y, p.mul(vec![p.integer(-1_i32), x])]);
509
510 let dr = resultant(circle, line, y, &p).unwrap();
511 let res_expr = dr.value;
512
513 let res_poly = UniPoly::from_symbolic(res_expr, x, &p).unwrap();
516 assert_eq!(res_poly.degree(), 2, "expected degree-2 resultant in x");
517 let coeffs = res_poly.coefficients_i64();
519 assert_eq!(coeffs[0], -1, "constant term should be -1");
520 assert_eq!(coeffs[2], 2, "leading coefficient should be 2");
521 }
522
523 #[test]
526 fn resultant_implicitization_twisted_cubic() {
527 let pool = ExprPool::new();
530 let t = pool.symbol("t", Domain::Real);
531 let x = pool.symbol("x", Domain::Real);
532 let y = pool.symbol("y", Domain::Real);
533
534 let t2 = pool.pow(t, pool.integer(2_i32));
536 let p1 = pool.add(vec![x, pool.mul(vec![pool.integer(-1_i32), t2])]);
537
538 let t3 = pool.pow(t, pool.integer(3_i32));
540 let p2 = pool.add(vec![y, pool.mul(vec![pool.integer(-1_i32), t3])]);
541
542 let dr = resultant(p1, p2, t, &pool).unwrap();
543 let res_expr = dr.value;
544
545 use crate::kernel::subs;
549 use std::collections::HashMap;
550 let one = pool.integer(1_i32);
551 let two = pool.integer(2_i32);
552 let four = pool.integer(4_i32);
553 let eight = pool.integer(8_i32);
554
555 let mut map_on = HashMap::new();
557 map_on.insert(x, four);
558 map_on.insert(y, eight);
559 let at_4_8 = subs(res_expr, &map_on, &pool);
560 let simplified_0 = crate::simplify::simplify(at_4_8, &pool);
561 match pool.get(simplified_0.value) {
562 ExprData::Integer(n) => assert_eq!(n.0, 0, "res at (4,8) should be 0"),
563 _ => {
564 panic!(
565 "expected integer 0 at (4,8), got {:?}",
566 pool.get(simplified_0.value)
567 )
568 }
569 }
570
571 let mut map_off = HashMap::new();
573 map_off.insert(x, one);
574 map_off.insert(y, two);
575 let at_1_2 = subs(res_expr, &map_off, &pool);
576 let simplified_nz = crate::simplify::simplify(at_1_2, &pool);
577 if let ExprData::Integer(n) = pool.get(simplified_nz.value) {
578 assert_ne!(n.0, 0, "res at (1,2) should be non-zero");
579 } }
581
582 #[test]
585 fn sprs_sequence_length() {
586 let p = ExprPool::new();
588 let x = p.symbol("x", Domain::Real);
589 let xsq = p.pow(x, p.integer(2_i32));
591 let poly_p = p.add(vec![xsq, p.integer(1_i32)]);
592 let poly_q = p.add(vec![x, p.integer(-1_i32)]);
594
595 let dr = subresultant_prs(poly_p, poly_q, x, &p).unwrap();
596 let seq = &dr.value;
599 assert!(seq.len() >= 2, "sequence must have at least [p, q]");
600 let last_id = *seq.last().unwrap();
603 match p.get(last_id) {
604 ExprData::Integer(_) => {} _ => {
606 let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
608 assert_eq!(last_poly.degree(), 0, "last PRS element should be degree 0");
609 }
610 }
611 }
612
613 #[test]
614 fn sprs_first_elements() {
615 let p = ExprPool::new();
617 let x = p.symbol("x", Domain::Real);
618 let two = p.integer(2_i32);
619 let xsq = p.pow(x, p.integer(2_i32));
620 let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
622 let two_x = p.mul(vec![two, x]);
624 let poly_q_expr = p.add(vec![two_x, p.integer(-2_i32)]);
625
626 let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
627 assert!(dr.value.len() >= 2);
628 }
629
630 #[test]
631 fn sprs_gcd_from_sequence() {
632 let p = ExprPool::new();
635 let x = p.symbol("x", Domain::Real);
636 let xsq = p.pow(x, p.integer(2_i32));
637 let poly_p_expr = p.add(vec![xsq, p.integer(-1_i32)]);
638 let poly_q_expr = p.add(vec![x, p.integer(-1_i32)]);
639
640 let dr = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
641 let seq = &dr.value;
642 assert!(seq.len() >= 2);
643 let last_id = *seq.last().unwrap();
645 let last_poly = UniPoly::from_symbolic(last_id, x, &p).unwrap();
646 assert_eq!(
648 last_poly.degree(),
649 1,
650 "last PRS element should be degree-1 (matching GCD)"
651 );
652 }
653
654 #[test]
655 fn sprs_sylvester_consistency() {
656 let p = ExprPool::new();
659 let x = p.symbol("x", Domain::Real);
660 let poly_p_expr = p.add(vec![x, p.integer(-3_i32)]);
661 let poly_q_expr = p.add(vec![x, p.integer(-7_i32)]);
662
663 let dr_prs = subresultant_prs(poly_p_expr, poly_q_expr, x, &p).unwrap();
664 let dr_res = resultant(poly_p_expr, poly_q_expr, x, &p).unwrap();
665
666 let last = *dr_prs.value.last().unwrap();
668 match p.get(last) {
669 ExprData::Integer(n) => {
670 let res_n = match p.get(dr_res.value) {
671 ExprData::Integer(m) => m.0.clone(),
672 _ => panic!("resultant not integer"),
673 };
674 assert_eq!(n.0.clone().abs(), res_n.abs());
676 }
677 _ => {
678 }
680 }
681 }
682
683 #[test]
686 fn resultant_non_polynomial_error() {
687 let p = ExprPool::new();
688 let x = p.symbol("x", Domain::Real);
689 let sin_x = p.func("sin", vec![x]);
691 let poly_q = p.add(vec![x, p.integer(-1_i32)]);
692 let err = resultant(sin_x, poly_q, x, &p);
693 assert!(
694 matches!(err, Err(ResultantError::NotAPolynomial(_))),
695 "expected NotAPolynomial error"
696 );
697 }
698
699 #[test]
700 fn subresultant_prs_non_polynomial_error() {
701 let p = ExprPool::new();
702 let x = p.symbol("x", Domain::Real);
703 let y = p.symbol("y", Domain::Real);
704 let poly_p = p.add(vec![x, y]);
706 let poly_q = p.add(vec![x, p.integer(-1_i32)]);
707 let err = subresultant_prs(poly_p, poly_q, x, &p);
708 assert!(
709 matches!(err, Err(ResultantError::NotAPolynomial(_))),
710 "expected NotAPolynomial error for multivariate input to subresultant_prs"
711 );
712 }
713}