1use crate::diff::diff;
44use crate::integrate::engine::integrate;
45use crate::kernel::eval_const::try_expr_f64;
46use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
47use crate::simplify::engine::{simplify, simplify_expanded};
48use std::collections::HashMap;
49use std::fmt;
50
51mod constant_coeff;
52mod first_order;
53mod verify;
54
55pub(crate) use verify::residual_is_zero;
56
57#[derive(Clone, Debug)]
73pub struct OdeInput {
74 pub x: ExprId,
76 pub y: ExprId,
78 pub derivs: Vec<ExprId>,
80 pub equation: ExprId,
82}
83
84impl OdeInput {
85 fn deriv_symbol(y: ExprId, k: usize, pool: &ExprPool) -> ExprId {
86 let base = pool.with(y, |d| match d {
87 ExprData::Symbol { name, .. } => name.clone(),
88 _ => "y".to_string(),
89 });
90 let primes = "'".repeat(k);
91 pool.symbol(format!("{base}{primes}"), Domain::Real)
92 }
93
94 pub fn first_order(x: ExprId, y: ExprId, pool: &ExprPool) -> (Self, ExprId) {
99 let yp = Self::deriv_symbol(y, 1, pool);
100 (
101 OdeInput {
102 x,
103 y,
104 derivs: vec![yp],
105 equation: pool.integer(0_i32),
106 },
107 yp,
108 )
109 }
110
111 pub fn second_order(x: ExprId, y: ExprId, pool: &ExprPool) -> (Self, ExprId, ExprId) {
115 let yp = Self::deriv_symbol(y, 1, pool);
116 let ypp = Self::deriv_symbol(y, 2, pool);
117 (
118 OdeInput {
119 x,
120 y,
121 derivs: vec![yp, ypp],
122 equation: pool.integer(0_i32),
123 },
124 yp,
125 ypp,
126 )
127 }
128
129 pub fn higher_order(
132 x: ExprId,
133 y: ExprId,
134 order: usize,
135 pool: &ExprPool,
136 ) -> (Self, Vec<ExprId>) {
137 assert!(order >= 1, "ODE order must be ≥ 1");
138 let derivs: Vec<ExprId> = (1..=order)
139 .map(|k| Self::deriv_symbol(y, k, pool))
140 .collect();
141 (
142 OdeInput {
143 x,
144 y,
145 derivs: derivs.clone(),
146 equation: pool.integer(0_i32),
147 },
148 derivs,
149 )
150 }
151
152 pub fn with_equation(mut self, equation: ExprId) -> Self {
154 self.equation = equation;
155 self
156 }
157
158 pub fn order(&self) -> usize {
160 self.derivs.len()
161 }
162}
163
164#[derive(Clone, Debug)]
166pub struct DsolveSolution {
167 pub y_of_x: ExprId,
170 pub constants: Vec<ExprId>,
172 pub method: &'static str,
174}
175
176#[derive(Clone, Debug)]
178pub struct DsolveResult {
179 pub solutions: Vec<DsolveSolution>,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
185pub enum DsolveError {
186 Unsupported(String),
189 VerificationFailed(String),
192 DiffError(String),
194}
195
196impl fmt::Display for DsolveError {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 match self {
199 DsolveError::Unsupported(m) => write!(f, "dsolve: unsupported ODE: {m}"),
200 DsolveError::VerificationFailed(m) => {
201 write!(f, "dsolve: candidate failed verification: {m}")
202 }
203 DsolveError::DiffError(m) => write!(f, "dsolve: differentiation error: {m}"),
204 }
205 }
206}
207
208impl std::error::Error for DsolveError {}
209
210impl crate::errors::AlkahestError for DsolveError {
211 fn code(&self) -> &'static str {
212 match self {
213 DsolveError::Unsupported(_) => "E-ODE-010",
214 DsolveError::VerificationFailed(_) => "E-ODE-011",
215 DsolveError::DiffError(_) => "E-ODE-012",
216 }
217 }
218
219 fn remediation(&self) -> Option<&'static str> {
220 match self {
221 DsolveError::Unsupported(_) => Some(
222 "the ODE is outside the implemented classical classes, or a required \
223 integral is non-elementary; check the equation form",
224 ),
225 DsolveError::VerificationFailed(_) => Some(
226 "the solver found a candidate that did not verify by substitution; \
227 this is reported rather than returned as a (possibly wrong) answer",
228 ),
229 DsolveError::DiffError(_) => {
230 Some("ensure the equation only contains differentiable functions")
231 }
232 }
233 }
234}
235
236pub fn dsolve(input: &OdeInput, pool: &ExprPool) -> Result<DsolveResult, DsolveError> {
253 let mut gen = ConstGen::new(input, pool);
254 match input.order() {
255 1 => first_order::solve(input, &mut gen, pool),
256 2 => constant_coeff::solve_second_order(input, &mut gen, pool),
257 n if n >= 3 => constant_coeff::solve_higher_order(input, n, &mut gen, pool),
258 _ => Err(DsolveError::Unsupported("order 0 ODE".to_string())),
259 }
260}
261
262pub(crate) struct ConstGen {
269 next: usize,
270 used: std::collections::HashSet<String>,
271}
272
273impl ConstGen {
274 fn new(input: &OdeInput, pool: &ExprPool) -> Self {
275 let mut used = std::collections::HashSet::new();
276 collect_symbol_names(input.equation, pool, &mut used);
277 ConstGen { next: 1, used }
278 }
279
280 pub(crate) fn fresh(&mut self, pool: &ExprPool) -> ExprId {
282 loop {
283 let name = format!("C{}", self.next);
284 self.next += 1;
285 if !self.used.contains(&name) {
286 self.used.insert(name.clone());
287 return pool.symbol(name, Domain::Real);
288 }
289 }
290 }
291}
292
293fn collect_symbol_names(
294 expr: ExprId,
295 pool: &ExprPool,
296 out: &mut std::collections::HashSet<String>,
297) {
298 pool.with(expr, |d| match d {
299 ExprData::Symbol { name, .. } => {
300 out.insert(name.clone());
301 }
302 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
303 for &a in args {
304 collect_symbol_names(a, pool, out);
305 }
306 }
307 ExprData::Pow { base, exp } => {
308 collect_symbol_names(*base, pool, out);
309 collect_symbol_names(*exp, pool, out);
310 }
311 _ => {}
312 });
313}
314
315pub(crate) fn simp(expr: ExprId, pool: &ExprPool) -> ExprId {
324 simplify_expanded(expr, pool).value
325}
326
327pub(crate) fn simp_plain(expr: ExprId, pool: &ExprPool) -> ExprId {
330 simplify(expr, pool).value
331}
332
333pub(crate) fn ddx(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<ExprId, DsolveError> {
335 diff(expr, var, pool)
336 .map(|d| d.value)
337 .map_err(|e| DsolveError::DiffError(e.to_string()))
338}
339
340pub(crate) fn integrate_or_decline(
343 expr: ExprId,
344 var: ExprId,
345 pool: &ExprPool,
346) -> Result<ExprId, DsolveError> {
347 match integrate(expr, var, pool) {
348 Ok(d) => Ok(simp(d.value, pool)),
349 Err(e) => {
350 if let Some(f) = integrate_pexp_trig(expr, var, pool) {
354 return Ok(f);
355 }
356 Err(DsolveError::Unsupported(format!(
357 "required integral did not close: {e}"
358 )))
359 }
360 }
361}
362
363pub(crate) fn integrate_pexp_trig(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
369 let factors: Vec<ExprId> = match pool.get(expr) {
371 ExprData::Mul(args) => args,
372 _ => vec![expr],
373 };
374 let mut exp_rate = 0.0_f64;
375 let mut trig: Option<(bool, f64)> = None; let mut poly_factors: Vec<ExprId> = Vec::new();
377 for f in factors {
378 match pool.get(f) {
379 ExprData::Func { name, args } if name == "exp" && args.len() == 1 => {
380 exp_rate += linear_rate_of(args[0], var, pool)?;
381 }
382 ExprData::Func { name, args }
383 if (name == "cos" || name == "sin") && args.len() == 1 =>
384 {
385 if trig.is_some() {
386 return None;
387 }
388 trig = Some((name == "sin", linear_rate_of(args[0], var, pool)?));
389 }
390 _ => {
391 if contains(f, var, pool) && poly_degree_in(f, var, pool).is_none() {
392 return None;
393 }
394 poly_factors.push(f);
395 }
396 }
397 }
398 let poly = if poly_factors.is_empty() {
399 pool.integer(1_i32)
400 } else {
401 simp(pool.mul(poly_factors), pool)
402 };
403 let deg = poly_degree_in(poly, var, pool)?;
404 if exp_rate == 0.0 && trig.is_none() {
405 return None; }
407
408 let exp_factor = if exp_rate != 0.0 {
411 Some(simp(
412 pool.func("exp", vec![mul_c(exp_rate, var, pool)]),
413 pool,
414 ))
415 } else {
416 None
417 };
418 let mut mods: Vec<ExprId> = Vec::new();
419 if let Some((_, b)) = trig {
420 let bx = mul_c(b, var, pool);
421 mods.push(pool.func("cos", vec![bx]));
422 mods.push(pool.func("sin", vec![bx]));
423 } else {
424 mods.push(pool.integer(1_i32));
425 }
426 let mut terms: Vec<ExprId> = Vec::new();
427 for k in 0..=deg {
428 let xk = if k == 0 {
429 pool.integer(1_i32)
430 } else {
431 pool.pow(var, pool.integer(k as i32))
432 };
433 for &m in &mods {
434 let mut fac = vec![xk, m];
435 if let Some(e) = exp_factor {
436 fac.push(e);
437 }
438 terms.push(simp(pool.mul(fac), pool));
439 }
440 }
441 let k = terms.len();
442 let mut dterms: Vec<ExprId> = Vec::with_capacity(k);
444 for &t in &terms {
445 dterms.push(simp(diff(t, var, pool).ok()?.value, pool));
446 }
447 let samples: Vec<f64> = (0..k).map(|i| 0.41 + 0.47 * i as f64).collect();
448 let mut mat = vec![vec![0.0; k]; k];
449 let mut rhs = vec![0.0; k];
450 for (i, &xv) in samples.iter().enumerate() {
451 let mut env = HashMap::new();
452 env.insert(var, xv);
453 for (j, &dt) in dterms.iter().enumerate() {
454 mat[i][j] = verify::eval(dt, &env, pool)?;
455 }
456 rhs[i] = verify::eval(expr, &env, pool)?;
457 }
458 let sol = gaussian_solve(&mut mat, &mut rhs)?;
459 let mut out = Vec::new();
460 for (j, &t) in terms.iter().enumerate() {
461 if sol[j].abs() < 1e-12 {
462 continue;
463 }
464 out.push(pool.mul(vec![f64_rational(sol[j], pool), t]));
465 }
466 let f = simp(pool.add(out), pool);
467 let df = simp(diff(f, var, pool).ok()?.value, pool);
469 for xv in [0.23_f64, 0.61, 1.07, 1.53] {
470 let mut env = HashMap::new();
471 env.insert(var, xv);
472 let lhs = verify::eval(df, &env, pool)?;
473 let rhsv = verify::eval(expr, &env, pool)?;
474 if (lhs - rhsv).abs() > 1e-6 {
475 return None;
476 }
477 }
478 Some(f)
479}
480
481fn linear_rate_of(arg: ExprId, var: ExprId, pool: &ExprPool) -> Option<f64> {
483 let d = diff(arg, var, pool).ok()?.value;
484 if contains(d, var, pool) {
485 return None;
486 }
487 let dx = simp(pool.mul(vec![d, var]), pool);
488 if !is_zero(sub(arg, dx, pool), pool) {
489 return None;
490 }
491 try_expr_f64(simp(d, pool), pool)
492}
493
494fn poly_degree_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<usize> {
495 if !contains(expr, var, pool) {
496 return Some(0);
497 }
498 match pool.get(expr) {
499 ExprData::Symbol { .. } => Some(1),
500 ExprData::Add(args) => args
501 .iter()
502 .map(|&a| poly_degree_in(a, var, pool))
503 .try_fold(0usize, |acc, d| Some(acc.max(d?))),
504 ExprData::Mul(args) => args
505 .iter()
506 .map(|&a| poly_degree_in(a, var, pool))
507 .try_fold(0usize, |acc, d| Some(acc + d?)),
508 ExprData::Pow { base, exp } if base == var => {
509 if let ExprData::Integer(k) = pool.get(exp) {
510 let k = k.0.to_i64()?;
511 if k >= 0 {
512 return Some(k as usize);
513 }
514 }
515 None
516 }
517 _ => None,
518 }
519}
520
521fn mul_c(c: f64, var: ExprId, pool: &ExprPool) -> ExprId {
522 simp(pool.mul(vec![f64_rational(c, pool), var]), pool)
523}
524
525fn f64_rational(v: f64, pool: &ExprPool) -> ExprId {
526 if v == v.round() {
527 return pool.integer(v as i64);
528 }
529 for den in 2..=24_i64 {
530 let num = v * den as f64;
531 if (num - num.round()).abs() < 1e-9 {
532 return pool.rational(num.round() as i64, den);
533 }
534 }
535 pool.float(v, 53)
536}
537
538#[allow(clippy::needless_range_loop)]
540fn gaussian_solve(mat: &mut [Vec<f64>], rhs: &mut [f64]) -> Option<Vec<f64>> {
541 let n = rhs.len();
542 for col in 0..n {
543 let mut piv = col;
544 for r in (col + 1)..n {
545 if mat[r][col].abs() > mat[piv][col].abs() {
546 piv = r;
547 }
548 }
549 if mat[piv][col].abs() < 1e-12 {
550 return None;
551 }
552 mat.swap(col, piv);
553 rhs.swap(col, piv);
554 for r in 0..n {
555 if r == col {
556 continue;
557 }
558 let factor = mat[r][col] / mat[col][col];
559 for c in col..n {
560 mat[r][c] -= factor * mat[col][c];
561 }
562 rhs[r] -= factor * rhs[col];
563 }
564 }
565 Some((0..n).map(|i| rhs[i] / mat[i][i]).collect())
566}
567
568pub(crate) fn contains(expr: ExprId, needle: ExprId, pool: &ExprPool) -> bool {
570 if expr == needle {
571 return true;
572 }
573 pool.with(expr, |d| match d {
574 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
575 args.iter().any(|&a| contains(a, needle, pool))
576 }
577 ExprData::Pow { base, exp } => {
578 contains(*base, needle, pool) || contains(*exp, needle, pool)
579 }
580 _ => false,
581 })
582}
583
584pub(crate) fn sub(a: ExprId, b: ExprId, pool: &ExprPool) -> ExprId {
586 let neg_b = pool.mul(vec![pool.integer(-1_i32), b]);
587 simp(pool.add(vec![a, neg_b]), pool)
588}
589
590pub(crate) fn div(a: ExprId, b: ExprId, pool: &ExprPool) -> ExprId {
592 let inv_b = pool.pow(b, pool.integer(-1_i32));
593 simp(pool.mul(vec![a, inv_b]), pool)
594}
595
596pub(crate) fn subs1(expr: ExprId, from: ExprId, to: ExprId, pool: &ExprPool) -> ExprId {
598 let mut m = HashMap::new();
599 m.insert(from, to);
600 simp(crate::kernel::subs::subs(expr, &m, pool), pool)
601}
602
603pub(crate) fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
605 let s = simp(expr, pool);
606 matches!(pool.get(s), ExprData::Integer(n) if n.0 == 0)
607 || matches!(try_expr_f64(s, pool), Some(v) if v == 0.0)
608}
609
610#[cfg(test)]
611mod tests;