1pub mod diophantine;
24pub mod homotopy;
25pub mod regular_chains;
26
27pub use regular_chains::{
28 extract_regular_chain_from_basis, main_variable_recursive, triangularize, RegularChain,
29};
30
31pub use homotopy::{solve_numerical, CertifiedPoint, HomotopyError, HomotopyOpts};
32
33pub use diophantine::{diophantine, DiophantineError, DiophantineSolution};
34
35use crate::errors::AlkahestError;
36use crate::kernel::{ExprData, ExprId, ExprPool};
37use crate::poly::groebner::{GbPoly, GroebnerBasis, MonomialOrder};
38use rug::{ops::NegAssign, Rational};
39use std::collections::BTreeMap;
40use std::fmt;
41
42pub type Solution = Vec<ExprId>;
50
51pub enum SolutionSet {
53 Finite(Vec<Solution>),
55 Parametric(GroebnerBasis),
57 NoSolution,
59}
60
61#[derive(Debug, Clone)]
63pub enum SolverError {
64 NotPolynomial(String),
66 HighDegree(usize),
69 ShapeMismatch,
71}
72
73impl fmt::Display for SolverError {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 SolverError::NotPolynomial(s) => write!(f, "not a polynomial: {s}"),
77 SolverError::HighDegree(d) => write!(
78 f,
79 "back-substitution requires solving a degree-{d} univariate polynomial \
80 (only degree ≤ 2 is currently supported)"
81 ),
82 SolverError::ShapeMismatch => write!(
83 f,
84 "number of equations must equal number of variables for zero-dimensional solving"
85 ),
86 }
87 }
88}
89
90impl std::error::Error for SolverError {}
91
92impl AlkahestError for SolverError {
93 fn code(&self) -> &'static str {
94 match self {
95 SolverError::NotPolynomial(_) => "E-SOLVE-001",
96 SolverError::HighDegree(_) => "E-SOLVE-002",
97 SolverError::ShapeMismatch => "E-SOLVE-003",
98 }
99 }
100
101 fn remediation(&self) -> Option<&'static str> {
102 match self {
103 SolverError::NotPolynomial(_) => Some(
104 "ensure all equations are polynomial in the declared variables; \
105 transcendental functions are not supported",
106 ),
107 SolverError::HighDegree(_) => Some(
108 "degree > 2 univariate solving is not yet implemented; \
109 the Gröbner basis is still returned for manual inspection",
110 ),
111 SolverError::ShapeMismatch => {
112 Some("provide one equation per variable for zero-dimensional system solving")
113 }
114 }
115 }
116}
117
118pub fn expr_to_gbpoly(
126 expr: ExprId,
127 vars: &[ExprId],
128 pool: &ExprPool,
129) -> Result<GbPoly, SolverError> {
130 let n = vars.len();
131 expr_to_gbpoly_rec(expr, vars, n, pool)
132}
133
134fn expr_to_gbpoly_rec(
135 expr: ExprId,
136 vars: &[ExprId],
137 n_vars: usize,
138 pool: &ExprPool,
139) -> Result<GbPoly, SolverError> {
140 if let Some(idx) = vars.iter().position(|&v| v == expr) {
141 let mut exp = vec![0u32; n_vars];
142 exp[idx] = 1;
143 let mut terms = BTreeMap::new();
144 terms.insert(exp, rug::Rational::from(1));
145 return Ok(GbPoly { terms, n_vars });
146 }
147
148 enum Node {
149 Var(usize),
150 IntConst(rug::Integer),
151 RatConst(Rational),
152 FloatConst(f64),
153 FreeSymbol(String),
154 Add(Vec<ExprId>),
155 Mul(Vec<ExprId>),
156 Pow(ExprId, ExprId),
157 Func(String),
158 Other,
159 }
160
161 let node = pool.with(expr, |data| match data {
162 ExprData::Integer(n) => Node::IntConst(n.0.clone()),
163 ExprData::Rational(r) => Node::RatConst(r.0.clone()),
164 ExprData::Float(f) => Node::FloatConst(f.inner.to_f64()),
165 ExprData::Symbol { name, .. } => {
166 if let Some(idx) = vars.iter().position(|&v| v == expr) {
167 Node::Var(idx)
168 } else {
169 Node::FreeSymbol(name.clone())
170 }
171 }
172 ExprData::Add(args) => Node::Add(args.clone()),
173 ExprData::Mul(args) => Node::Mul(args.clone()),
174 ExprData::Pow { base, exp } => Node::Pow(*base, *exp),
175 ExprData::Func { name, .. } => Node::Func(name.clone()),
176 _ => Node::Other,
177 });
178
179 match node {
180 Node::Var(idx) => {
181 let mut exp = vec![0u32; n_vars];
182 exp[idx] = 1;
183 let mut terms = BTreeMap::new();
184 terms.insert(exp, Rational::from(1));
185 Ok(GbPoly { terms, n_vars })
186 }
187 Node::IntConst(n) => Ok(GbPoly::constant(Rational::from(n), n_vars)),
188 Node::RatConst(r) => Ok(GbPoly::constant(r, n_vars)),
189 Node::FloatConst(v) => {
190 let r = Rational::from_f64(v).unwrap_or_else(|| Rational::from(0));
191 Ok(GbPoly::constant(r, n_vars))
192 }
193 Node::FreeSymbol(name) => Err(SolverError::NotPolynomial(format!(
194 "free symbol '{name}' not in variable list"
195 ))),
196 Node::Add(args) => {
197 let mut result = GbPoly::zero(n_vars);
198 for a in args {
199 let p = expr_to_gbpoly_rec(a, vars, n_vars, pool)?;
200 result = result.add(&p);
201 }
202 Ok(result)
203 }
204 Node::Mul(args) => {
205 let mut result = GbPoly::constant(Rational::from(1), n_vars);
206 for a in args {
207 let p = expr_to_gbpoly_rec(a, vars, n_vars, pool)?;
208 result = result.mul(&p);
209 }
210 Ok(result)
211 }
212 Node::Pow(base, exp_id) => {
213 let exp_node = pool.with(exp_id, |d| match d {
214 ExprData::Integer(n) => Some(n.0.clone()),
215 _ => None,
216 });
217 match exp_node {
218 Some(n) => {
219 let n_val = n.to_i64().unwrap_or(-1);
220 if n_val < 0 {
221 return Err(SolverError::NotPolynomial(format!(
222 "negative exponent {n_val} in polynomial"
223 )));
224 }
225 let base_poly = expr_to_gbpoly_rec(base, vars, n_vars, pool)?;
226 let mut result = GbPoly::constant(Rational::from(1), n_vars);
227 let mut cur = base_poly;
228 let mut rem = n_val as u64;
229 while rem > 0 {
230 if rem & 1 == 1 {
231 result = result.mul(&cur);
232 }
233 let cur2 = cur.clone();
234 cur = cur.mul(&cur2);
235 rem >>= 1;
236 }
237 Ok(result)
238 }
239 None => Err(SolverError::NotPolynomial(
240 "symbolic or non-integer exponent".to_string(),
241 )),
242 }
243 }
244 Node::Func(name) => Err(SolverError::NotPolynomial(format!(
245 "function '{name}' is not a polynomial"
246 ))),
247 Node::Other => Err(SolverError::NotPolynomial(
248 "unsupported expression node".to_string(),
249 )),
250 }
251}
252
253fn rational_to_expr(r: &Rational, pool: &ExprPool) -> ExprId {
258 let (num, den) = r.clone().into_numer_denom();
259 if den == 1 {
260 pool.integer(num)
261 } else {
262 pool.rational(num, den)
263 }
264}
265
266fn neg_expr(e: ExprId, pool: &ExprPool) -> ExprId {
267 let neg_one = pool.integer(rug::Integer::from(-1));
268 pool.mul(vec![neg_one, e])
269}
270
271fn div_expr(num: ExprId, den: ExprId, pool: &ExprPool) -> ExprId {
272 let neg_one = pool.integer(rug::Integer::from(-1));
274 let inv_den = pool.pow(den, neg_one);
275 pool.mul(vec![num, inv_den])
276}
277
278fn is_syntactic_zero(e: ExprId, pool: &ExprPool) -> bool {
280 pool.with(e, |d| matches!(d, ExprData::Integer(n) if n.0 == 0))
281}
282
283fn extract_coeff_in_var(
290 poly: &GbPoly,
291 var_idx: usize,
292 k: u32,
293 vars: &[ExprId],
294 assigned: &[Option<ExprId>],
295 pool: &ExprPool,
296) -> ExprId {
297 let mut sum_terms: Vec<ExprId> = Vec::new();
298 for (exp, coeff) in &poly.terms {
299 let e_k = exp.get(var_idx).copied().unwrap_or(0);
300 if e_k != k {
301 continue;
302 }
303 let mut factors: Vec<ExprId> = Vec::new();
304 if *coeff != 1 {
305 factors.push(rational_to_expr(coeff, pool));
306 }
307 for (i, &e) in exp.iter().enumerate() {
308 if i == var_idx || e == 0 {
309 continue;
310 }
311 let base = assigned
312 .get(i)
313 .and_then(|o| o.as_ref())
314 .copied()
315 .unwrap_or(vars[i]);
316 if e == 1 {
317 factors.push(base);
318 } else {
319 let exp_id = pool.integer(rug::Integer::from(e));
320 factors.push(pool.pow(base, exp_id));
321 }
322 }
323 let term = match factors.len() {
324 0 => pool.integer(rug::Integer::from(1)),
325 1 => factors[0],
326 _ => pool.mul(factors),
327 };
328 let signed = if *coeff == 1 {
330 term
331 } else {
332 term
334 };
335 sum_terms.push(signed);
336 }
337 match sum_terms.len() {
338 0 => pool.integer(rug::Integer::from(0)),
339 1 => sum_terms[0],
340 _ => pool.add(sum_terms),
341 }
342}
343
344fn solve_univariate_symbolic(
355 coeffs: &[ExprId],
356 pool: &ExprPool,
357) -> Result<Vec<ExprId>, SolverError> {
358 let mut degree = 0usize;
359 for (i, &c) in coeffs.iter().enumerate() {
360 if !is_syntactic_zero(c, pool) {
361 degree = i;
362 }
363 }
364 match degree {
365 0 => {
366 if coeffs.is_empty() || is_syntactic_zero(coeffs[0], pool) {
371 Ok(vec![])
372 } else {
373 Ok(vec![])
374 }
375 }
376 1 => {
377 let a = coeffs[1];
378 let b = coeffs[0];
379 let neg_b = neg_expr(b, pool);
380 Ok(vec![div_expr(neg_b, a, pool)])
381 }
382 2 => {
383 let a = coeffs[2];
384 let b = coeffs[1];
385 let c = coeffs[0];
386 let two = pool.integer(rug::Integer::from(2));
387 let four = pool.integer(rug::Integer::from(4));
388 let b2 = pool.pow(b, two);
389 let four_ac = pool.mul(vec![four, a, c]);
390 let neg_four_ac = neg_expr(four_ac, pool);
391 let disc = pool.add(vec![b2, neg_four_ac]);
392 let sqrt_disc = pool.func("sqrt", vec![disc]);
393 let two_b = pool.integer(rug::Integer::from(2));
394 let two_a = pool.mul(vec![two_b, a]);
395 let neg_b = neg_expr(b, pool);
396 let root_plus = div_expr(pool.add(vec![neg_b, sqrt_disc]), two_a, pool);
397 let neg_sqrt = neg_expr(sqrt_disc, pool);
398 let root_minus = div_expr(pool.add(vec![neg_b, neg_sqrt]), two_a, pool);
399 Ok(vec![root_plus, root_minus])
400 }
401 d => Err(SolverError::HighDegree(d)),
402 }
403}
404
405#[allow(dead_code)]
409fn try_solve_univariate_rational(p: &GbPoly, var_idx: usize) -> Option<Vec<Rational>> {
410 let mut coeffs: BTreeMap<u32, Rational> = BTreeMap::new();
411 for (exp, coeff) in &p.terms {
412 let deg = exp.get(var_idx).copied().unwrap_or(0);
413 let entry = coeffs.entry(deg).or_insert_with(|| Rational::from(0));
414 *entry += coeff.clone();
415 }
416 coeffs.retain(|_, v| *v != 0);
417 let degree = coeffs.keys().max().copied().unwrap_or(0);
418 match degree {
419 0 => Some(vec![]),
420 1 => {
421 let a = coeffs.get(&1).cloned().unwrap_or_else(|| Rational::from(0));
422 let b = coeffs.get(&0).cloned().unwrap_or_else(|| Rational::from(0));
423 let mut neg_b = b;
424 neg_b.neg_assign();
425 Some(vec![Rational::from(neg_b / a)])
426 }
427 2 => {
428 let a = coeffs.get(&2).cloned().unwrap_or_else(|| Rational::from(0));
429 let b = coeffs.get(&1).cloned().unwrap_or_else(|| Rational::from(0));
430 let c = coeffs.get(&0).cloned().unwrap_or_else(|| Rational::from(0));
431 let b2 = Rational::from(&b * &b);
432 let four_ac = Rational::from(Rational::from(4) * &a * &c);
433 let disc = Rational::from(b2 - four_ac);
434 if disc < 0 {
435 return Some(vec![]);
436 }
437 let disc_numer = disc.numer().clone();
438 let disc_denom = disc.denom().clone();
439 let (sn, rem_n) = disc_numer.sqrt_rem(rug::Integer::new());
440 let (sd, rem_d) = disc_denom.sqrt_rem(rug::Integer::new());
441 if rem_n == 0 && rem_d == 0 {
442 let sqrt_disc = Rational::from((sn, sd));
443 let two_a = Rational::from(Rational::from(2) * &a);
444 let mut neg_b = b;
445 neg_b.neg_assign();
446 let root1 = Rational::from((Rational::from(&neg_b + &sqrt_disc)) / &two_a);
447 let root2 = Rational::from((Rational::from(neg_b - sqrt_disc)) / &two_a);
448 if root1 == root2 {
449 Some(vec![root1])
450 } else {
451 Some(vec![root1, root2])
452 }
453 } else {
454 None
455 }
456 }
457 _ => None,
458 }
459}
460
461fn active_vars(poly: &GbPoly, n_vars: usize) -> Vec<usize> {
468 (0..n_vars)
469 .filter(|&i| {
470 poly.terms
471 .keys()
472 .any(|e| e.get(i).copied().unwrap_or(0) > 0)
473 })
474 .collect()
475}
476
477fn max_degree_in_var(poly: &GbPoly, var_idx: usize) -> u32 {
478 poly.terms
479 .keys()
480 .map(|e| e.get(var_idx).copied().unwrap_or(0))
481 .max()
482 .unwrap_or(0)
483}
484
485fn find_solvable<'a>(
488 gens: &'a [GbPoly],
489 assigned: &[Option<ExprId>],
490 n_vars: usize,
491) -> Option<(usize, &'a GbPoly, u32)> {
492 for g in gens {
493 let active = active_vars(g, n_vars);
494 let unsolved: Vec<usize> = active
495 .iter()
496 .copied()
497 .filter(|&i| assigned[i].is_none())
498 .collect();
499 if unsolved.len() == 1 {
500 let var_idx = unsolved[0];
501 let max_deg = max_degree_in_var(g, var_idx);
502 if max_deg > 0 {
503 return Some((var_idx, g, max_deg));
504 }
505 }
506 }
507 None
508}
509
510enum BacksolveOutcome {
513 Finite(Vec<Solution>),
514 Stuck,
516 NoSolution,
517}
518
519fn try_backsolve_generators(
520 gens: &[GbPoly],
521 vars: &[ExprId],
522 pool: &ExprPool,
523) -> Result<BacksolveOutcome, SolverError> {
524 let n_vars = vars.len();
525 let mut partials: Vec<Vec<Option<ExprId>>> = vec![vec![None; n_vars]];
526
527 for _ in 0..n_vars {
528 let mut new_partials = Vec::new();
529 for partial in &partials {
530 let solvable = find_solvable(gens, partial, n_vars);
531 let (var_idx, gen, max_deg) = match solvable {
532 Some(t) => t,
533 None => return Ok(BacksolveOutcome::Stuck),
534 };
535 if max_deg > 2 {
536 return Err(SolverError::HighDegree(max_deg as usize));
537 }
538 let coeffs: Vec<ExprId> = (0..=max_deg)
539 .map(|k| extract_coeff_in_var(gen, var_idx, k, vars, partial, pool))
540 .collect();
541 let roots = solve_univariate_symbolic(&coeffs, pool)?;
542 if roots.is_empty() {
543 continue;
544 }
545 for root in roots {
546 let mut np = partial.clone();
547 np[var_idx] = Some(root);
548 new_partials.push(np);
549 }
550 }
551 partials = new_partials;
552 if partials.is_empty() {
553 return Ok(BacksolveOutcome::NoSolution);
554 }
555 }
556
557 let solutions: Vec<Solution> = partials
558 .into_iter()
559 .map(|p| {
560 p.into_iter()
561 .map(|o| o.expect("all vars assigned"))
562 .collect()
563 })
564 .collect();
565
566 Ok(BacksolveOutcome::Finite(solutions))
567}
568
569pub fn solve_polynomial_system(
576 equations: Vec<ExprId>,
577 vars: Vec<ExprId>,
578 pool: &ExprPool,
579) -> Result<SolutionSet, SolverError> {
580 let n_vars = vars.len();
581
582 let mut polys: Vec<GbPoly> = Vec::with_capacity(equations.len());
583 for eq in &equations {
584 polys.push(expr_to_gbpoly(*eq, &vars, pool)?);
585 }
586
587 let gb = GroebnerBasis::compute(polys, MonomialOrder::Lex);
588 let gens = gb.generators();
589
590 if gens.len() == 1
592 && gens[0].terms.len() == 1
593 && gens[0].leading_exp(MonomialOrder::Lex) == Some(vec![0u32; n_vars])
594 {
595 return Ok(SolutionSet::NoSolution);
596 }
597
598 match try_backsolve_generators(gens, &vars, pool)? {
599 BacksolveOutcome::Finite(solutions) => Ok(SolutionSet::Finite(solutions)),
600 BacksolveOutcome::NoSolution => Ok(SolutionSet::NoSolution),
601 BacksolveOutcome::Stuck => {
602 let chain = extract_regular_chain_from_basis(gens, n_vars, MonomialOrder::Lex);
603 if chain.polys.is_empty() {
604 return Ok(SolutionSet::Parametric(gb));
605 }
606 match try_backsolve_generators(&chain.polys, &vars, pool)? {
607 BacksolveOutcome::Finite(solutions) => Ok(SolutionSet::Finite(solutions)),
608 _ => Ok(SolutionSet::Parametric(gb)),
609 }
610 }
611 }
612}
613
614#[cfg(test)]
619mod tests {
620 use super::*;
621 use crate::jit::eval_interp;
622 use crate::kernel::{Domain, ExprPool};
623 use std::collections::HashMap;
624
625 fn eval_no_env(e: ExprId, pool: &ExprPool) -> f64 {
626 eval_interp(e, &HashMap::new(), pool).expect("numeric eval")
627 }
628
629 fn has_numeric_pair(sols: &[Solution], pool: &ExprPool, expected: &[(f64, f64)]) -> bool {
630 let tol = 1e-10;
631 expected.iter().all(|(ex, ey)| {
632 sols.iter().any(|s| {
633 let x = eval_no_env(s[0], pool);
634 let y = eval_no_env(s[1], pool);
635 (x - ex).abs() < tol && (y - ey).abs() < tol
636 })
637 })
638 }
639
640 #[test]
641 fn linear_system() {
642 let pool = ExprPool::new();
644 let x = pool.symbol("x", Domain::Real);
645 let y = pool.symbol("y", Domain::Real);
646 let neg_one = pool.integer(-1_i32);
647 let eq1 = pool.add(vec![x, y, neg_one]);
648 let eq2 = pool.add(vec![x, pool.mul(vec![neg_one, y])]);
649 let result = solve_polynomial_system(vec![eq1, eq2], vec![x, y], &pool).unwrap();
650 if let SolutionSet::Finite(sols) = result {
651 assert!(has_numeric_pair(&sols, &pool, &[(0.5, 0.5)]));
652 } else {
653 panic!("expected finite solution set");
654 }
655 }
656
657 #[test]
658 fn univariate_quadratic() {
659 let pool = ExprPool::new();
661 let x = pool.symbol("x", Domain::Real);
662 let neg_one = pool.integer(-1_i32);
663 let x2 = pool.pow(x, pool.integer(2_i32));
664 let eq = pool.add(vec![x2, neg_one]);
665 let result = solve_polynomial_system(vec![eq], vec![x], &pool).unwrap();
666 if let SolutionSet::Finite(sols) = result {
667 let vals: Vec<f64> = sols.iter().map(|s| eval_no_env(s[0], &pool)).collect();
668 assert!(vals.iter().any(|v| (v - 1.0).abs() < 1e-10));
669 assert!(vals.iter().any(|v| (v + 1.0).abs() < 1e-10));
670 } else {
671 panic!("expected finite solution set");
672 }
673 }
674
675 #[test]
676 fn circle_line_intersection() {
677 let pool = ExprPool::new();
679 let x = pool.symbol("x", Domain::Real);
680 let y = pool.symbol("y", Domain::Real);
681 let neg_one = pool.integer(-1_i32);
682 let two = pool.integer(2_i32);
683 let x2 = pool.pow(x, two);
684 let y2 = pool.pow(y, two);
685 let eq1 = pool.add(vec![x2, y2, neg_one]);
687 let eq2 = pool.add(vec![y, pool.mul(vec![neg_one, x])]);
689 let result = solve_polynomial_system(vec![eq1, eq2], vec![x, y], &pool).unwrap();
690 if let SolutionSet::Finite(sols) = result {
691 assert_eq!(
692 sols.len(),
693 2,
694 "expected exactly 2 solutions, got {}",
695 sols.len()
696 );
697 let root = (0.5_f64).sqrt(); assert!(has_numeric_pair(
699 &sols,
700 &pool,
701 &[(root, root), (-root, -root)]
702 ));
703 } else {
704 panic!("expected finite solution set");
705 }
706 }
707
708 #[test]
709 fn no_solution_inconsistent() {
710 let pool = ExprPool::new();
712 let x = pool.symbol("x", Domain::Real);
713 let neg_one = pool.integer(-1_i32);
714 let eq1 = x; let eq2 = pool.add(vec![x, neg_one]); let result = solve_polynomial_system(vec![eq1, eq2], vec![x], &pool).unwrap();
717 assert!(matches!(result, SolutionSet::NoSolution));
718 }
719
720 #[test]
721 fn parabola_and_line() {
722 let pool = ExprPool::new();
724 let x = pool.symbol("x", Domain::Real);
725 let y = pool.symbol("y", Domain::Real);
726 let neg_one = pool.integer(-1_i32);
727 let two = pool.integer(2_i32);
728 let x2 = pool.pow(x, two);
729 let eq1 = pool.add(vec![y, pool.mul(vec![neg_one, x2])]);
731 let eq2 = pool.add(vec![y, pool.mul(vec![neg_one, x])]);
733 let result = solve_polynomial_system(vec![eq1, eq2], vec![x, y], &pool).unwrap();
734 if let SolutionSet::Finite(sols) = result {
735 assert_eq!(sols.len(), 2);
736 assert!(has_numeric_pair(&sols, &pool, &[(0.0, 0.0), (1.0, 1.0)]));
737 } else {
738 panic!("expected finite solution set");
739 }
740 }
741}