1use crate::kernel::{ExprId, ExprPool};
11use crate::poly::factor::factor_univariate_z;
12use crate::poly::groebner::{GbPoly, GroebnerBasis, MonomialOrder};
13use crate::poly::unipoly::UniPoly;
14use rug::Rational;
15use std::collections::BTreeMap;
16
17use super::{expr_to_gbpoly, SolverError};
18
19#[derive(Debug, Clone)]
22pub struct RegularChain {
23 pub n_vars: usize,
24 pub polys: Vec<GbPoly>,
25}
26
27impl RegularChain {
28 pub fn len(&self) -> usize {
30 self.polys.len()
31 }
32
33 pub fn is_empty(&self) -> bool {
34 self.polys.is_empty()
35 }
36}
37
38pub fn main_variable_recursive(poly: &GbPoly) -> Option<usize> {
40 let mut best: Option<usize> = None;
41 for exp in poly.terms.keys() {
42 for (i, &e) in exp.iter().enumerate() {
43 if e > 0 {
44 best = Some(best.map_or(i, |b| b.max(i)));
45 }
46 }
47 }
48 best
49}
50
51fn degree_in_var(poly: &GbPoly, var: usize) -> u32 {
52 poly.terms
53 .keys()
54 .map(|e| e.get(var).copied().unwrap_or(0))
55 .max()
56 .unwrap_or(0)
57}
58
59fn is_univariate_in(poly: &GbPoly, var: usize) -> bool {
61 !poly.is_zero()
62 && poly
63 .terms
64 .keys()
65 .all(|e| e.iter().enumerate().all(|(i, &exp)| i == var || exp == 0))
66}
67
68fn is_unit_ideal(gens: &[GbPoly], n_vars: usize) -> bool {
69 gens.len() == 1
70 && gens[0].terms.len() == 1
71 && gens[0].leading_exp(MonomialOrder::Lex) == Some(vec![0u32; n_vars])
72}
73
74pub fn extract_regular_chain_from_basis(
77 gens: &[GbPoly],
78 n_vars: usize,
79 order: MonomialOrder,
80) -> RegularChain {
81 let mut best: Vec<Option<(GbPoly, u32)>> = vec![None; n_vars];
82 for g in gens {
83 if let Some(mv) = main_variable_recursive(g) {
84 let d = degree_in_var(g, mv);
85 let replace = match &best[mv] {
86 None => true,
87 Some((_, deg)) => d < *deg,
88 };
89 if replace {
90 best[mv] = Some((g.clone().make_monic(order), d));
91 }
92 }
93 }
94 let polys: Vec<GbPoly> = best.into_iter().flatten().map(|(p, _)| p).collect();
95 RegularChain { n_vars, polys }
96}
97
98fn lcm_rational_denoms(coeffs: &[Rational]) -> rug::Integer {
99 let mut m = rug::Integer::from(1);
100 for c in coeffs {
101 m = m.lcm(c.denom());
102 }
103 m
104}
105
106fn gbpoly_to_unipoly_z(
108 p: &GbPoly,
109 var_idx: usize,
110 var_expr: ExprId,
111) -> Result<UniPoly, SolverError> {
112 let mut coeffs_map: BTreeMap<u32, Rational> = BTreeMap::new();
113 for (exp, c) in &p.terms {
114 let e = exp.get(var_idx).copied().unwrap_or(0);
115 if exp.iter().enumerate().any(|(i, &x)| i != var_idx && x > 0) {
116 return Err(SolverError::NotPolynomial(
117 "expected univariate polynomial for factor split".into(),
118 ));
119 }
120 coeffs_map.insert(e, c.clone());
121 }
122 let coeffs_rat: Vec<Rational> = (0..=*coeffs_map.keys().max().unwrap_or(&0))
123 .map(|d| {
124 coeffs_map
125 .get(&d)
126 .cloned()
127 .unwrap_or_else(|| Rational::from(0))
128 })
129 .collect();
130 let lcm = lcm_rational_denoms(&coeffs_rat);
131 let mut coeff_ints = Vec::new();
132 for r in coeffs_rat {
133 let t = r * Rational::from((lcm.clone(), 1));
134 let (n, d) = t.into_numer_denom();
135 debug_assert_eq!(d, 1);
136 coeff_ints.push(n);
137 }
138 while coeff_ints.len() > 1 && coeff_ints.last() == Some(&rug::Integer::from(0)) {
140 coeff_ints.pop();
141 }
142 let flint = crate::flint::FlintPoly::from_rug_coefficients(&coeff_ints);
143 Ok(UniPoly {
144 var: var_expr,
145 coeffs: flint,
146 })
147}
148
149fn unipoly_z_to_gbpoly_last(u: &UniPoly, n_vars: usize, var_idx: usize) -> GbPoly {
151 let mut terms = BTreeMap::new();
152 let deg = u.degree().max(0) as usize;
153 for d in 0..=deg {
154 let zi = u.coeffs.get_coeff_flint(d).to_rug();
155 if zi == 0 {
156 continue;
157 }
158 let mut exp = vec![0u32; n_vars];
159 exp[var_idx] = d as u32;
160 terms.insert(exp, Rational::from((zi, 1)));
161 }
162 GbPoly { terms, n_vars }
163}
164
165fn split_chain_at_bottom_univariate(
168 chain: RegularChain,
169 last_var: ExprId,
170) -> Result<Vec<RegularChain>, SolverError> {
171 let n = chain.n_vars;
172 if n == 0 {
173 return Ok(vec![chain]);
174 }
175 let last = n - 1;
176 let uni_entry = chain
178 .polys
179 .iter()
180 .enumerate()
181 .filter(|(_, p)| is_univariate_in(p, last))
182 .max_by_key(|(_, p)| degree_in_var(p, last));
183
184 let Some((idx, uni_poly)) = uni_entry else {
185 return Ok(vec![chain]);
186 };
187
188 let u_z = gbpoly_to_unipoly_z(uni_poly, last, last_var)?;
189 let sqf = u_z.squarefree_part();
190 if sqf.degree() <= 1 {
191 return Ok(vec![chain]);
192 }
193
194 let fac = factor_univariate_z(&sqf)
195 .map_err(|e| SolverError::NotPolynomial(format!("triangularize factorization: {e}")))?;
196
197 let nontrivial = fac.factors.iter().filter(|(f, _)| f.degree() >= 1).count();
198 if nontrivial <= 1 {
199 return Ok(vec![chain]);
200 }
201
202 let mut out = Vec::new();
203 for (factor, _) in fac.factors {
204 if factor.degree() < 1 {
205 continue;
206 }
207 let f_gbp = unipoly_z_to_gbpoly_last(&factor, n, last).make_monic(MonomialOrder::Lex);
208 let mut polys = chain.polys.clone();
209 polys[idx] = f_gbp;
210 out.push(RegularChain {
211 n_vars: chain.n_vars,
212 polys,
213 });
214 }
215
216 if out.is_empty() {
217 Ok(vec![chain])
218 } else {
219 Ok(out)
220 }
221}
222
223pub fn triangularize(
229 equations: Vec<ExprId>,
230 vars: Vec<ExprId>,
231 pool: &ExprPool,
232) -> Result<Vec<RegularChain>, SolverError> {
233 let n_vars = vars.len();
234 if n_vars == 0 {
235 return Ok(vec![]);
236 }
237 let last_var = *vars.last().expect("n_vars > 0");
238
239 let mut polys: Vec<GbPoly> = Vec::with_capacity(equations.len());
240 for eq in &equations {
241 polys.push(expr_to_gbpoly(*eq, &vars, pool)?);
242 }
243
244 let gb = GroebnerBasis::compute(polys, MonomialOrder::Lex);
245 let gens = gb.generators();
246
247 if is_unit_ideal(gens, n_vars) {
248 return Ok(vec![]);
249 }
250
251 let chain = extract_regular_chain_from_basis(gens, n_vars, MonomialOrder::Lex);
252 split_chain_at_bottom_univariate(chain, last_var)
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::kernel::Domain;
259
260 #[test]
261 fn extract_chain_linear_system() {
262 let pool = ExprPool::new();
263 let x = pool.symbol("x", Domain::Real);
264 let y = pool.symbol("y", Domain::Real);
265 let neg_one = pool.integer(-1_i32);
266 let eq1 = pool.add(vec![x, y, neg_one]);
267 let eq2 = pool.add(vec![x, pool.mul(vec![neg_one, y])]);
268 let chains = triangularize(vec![eq1, eq2], vec![x, y], &pool).unwrap();
269 assert_eq!(chains.len(), 1);
270 assert!(!chains[0].is_empty());
271 }
272
273 #[test]
274 fn split_univariate_square() {
275 let pool = ExprPool::new();
277 let x = pool.symbol("x", Domain::Real);
278 let one = pool.integer(1_i32);
279 let x2 = pool.pow(x, pool.integer(2));
280 let eq = pool.add(vec![x2, pool.mul(vec![pool.integer(-1), one])]);
281 let chains = triangularize(vec![eq], vec![x], &pool).unwrap();
282 assert_eq!(chains.len(), 2);
283 for c in &chains {
284 assert_eq!(c.len(), 1);
285 assert_eq!(degree_in_var(&c.polys[0], 0), 1);
286 }
287 }
288}