Skip to main content

alkahest_cas/ode/
dsolve.rs

1//! Classical symbolic ODE solver (`dsolve`).
2//!
3//! Returns closed-form *general* solutions to ordinary differential equations,
4//! introducing integration constants `C1, C2, …` as fresh symbols.
5//!
6//! # Covered classes
7//!
8//! **First order** (`y' = …` written as `F(x, y, y') = 0`):
9//! - separable `y' = g(x)·h(y)`
10//! - linear `y' + p(x)·y = q(x)` (integrating-factor)
11//! - Bernoulli `y' + p(x)·y = q(x)·yⁿ`
12//! - exact `M dx + N dy = 0` with `∂M/∂y = ∂N/∂x`
13//! - homogeneous of degree zero `y' = G(y/x)` (substitution `v = y/x`)
14//! - Clairaut `y = x·y' + f(y')`
15//! - Riccati `y' = q₀(x) + q₁(x)·y + q₂(x)·y²` **with a polynomial particular
16//!   solution** found by ansatz (declined otherwise)
17//!
18//! **Second order** (`F(x, y, y', y'') = 0`):
19//! - constant coefficients `a·y'' + b·y' + c·y = r(x)` (real distinct / repeated
20//!   / complex roots), including non-homogeneous RHS via undetermined
21//!   coefficients (polynomial × exp × sin/cos) and variation of parameters
22//! - Euler–Cauchy `a·x²·y'' + b·x·y' + c·y = 0`
23//!
24//! **Higher order**: constant-coefficient `Σ aₖ y^(k) = 0`, solved through the
25//! characteristic polynomial (rational + quadratic factorization; irreducible
26//! factors of degree ≥ 3 are declined).
27//!
28//! # Verification gate
29//!
30//! *Every* returned solution is verified by substitution: the candidate `y(x)`
31//! (and its derivatives) are substituted into the original equation, the
32//! residual is simplified, and accepted only when it is the symbolic zero or
33//! numerically `≈ 0` at several sample points over random constant values.  A
34//! candidate that fails verification causes [`dsolve`] to decline (it never
35//! returns an unverified solution).
36//!
37//! # Quadratures
38//!
39//! Closed forms that require an integral defer to the existing
40//! [`mod@crate::integrate`] engine.  If a required integral does not close in
41//! elementary form, the class is declined (no unevaluated-integral output).
42
43use crate::diff::diff;
44use crate::integrate::engine::integrate;
45use crate::kernel::eval_const::try_expr_f64;
46use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
47use crate::simplify::engine::{simplify, simplify_expanded};
48use std::collections::HashMap;
49use std::fmt;
50
51mod constant_coeff;
52mod first_order;
53mod verify;
54
55pub(crate) use verify::residual_is_zero;
56
57// ---------------------------------------------------------------------------
58// Public API
59// ---------------------------------------------------------------------------
60
61/// Input description of a scalar ODE for [`dsolve`].
62///
63/// The equation is supplied as a single expression `equation` that is taken to
64/// equal zero, written in terms of the symbols `x` (independent variable), `y`
65/// (the unknown `y(x)`), and the derivative symbols in `derivs`
66/// (`derivs[0] = y'`, `derivs[1] = y''`, …).  The `order` equals
67/// `derivs.len()`.
68///
69/// Use [`OdeInput::first_order`] / [`OdeInput::second_order`] /
70/// [`OdeInput::higher_order`] to build instances; they allocate the derivative
71/// symbols with the conventional names `y'`, `y''`, ….
72#[derive(Clone, Debug)]
73pub struct OdeInput {
74    /// Independent variable, e.g. `x`.
75    pub x: ExprId,
76    /// Dependent variable `y` (representing `y(x)`).
77    pub y: ExprId,
78    /// Derivative symbols `[y', y'', …]`.
79    pub derivs: Vec<ExprId>,
80    /// The equation, interpreted as `equation = 0`.
81    pub equation: ExprId,
82}
83
84impl OdeInput {
85    fn deriv_symbol(y: ExprId, k: usize, pool: &ExprPool) -> ExprId {
86        let base = pool.with(y, |d| match d {
87            ExprData::Symbol { name, .. } => name.clone(),
88            _ => "y".to_string(),
89        });
90        let primes = "'".repeat(k);
91        pool.symbol(format!("{base}{primes}"), Domain::Real)
92    }
93
94    /// Build a first-order input `equation(x, y, y') = 0`.
95    ///
96    /// Returns `(input, y')` so the caller can build the equation referring to
97    /// the freshly created derivative symbol.
98    pub fn first_order(x: ExprId, y: ExprId, pool: &ExprPool) -> (Self, ExprId) {
99        let yp = Self::deriv_symbol(y, 1, pool);
100        (
101            OdeInput {
102                x,
103                y,
104                derivs: vec![yp],
105                equation: pool.integer(0_i32),
106            },
107            yp,
108        )
109    }
110
111    /// Build a second-order input `equation(x, y, y', y'') = 0`.
112    ///
113    /// Returns `(input, y', y'')`.
114    pub fn second_order(x: ExprId, y: ExprId, pool: &ExprPool) -> (Self, ExprId, ExprId) {
115        let yp = Self::deriv_symbol(y, 1, pool);
116        let ypp = Self::deriv_symbol(y, 2, pool);
117        (
118            OdeInput {
119                x,
120                y,
121                derivs: vec![yp, ypp],
122                equation: pool.integer(0_i32),
123            },
124            yp,
125            ypp,
126        )
127    }
128
129    /// Build an `order`-th order input.  Returns `(input, derivs)` where
130    /// `derivs[k]` is the `(k+1)`-th derivative symbol.
131    pub fn higher_order(
132        x: ExprId,
133        y: ExprId,
134        order: usize,
135        pool: &ExprPool,
136    ) -> (Self, Vec<ExprId>) {
137        assert!(order >= 1, "ODE order must be ≥ 1");
138        let derivs: Vec<ExprId> = (1..=order)
139            .map(|k| Self::deriv_symbol(y, k, pool))
140            .collect();
141        (
142            OdeInput {
143                x,
144                y,
145                derivs: derivs.clone(),
146                equation: pool.integer(0_i32),
147            },
148            derivs,
149        )
150    }
151
152    /// Replace the equation expression.
153    pub fn with_equation(mut self, equation: ExprId) -> Self {
154        self.equation = equation;
155        self
156    }
157
158    /// ODE order.
159    pub fn order(&self) -> usize {
160        self.derivs.len()
161    }
162}
163
164/// A general solution returned by [`dsolve`].
165#[derive(Clone, Debug)]
166pub struct DsolveSolution {
167    /// The solution expression for `y(x)` (the right-hand side of `y(x) = …`),
168    /// containing the integration constants in [`Self::constants`].
169    pub y_of_x: ExprId,
170    /// The fresh constant symbols `C1, C2, …` appearing in [`Self::y_of_x`].
171    pub constants: Vec<ExprId>,
172    /// Short label of the solving method (e.g. `"separable"`).
173    pub method: &'static str,
174}
175
176/// The result of [`dsolve`]: zero or more general-solution branches.
177#[derive(Clone, Debug)]
178pub struct DsolveResult {
179    /// General-solution branches.  Most classes return exactly one branch.
180    pub solutions: Vec<DsolveSolution>,
181}
182
183/// Errors / declines from [`dsolve`].
184#[derive(Debug, Clone, PartialEq, Eq)]
185pub enum DsolveError {
186    /// The ODE did not match any implemented solvable class, or a required
187    /// quadrature did not close in elementary form.
188    Unsupported(String),
189    /// A candidate closed form was produced but failed the substitution
190    /// verification gate (so it is withheld rather than returned wrong).
191    VerificationFailed(String),
192    /// Differentiation of an intermediate expression failed.
193    DiffError(String),
194}
195
196impl fmt::Display for DsolveError {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            DsolveError::Unsupported(m) => write!(f, "dsolve: unsupported ODE: {m}"),
200            DsolveError::VerificationFailed(m) => {
201                write!(f, "dsolve: candidate failed verification: {m}")
202            }
203            DsolveError::DiffError(m) => write!(f, "dsolve: differentiation error: {m}"),
204        }
205    }
206}
207
208impl std::error::Error for DsolveError {}
209
210impl crate::errors::AlkahestError for DsolveError {
211    fn code(&self) -> &'static str {
212        match self {
213            DsolveError::Unsupported(_) => "E-ODE-010",
214            DsolveError::VerificationFailed(_) => "E-ODE-011",
215            DsolveError::DiffError(_) => "E-ODE-012",
216        }
217    }
218
219    fn remediation(&self) -> Option<&'static str> {
220        match self {
221            DsolveError::Unsupported(_) => Some(
222                "the ODE is outside the implemented classical classes, or a required \
223                 integral is non-elementary; check the equation form",
224            ),
225            DsolveError::VerificationFailed(_) => Some(
226                "the solver found a candidate that did not verify by substitution; \
227                 this is reported rather than returned as a (possibly wrong) answer",
228            ),
229            DsolveError::DiffError(_) => {
230                Some("ensure the equation only contains differentiable functions")
231            }
232        }
233    }
234}
235
236// ---------------------------------------------------------------------------
237// Entry point
238// ---------------------------------------------------------------------------
239
240/// Solve a scalar ODE in closed form, returning the general solution(s).
241///
242/// Dispatches on the ODE order and structure to the implemented classical
243/// methods.  Every returned solution is verified by substitution (see the
244/// [module docs](self)); unverifiable candidates are withheld and the relevant
245/// class declines.
246///
247/// # Errors
248///
249/// Returns [`DsolveError::Unsupported`] when the equation is outside the
250/// implemented classes or a required quadrature is non-elementary, and
251/// [`DsolveError::VerificationFailed`] when a candidate could not be verified.
252pub fn dsolve(input: &OdeInput, pool: &ExprPool) -> Result<DsolveResult, DsolveError> {
253    let mut gen = ConstGen::new(input, pool);
254    match input.order() {
255        1 => first_order::solve(input, &mut gen, pool),
256        2 => constant_coeff::solve_second_order(input, &mut gen, pool),
257        n if n >= 3 => constant_coeff::solve_higher_order(input, n, &mut gen, pool),
258        _ => Err(DsolveError::Unsupported("order 0 ODE".to_string())),
259    }
260}
261
262// ---------------------------------------------------------------------------
263// Fresh-constant generator (collision-free with user symbols)
264// ---------------------------------------------------------------------------
265
266/// Allocates fresh integration-constant symbols `C1, C2, …`, skipping any name
267/// already present in the input equation so user symbols never collide.
268pub(crate) struct ConstGen {
269    next: usize,
270    used: std::collections::HashSet<String>,
271}
272
273impl ConstGen {
274    fn new(input: &OdeInput, pool: &ExprPool) -> Self {
275        let mut used = std::collections::HashSet::new();
276        collect_symbol_names(input.equation, pool, &mut used);
277        ConstGen { next: 1, used }
278    }
279
280    /// Return a fresh constant symbol whose name (`C{n}`) is not already used.
281    pub(crate) fn fresh(&mut self, pool: &ExprPool) -> ExprId {
282        loop {
283            let name = format!("C{}", self.next);
284            self.next += 1;
285            if !self.used.contains(&name) {
286                self.used.insert(name.clone());
287                return pool.symbol(name, Domain::Real);
288            }
289        }
290    }
291}
292
293fn collect_symbol_names(
294    expr: ExprId,
295    pool: &ExprPool,
296    out: &mut std::collections::HashSet<String>,
297) {
298    pool.with(expr, |d| match d {
299        ExprData::Symbol { name, .. } => {
300            out.insert(name.clone());
301        }
302        ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
303            for &a in args {
304                collect_symbol_names(a, pool, out);
305            }
306        }
307        ExprData::Pow { base, exp } => {
308            collect_symbol_names(*base, pool, out);
309            collect_symbol_names(*exp, pool, out);
310        }
311        _ => {}
312    });
313}
314
315// ---------------------------------------------------------------------------
316// Shared small helpers (used across submodules)
317// ---------------------------------------------------------------------------
318
319/// Simplify with distribution (expanded normal form).  The classification
320/// logic relies on polynomial-in-`x`/`y` terms being flattened (e.g.
321/// `−1·(−3y−x)` becoming `3y + x`) so coefficient extraction by structural
322/// inspection works.
323pub(crate) fn simp(expr: ExprId, pool: &ExprPool) -> ExprId {
324    simplify_expanded(expr, pool).value
325}
326
327/// Plain (non-expanding) simplify, for the final residual zero-check where
328/// expansion is not required.
329pub(crate) fn simp_plain(expr: ExprId, pool: &ExprPool) -> ExprId {
330    simplify(expr, pool).value
331}
332
333/// `diff(expr, var).value`, mapping `DiffError` into `DsolveError`.
334pub(crate) fn ddx(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<ExprId, DsolveError> {
335    diff(expr, var, pool)
336        .map(|d| d.value)
337        .map_err(|e| DsolveError::DiffError(e.to_string()))
338}
339
340/// Integrate `expr` in `var`; map any decline to `Unsupported` so the caller
341/// declines the whole class (we never emit unevaluated-integral output).
342pub(crate) fn integrate_or_decline(
343    expr: ExprId,
344    var: ExprId,
345    pool: &ExprPool,
346) -> Result<ExprId, DsolveError> {
347    match integrate(expr, var, pool) {
348        Ok(d) => Ok(simp(d.value, pool)),
349        Err(e) => {
350            // Fallback: closed-form ∫ p(x)·e^{a x}·{1,cos b x,sin b x} dx via an
351            // undetermined-coefficients ansatz (the engine declines some of these
352            // products, e.g. ∫ x·e^{−3x}).
353            if let Some(f) = integrate_pexp_trig(expr, var, pool) {
354                return Ok(f);
355            }
356            Err(DsolveError::Unsupported(format!(
357                "required integral did not close: {e}"
358            )))
359        }
360    }
361}
362
363/// Antiderivative of `p(x)·e^{a x}·{1 | cos(b x) | sin(b x)}` where `p` is a
364/// polynomial and `a,b` are constants.  Builds an ansatz of the same shape with
365/// undetermined polynomial coefficients and solves by numeric sampling, then
366/// returns the symbolic antiderivative (verified by `d/dx`).  Returns `None`
367/// when the integrand is not of this form or the solve is singular.
368pub(crate) fn integrate_pexp_trig(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
369    // Decompose factors into polynomial part, exp rate a, trig rate b.
370    let factors: Vec<ExprId> = match pool.get(expr) {
371        ExprData::Mul(args) => args,
372        _ => vec![expr],
373    };
374    let mut exp_rate = 0.0_f64;
375    let mut trig: Option<(bool, f64)> = None; // (is_sin, rate)
376    let mut poly_factors: Vec<ExprId> = Vec::new();
377    for f in factors {
378        match pool.get(f) {
379            ExprData::Func { name, args } if name == "exp" && args.len() == 1 => {
380                exp_rate += linear_rate_of(args[0], var, pool)?;
381            }
382            ExprData::Func { name, args }
383                if (name == "cos" || name == "sin") && args.len() == 1 =>
384            {
385                if trig.is_some() {
386                    return None;
387                }
388                trig = Some((name == "sin", linear_rate_of(args[0], var, pool)?));
389            }
390            _ => {
391                if contains(f, var, pool) && poly_degree_in(f, var, pool).is_none() {
392                    return None;
393                }
394                poly_factors.push(f);
395            }
396        }
397    }
398    let poly = if poly_factors.is_empty() {
399        pool.integer(1_i32)
400    } else {
401        simp(pool.mul(poly_factors), pool)
402    };
403    let deg = poly_degree_in(poly, var, pool)?;
404    if exp_rate == 0.0 && trig.is_none() {
405        return None; // pure polynomial — the engine already handles this
406    }
407
408    // Ansatz: F = e^{a x}·Σ_{k≤deg} (A_k x^k cos b x + B_k x^k sin b x)  (cos&sin
409    // only when trig present; otherwise just e^{a x}·Σ A_k x^k).
410    let exp_factor = if exp_rate != 0.0 {
411        Some(simp(
412            pool.func("exp", vec![mul_c(exp_rate, var, pool)]),
413            pool,
414        ))
415    } else {
416        None
417    };
418    let mut mods: Vec<ExprId> = Vec::new();
419    if let Some((_, b)) = trig {
420        let bx = mul_c(b, var, pool);
421        mods.push(pool.func("cos", vec![bx]));
422        mods.push(pool.func("sin", vec![bx]));
423    } else {
424        mods.push(pool.integer(1_i32));
425    }
426    let mut terms: Vec<ExprId> = Vec::new();
427    for k in 0..=deg {
428        let xk = if k == 0 {
429            pool.integer(1_i32)
430        } else {
431            pool.pow(var, pool.integer(k as i32))
432        };
433        for &m in &mods {
434            let mut fac = vec![xk, m];
435            if let Some(e) = exp_factor {
436                fac.push(e);
437            }
438            terms.push(simp(pool.mul(fac), pool));
439        }
440    }
441    let k = terms.len();
442    // Solve Σ A_j (d/dx term_j) = integrand by sampling at k points.
443    let mut dterms: Vec<ExprId> = Vec::with_capacity(k);
444    for &t in &terms {
445        dterms.push(simp(diff(t, var, pool).ok()?.value, pool));
446    }
447    let samples: Vec<f64> = (0..k).map(|i| 0.41 + 0.47 * i as f64).collect();
448    let mut mat = vec![vec![0.0; k]; k];
449    let mut rhs = vec![0.0; k];
450    for (i, &xv) in samples.iter().enumerate() {
451        let mut env = HashMap::new();
452        env.insert(var, xv);
453        for (j, &dt) in dterms.iter().enumerate() {
454            mat[i][j] = verify::eval(dt, &env, pool)?;
455        }
456        rhs[i] = verify::eval(expr, &env, pool)?;
457    }
458    let sol = gaussian_solve(&mut mat, &mut rhs)?;
459    let mut out = Vec::new();
460    for (j, &t) in terms.iter().enumerate() {
461        if sol[j].abs() < 1e-12 {
462            continue;
463        }
464        out.push(pool.mul(vec![f64_rational(sol[j], pool), t]));
465    }
466    let f = simp(pool.add(out), pool);
467    // Verify d/dx f == expr numerically before trusting it.
468    let df = simp(diff(f, var, pool).ok()?.value, pool);
469    for xv in [0.23_f64, 0.61, 1.07, 1.53] {
470        let mut env = HashMap::new();
471        env.insert(var, xv);
472        let lhs = verify::eval(df, &env, pool)?;
473        let rhsv = verify::eval(expr, &env, pool)?;
474        if (lhs - rhsv).abs() > 1e-6 {
475            return None;
476        }
477    }
478    Some(f)
479}
480
481/// `arg = c·var` (through the origin) → `c`.
482fn linear_rate_of(arg: ExprId, var: ExprId, pool: &ExprPool) -> Option<f64> {
483    let d = diff(arg, var, pool).ok()?.value;
484    if contains(d, var, pool) {
485        return None;
486    }
487    let dx = simp(pool.mul(vec![d, var]), pool);
488    if !is_zero(sub(arg, dx, pool), pool) {
489        return None;
490    }
491    try_expr_f64(simp(d, pool), pool)
492}
493
494fn poly_degree_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<usize> {
495    if !contains(expr, var, pool) {
496        return Some(0);
497    }
498    match pool.get(expr) {
499        ExprData::Symbol { .. } => Some(1),
500        ExprData::Add(args) => args
501            .iter()
502            .map(|&a| poly_degree_in(a, var, pool))
503            .try_fold(0usize, |acc, d| Some(acc.max(d?))),
504        ExprData::Mul(args) => args
505            .iter()
506            .map(|&a| poly_degree_in(a, var, pool))
507            .try_fold(0usize, |acc, d| Some(acc + d?)),
508        ExprData::Pow { base, exp } if base == var => {
509            if let ExprData::Integer(k) = pool.get(exp) {
510                let k = k.0.to_i64()?;
511                if k >= 0 {
512                    return Some(k as usize);
513                }
514            }
515            None
516        }
517        _ => None,
518    }
519}
520
521fn mul_c(c: f64, var: ExprId, pool: &ExprPool) -> ExprId {
522    simp(pool.mul(vec![f64_rational(c, pool), var]), pool)
523}
524
525fn f64_rational(v: f64, pool: &ExprPool) -> ExprId {
526    if v == v.round() {
527        return pool.integer(v as i64);
528    }
529    for den in 2..=24_i64 {
530        let num = v * den as f64;
531        if (num - num.round()).abs() < 1e-9 {
532            return pool.rational(num.round() as i64, den);
533        }
534    }
535    pool.float(v, 53)
536}
537
538/// Gaussian elimination with partial pivoting; `None` on singularity.
539#[allow(clippy::needless_range_loop)]
540fn gaussian_solve(mat: &mut [Vec<f64>], rhs: &mut [f64]) -> Option<Vec<f64>> {
541    let n = rhs.len();
542    for col in 0..n {
543        let mut piv = col;
544        for r in (col + 1)..n {
545            if mat[r][col].abs() > mat[piv][col].abs() {
546                piv = r;
547            }
548        }
549        if mat[piv][col].abs() < 1e-12 {
550            return None;
551        }
552        mat.swap(col, piv);
553        rhs.swap(col, piv);
554        for r in 0..n {
555            if r == col {
556                continue;
557            }
558            let factor = mat[r][col] / mat[col][col];
559            for c in col..n {
560                mat[r][c] -= factor * mat[col][c];
561            }
562            rhs[r] -= factor * rhs[col];
563        }
564    }
565    Some((0..n).map(|i| rhs[i] / mat[i][i]).collect())
566}
567
568/// Does `expr` contain `needle` as a sub-expression?
569pub(crate) fn contains(expr: ExprId, needle: ExprId, pool: &ExprPool) -> bool {
570    if expr == needle {
571        return true;
572    }
573    pool.with(expr, |d| match d {
574        ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
575            args.iter().any(|&a| contains(a, needle, pool))
576        }
577        ExprData::Pow { base, exp } => {
578            contains(*base, needle, pool) || contains(*exp, needle, pool)
579        }
580        _ => false,
581    })
582}
583
584/// `a - b`, simplified.
585pub(crate) fn sub(a: ExprId, b: ExprId, pool: &ExprPool) -> ExprId {
586    let neg_b = pool.mul(vec![pool.integer(-1_i32), b]);
587    simp(pool.add(vec![a, neg_b]), pool)
588}
589
590/// `a / b`, simplified.
591pub(crate) fn div(a: ExprId, b: ExprId, pool: &ExprPool) -> ExprId {
592    let inv_b = pool.pow(b, pool.integer(-1_i32));
593    simp(pool.mul(vec![a, inv_b]), pool)
594}
595
596/// Substitute a single symbol → replacement, simplifying the result.
597pub(crate) fn subs1(expr: ExprId, from: ExprId, to: ExprId, pool: &ExprPool) -> ExprId {
598    let mut m = HashMap::new();
599    m.insert(from, to);
600    simp(crate::kernel::subs::subs(expr, &m, pool), pool)
601}
602
603/// Is `expr` the literal zero after simplification?
604pub(crate) fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
605    let s = simp(expr, pool);
606    matches!(pool.get(s), ExprData::Integer(n) if n.0 == 0)
607        || matches!(try_expr_f64(s, pool), Some(v) if v == 0.0)
608}
609
610#[cfg(test)]
611mod tests;