1use crate::dae::{
20 differentiate_equation, extend_dae_for_derivative_symbols, pantelides, DaeError,
21 PantelidesResult, DAE,
22};
23use crate::errors::AlkahestError;
24use crate::kernel::{ExprData, ExprId, ExprPool};
25use crate::poly::groebner::{GbPoly, GroebnerBasis, MonomialOrder};
26use crate::solver::expr_to_gbpoly;
27use crate::solver::SolverError;
28use std::collections::{BTreeMap, HashSet};
29use std::fmt;
30
31const DEFAULT_MAX_PROLONG_ROUNDS: usize = 8;
33
34#[derive(Clone, Debug)]
37pub struct DifferentialRanking {
38 pub vars: Vec<ExprId>,
39}
40
41#[derive(Clone, Debug)]
44pub struct DifferentialIdeal {
45 pub generators: Vec<GbPoly>,
46}
47
48#[derive(Clone, Debug)]
50pub struct DifferentialRing {
51 pub time: ExprId,
52 pub ranked_indeterminates: Vec<ExprId>,
53}
54
55#[derive(Clone, Debug)]
58pub struct RegularDifferentialChain {
59 pub basis: GroebnerBasis,
60}
61
62#[derive(Clone, Debug)]
64pub struct RosenfeldGroebnerResult {
65 pub consistent: bool,
67 pub chains: Vec<RegularDifferentialChain>,
69 pub working_dae: DAE,
71 pub final_basis: Option<GroebnerBasis>,
73 pub prolongation_rounds: usize,
75 pub truncated: bool,
78}
79
80#[derive(Clone, Debug)]
83pub enum DaeIndexReduction {
84 Pantelides(PantelidesResult),
85 Rosenfeld(RosenfeldGroebnerResult),
86}
87
88#[derive(Debug, Clone)]
90pub enum DiffAlgError {
91 DiffError(String),
92 NotPolynomial(String),
93 EmptySystem,
94}
95
96impl fmt::Display for DiffAlgError {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 DiffAlgError::DiffError(s) => write!(f, "differentiation error: {s}"),
100 DiffAlgError::NotPolynomial(s) => write!(f, "not a polynomial: {s}"),
101 DiffAlgError::EmptySystem => write!(f, "empty equation system"),
102 }
103 }
104}
105
106impl std::error::Error for DiffAlgError {}
107
108impl AlkahestError for DiffAlgError {
109 fn code(&self) -> &'static str {
110 match self {
111 DiffAlgError::DiffError(_) => "E-DIFFALG-001",
112 DiffAlgError::NotPolynomial(_) => "E-DIFFALG-002",
113 DiffAlgError::EmptySystem => "E-DIFFALG-003",
114 }
115 }
116
117 fn remediation(&self) -> Option<&'static str> {
118 match self {
119 DiffAlgError::DiffError(_) => {
120 Some("ensure the DAE is polynomial in its state and derivative symbols")
121 }
122 DiffAlgError::NotPolynomial(_) => {
123 Some("declare all jet variables and parameters; remove transcendental functions")
124 }
125 DiffAlgError::EmptySystem => Some("pass at least one implicit equation"),
126 }
127 }
128}
129
130fn solver_err_to_diffalg(e: SolverError) -> DiffAlgError {
131 DiffAlgError::NotPolynomial(e.to_string())
132}
133
134fn is_unit_ideal_gb(gb: &GroebnerBasis) -> bool {
135 gb.generators().iter().any(|g| {
136 g.terms.len() == 1
137 && g.terms
138 .keys()
139 .next()
140 .is_some_and(|e| e.iter().all(|&x| x == 0))
141 && g.terms.values().next().is_some_and(|c| *c != 0)
142 })
143}
144
145fn pad_gbpoly(p: &GbPoly, new_n: usize) -> GbPoly {
146 if new_n == p.n_vars {
147 return p.clone();
148 }
149 assert!(new_n > p.n_vars);
150 let pad = new_n - p.n_vars;
151 let mut terms = BTreeMap::new();
152 for (e, c) in &p.terms {
153 let mut ne = e.clone();
154 ne.extend(std::iter::repeat(0u32).take(pad));
155 terms.insert(ne, c.clone());
156 }
157 GbPoly {
158 terms,
159 n_vars: new_n,
160 }
161}
162
163fn children(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
164 pool.with(expr, |data| match data {
165 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
166 ExprData::Pow { base, exp } => vec![*base, *exp],
167 ExprData::BigO(inner) => vec![*inner],
168 _ => vec![],
169 })
170}
171
172fn collect_symbols(
173 expr: ExprId,
174 pool: &ExprPool,
175 seen: &mut HashSet<ExprId>,
176 out: &mut Vec<ExprId>,
177) {
178 let is_sym = pool.with(expr, |d| matches!(d, ExprData::Symbol { .. }));
179 if is_sym && seen.insert(expr) {
180 out.push(expr);
181 }
182 for c in children(expr, pool) {
183 collect_symbols(c, pool, seen, out);
184 }
185}
186
187fn vars_for_dae(dae: &DAE, scratch: &[ExprId], pool: &ExprPool) -> Vec<ExprId> {
188 let mut seen = HashSet::new();
189 let mut out = Vec::new();
190 let mut push = |id: ExprId| {
191 if seen.insert(id) {
192 out.push(id);
193 }
194 };
195 push(dae.time_var);
196 for i in 0..dae.variables.len() {
197 push(dae.variables[i]);
198 push(dae.derivatives[i]);
199 }
200 for &root in scratch {
201 collect_symbols(root, pool, &mut seen, &mut out);
202 }
203 out
204}
205
206fn polys_from_equations(
207 eqs: &[ExprId],
208 vars: &[ExprId],
209 pool: &ExprPool,
210) -> Result<Vec<GbPoly>, DiffAlgError> {
211 eqs.iter()
212 .map(|&eq| expr_to_gbpoly(eq, vars, pool).map_err(solver_err_to_diffalg))
213 .collect()
214}
215
216pub fn rosenfeld_groebner_algebraic(
219 gens: Vec<GbPoly>,
220 order: MonomialOrder,
221) -> Result<Vec<RegularDifferentialChain>, DiffAlgError> {
222 if gens.is_empty() {
223 return Err(DiffAlgError::EmptySystem);
224 }
225 let gb = GroebnerBasis::compute(gens, order);
226 if is_unit_ideal_gb(&gb) {
227 return Ok(vec![]);
228 }
229 Ok(vec![RegularDifferentialChain { basis: gb }])
230}
231
232pub fn rosenfeld_groebner_with_options(
238 dae: &DAE,
239 pool: &ExprPool,
240 order: MonomialOrder,
241 max_prolong_rounds: usize,
242) -> Result<RosenfeldGroebnerResult, DiffAlgError> {
243 if dae.equations.is_empty() {
244 return Err(DiffAlgError::EmptySystem);
245 }
246
247 let source_eqs = dae.equations.clone();
248 let mut work = dae.clone();
249 let mut scratch: Vec<ExprId> = source_eqs.clone();
250 let mut vars = vars_for_dae(&work, &scratch, pool);
251 let mut active = polys_from_equations(&work.equations, &vars, pool)?;
252
253 let mut prolong_exprs = source_eqs.clone();
254 let mut prolongation_rounds: usize = 0;
255
256 for round in 0..max_prolong_rounds {
257 let gb = GroebnerBasis::compute(active.clone(), order);
258 if is_unit_ideal_gb(&gb) {
259 return Ok(RosenfeldGroebnerResult {
260 consistent: false,
261 chains: vec![],
262 working_dae: work,
263 final_basis: None,
264 prolongation_rounds,
265 truncated: false,
266 });
267 }
268
269 let mut next_prolong = Vec::with_capacity(prolong_exprs.len());
270 for &eq in &prolong_exprs {
271 let d_eq =
272 differentiate_equation(eq, &work.variables, &work.derivatives, work.time_var, pool)
273 .map_err(|e| DiffAlgError::DiffError(e.to_string()))?;
274 extend_dae_for_derivative_symbols(&mut work, d_eq, pool);
275 next_prolong.push(d_eq);
276 }
277 prolong_exprs = next_prolong;
278 scratch = source_eqs
279 .iter()
280 .copied()
281 .chain(prolong_exprs.iter().copied())
282 .collect();
283 vars = vars_for_dae(&work, &scratch, pool);
284 let n = vars.len();
285 for p in &mut active {
286 *p = pad_gbpoly(p, n);
287 }
288
289 let gb_check = GroebnerBasis::compute(active.clone(), order);
290 let mut to_add: Vec<GbPoly> = Vec::new();
291 for &d_eq in &prolong_exprs {
292 let p = expr_to_gbpoly(d_eq, &vars, pool).map_err(solver_err_to_diffalg)?;
293 if !gb_check.contains(&p) {
294 to_add.push(p);
295 }
296 }
297
298 if to_add.is_empty() {
299 let final_basis = GroebnerBasis::compute(active, order);
300 let consistent = !is_unit_ideal_gb(&final_basis);
301 let chains = if consistent {
302 vec![RegularDifferentialChain {
303 basis: final_basis.clone(),
304 }]
305 } else {
306 vec![]
307 };
308 return Ok(RosenfeldGroebnerResult {
309 consistent,
310 chains,
311 working_dae: work,
312 final_basis: if consistent { Some(final_basis) } else { None },
313 prolongation_rounds,
314 truncated: false,
315 });
316 }
317
318 active.extend(to_add);
319 prolongation_rounds += 1;
320
321 if round + 1 == max_prolong_rounds {
322 let final_basis = GroebnerBasis::compute(active, order);
323 let consistent = !is_unit_ideal_gb(&final_basis);
324 let chains = if consistent {
325 vec![RegularDifferentialChain {
326 basis: final_basis.clone(),
327 }]
328 } else {
329 vec![]
330 };
331 return Ok(RosenfeldGroebnerResult {
332 consistent,
333 chains,
334 working_dae: work,
335 final_basis: if consistent { Some(final_basis) } else { None },
336 prolongation_rounds,
337 truncated: true,
338 });
339 }
340 }
341
342 let final_basis = GroebnerBasis::compute(active, order);
343 let consistent = !is_unit_ideal_gb(&final_basis);
344 Ok(RosenfeldGroebnerResult {
345 consistent,
346 chains: if consistent {
347 vec![RegularDifferentialChain {
348 basis: final_basis.clone(),
349 }]
350 } else {
351 vec![]
352 },
353 working_dae: work,
354 final_basis: if consistent { Some(final_basis) } else { None },
355 prolongation_rounds,
356 truncated: true,
357 })
358}
359
360pub fn rosenfeld_groebner(
362 dae: &DAE,
363 pool: &ExprPool,
364 order: MonomialOrder,
365) -> Result<RosenfeldGroebnerResult, DiffAlgError> {
366 rosenfeld_groebner_with_options(dae, pool, order, DEFAULT_MAX_PROLONG_ROUNDS)
367}
368
369pub fn dae_index_reduce(
371 dae: &DAE,
372 pool: &ExprPool,
373 order: MonomialOrder,
374) -> Result<DaeIndexReduction, DaeError> {
375 match pantelides(dae, pool) {
376 Ok(p) => Ok(DaeIndexReduction::Pantelides(p)),
377 Err(DaeError::IndexTooHigh) => {
378 let r = rosenfeld_groebner(dae, pool, order).map_err(|e| match e {
379 DiffAlgError::DiffError(s) | DiffAlgError::NotPolynomial(s) => {
380 DaeError::DiffError(s)
381 }
382 DiffAlgError::EmptySystem => DaeError::StructurallyInconsistent,
383 })?;
384 Ok(DaeIndexReduction::Rosenfeld(r))
385 }
386 Err(e) => Err(e),
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::kernel::Domain;
394
395 fn pool() -> ExprPool {
396 ExprPool::new()
397 }
398
399 #[test]
400 fn algebraic_inconsistent_unit_ideal() {
401 let p = pool();
402 let _x = p.symbol("x", Domain::Real);
403 let one_p = GbPoly::constant(rug::Rational::from(1), 1);
404 let gx = GbPoly::monomial(vec![1], rug::Rational::from(1));
405 let f = gx.add(&one_p); let g = gx.sub(&one_p); let chains = rosenfeld_groebner_algebraic(vec![f, g], MonomialOrder::Lex).unwrap();
408 assert!(chains.is_empty());
409 }
410
411 #[test]
412 fn lotka_volterra_first_order_consistent() {
413 let p = pool();
414 let t = p.symbol("t", Domain::Real);
415 let x = p.symbol("x", Domain::Real);
416 let y = p.symbol("y", Domain::Real);
417 let dx = p.symbol("dx/dt", Domain::Real);
418 let dy = p.symbol("dy/dt", Domain::Real);
419 let eq1 = p.add(vec![dx, p.mul(vec![p.integer(-1), x]), p.mul(vec![x, y])]);
421 let eq2 = p.add(vec![dy, p.mul(vec![p.integer(-1), x, y]), y]);
422 let dae = DAE::new(vec![eq1, eq2], vec![x, y], vec![dx, dy], t);
423 let r = rosenfeld_groebner_with_options(&dae, &p, MonomialOrder::GRevLex, 0).unwrap();
425 assert!(r.consistent && r.final_basis.is_some());
426 assert!(r.truncated);
427 }
428
429 #[test]
430 fn contradictory_linear_equations_inconsistent() {
431 let p = pool();
432 let t = p.symbol("t", Domain::Real);
433 let y = p.symbol("y", Domain::Real);
434 let dy = p.symbol("dy/dt", Domain::Real);
435 let eq1 = p.add(vec![dy, p.mul(vec![p.integer(-1), y])]);
437 let eq2 = p.add(vec![dy, p.mul(vec![p.integer(-1), y]), p.integer(-1)]);
438 let dae = DAE::new(vec![eq1, eq2], vec![y], vec![dy], t);
439 let r = rosenfeld_groebner(&dae, &p, MonomialOrder::Lex).unwrap();
440 assert!(!r.consistent);
441 }
442
443 #[test]
444 fn textbook_library_runs() {
445 let mut n_ok = 0;
447 for _ in 0..10 {
448 let p = pool();
449 let t = p.symbol("t", Domain::Real);
450 let x = p.symbol("x", Domain::Real);
451 let y = p.symbol("y", Domain::Real);
452 let dx = p.symbol("dx/dt", Domain::Real);
453 let dy = p.symbol("dy/dt", Domain::Real);
454 let (eqs, v, d, consistent) = match n_ok % 3 {
456 0 => {
457 let e1 = p.add(vec![dx, p.mul(vec![p.integer(-1), y])]);
459 let e2 = p.add(vec![dy, x]);
460 (vec![e1, e2], vec![x, y], vec![dx, dy], true)
461 }
462 1 => {
463 let e1 = p.add(vec![dx, p.mul(vec![p.integer(-1), x])]);
465 let e2 = p.add(vec![dy, y]);
466 (vec![e1, e2], vec![x, y], vec![dx, dy], true)
467 }
468 _ => {
469 let e1 = p.add(vec![dx, p.mul(vec![p.integer(-1), x])]);
470 let e2 = p.add(vec![dx, p.mul(vec![p.integer(-1), x]), p.integer(-1)]);
471 (vec![e1, e2], vec![x], vec![dx], false)
472 }
473 };
474 let dae = DAE::new(eqs, v, d, t);
475 let r = rosenfeld_groebner(&dae, &p, MonomialOrder::GRevLex).unwrap();
476 assert_eq!(r.consistent, consistent);
477 n_ok += 1;
478 }
479 assert_eq!(n_ok, 10);
480 }
481}