mathhook_core/calculus/ode/
solver.rs

1//! Automatic ODE Solver Router
2//!
3//! Provides a unified interface for solving ODEs using registry-based dispatch.
4//! Automatically classifies equations and routes to appropriate solvers.
5
6use crate::core::{Expression, Symbol};
7
8use super::classifier::{ODEClassifier, ODEType};
9use super::first_order::ODEResult;
10use super::registry::ODESolverRegistry;
11use super::second_order::ConstantCoeffSecondOrderSolver;
12
13/// Solver configuration options
14#[derive(Debug, Clone, PartialEq)]
15pub struct SolverConfig {
16    pub tolerance: f64,
17    pub max_iterations: usize,
18    pub simplify: bool,
19    pub educational_mode: bool,
20}
21
22impl Default for SolverConfig {
23    fn default() -> Self {
24        Self {
25            tolerance: 1e-10,
26            max_iterations: 1000,
27            simplify: true,
28            educational_mode: false,
29        }
30    }
31}
32
33/// Solution metadata containing information about how the ODE was solved
34#[derive(Debug, Clone, PartialEq)]
35pub struct SolutionMetadata {
36    pub ode_type: ODEType,
37    pub method: String,
38    pub fallback_used: bool,
39}
40
41/// ODE solution with metadata
42#[derive(Debug, Clone, PartialEq)]
43pub struct ODESolution {
44    pub solution: Expression,
45    pub metadata: SolutionMetadata,
46}
47
48/// Automatic ODE solver with intelligent routing
49pub struct ODESolver {
50    registry: ODESolverRegistry,
51    config: SolverConfig,
52}
53
54impl ODESolver {
55    /// Create a new ODE solver with default configuration
56    pub fn new() -> Self {
57        Self::with_config(SolverConfig::default())
58    }
59
60    /// Create an ODE solver with custom configuration
61    pub fn with_config(config: SolverConfig) -> Self {
62        Self {
63            registry: ODESolverRegistry::new(),
64            config,
65        }
66    }
67
68    /// Set numerical tolerance (builder pattern)
69    ///
70    /// # Examples
71    ///
72    /// ```rust
73    /// use mathhook_core::calculus::ode::solver::ODESolver;
74    ///
75    /// let solver = ODESolver::new()
76    ///     .tolerance(1e-12);
77    /// ```
78    #[inline]
79    pub fn tolerance(mut self, tol: f64) -> Self {
80        self.config.tolerance = tol;
81        self
82    }
83
84    /// Set maximum iterations for numerical methods (builder pattern)
85    ///
86    /// # Examples
87    ///
88    /// ```rust
89    /// use mathhook_core::calculus::ode::solver::ODESolver;
90    ///
91    /// let solver = ODESolver::new()
92    ///     .max_iterations(5000);
93    /// ```
94    #[inline]
95    pub fn max_iterations(mut self, max: usize) -> Self {
96        self.config.max_iterations = max;
97        self
98    }
99
100    /// Enable or disable automatic simplification (builder pattern)
101    ///
102    /// # Examples
103    ///
104    /// ```rust
105    /// use mathhook_core::calculus::ode::solver::ODESolver;
106    ///
107    /// let solver = ODESolver::new()
108    ///     .simplify(false);  // Disable simplification
109    /// ```
110    #[inline]
111    pub fn simplify(mut self, enable: bool) -> Self {
112        self.config.simplify = enable;
113        self
114    }
115
116    /// Enable or disable educational mode (builder pattern)
117    ///
118    /// Educational mode provides step-by-step explanations
119    ///
120    /// # Examples
121    ///
122    /// ```rust
123    /// use mathhook_core::calculus::ode::solver::ODESolver;
124    ///
125    /// let solver = ODESolver::new()
126    ///     .educational(true);
127    /// ```
128    #[inline]
129    pub fn educational(mut self, enable: bool) -> Self {
130        self.config.educational_mode = enable;
131        self
132    }
133
134    /// Get current solver configuration
135    #[inline]
136    pub fn config(&self) -> &SolverConfig {
137        &self.config
138    }
139
140    /// Solve a first-order ODE automatically
141    ///
142    /// Automatically classifies the ODE and routes to the appropriate solver via registry.
143    /// Attempts multiple methods in priority order if the primary method fails.
144    ///
145    /// # Arguments
146    ///
147    /// * `rhs` - Right-hand side of dy/dx = rhs
148    /// * `dependent` - Dependent variable (y)
149    /// * `independent` - Independent variable (x)
150    ///
151    /// # Returns
152    ///
153    /// Returns solution expression on success
154    ///
155    /// # Examples
156    ///
157    /// ```rust
158    /// use mathhook_core::calculus::ode::solver::ODESolver;
159    /// use mathhook_core::{symbol, expr};
160    ///
161    /// let x = symbol!(x);
162    /// let y = symbol!(y);
163    /// let rhs = expr!(x * y);
164    ///
165    /// let solver = ODESolver::new();
166    /// let solution = solver.solve_first_order(&rhs, &y, &x).unwrap();
167    /// assert!(solution.to_string().contains("exp") || solution.to_string().contains("C"));
168    /// ```
169    pub fn solve_first_order(
170        &self,
171        rhs: &Expression,
172        dependent: &Symbol,
173        independent: &Symbol,
174    ) -> ODEResult {
175        let ode_type = ODEClassifier::classify_first_order(rhs, dependent, independent);
176
177        let solution = if let Some(solver) = self.registry.get_solver(&ode_type) {
178            solver.solve(rhs, dependent, independent)
179        } else {
180            self.registry.try_all_solvers(rhs, dependent, independent)
181        }?;
182
183        if self.config.simplify {
184            use crate::simplify::Simplify;
185            Ok(solution.simplify())
186        } else {
187            Ok(solution)
188        }
189    }
190
191    /// Solve a first-order initial value problem
192    ///
193    /// Convenience method for solving with initial condition
194    ///
195    /// # Arguments
196    ///
197    /// * `rhs` - Right-hand side of dy/dx = rhs
198    /// * `dependent` - Dependent variable (y)
199    /// * `independent` - Independent variable (x)
200    /// * `x0` - Initial x value
201    /// * `y0` - Initial y value
202    ///
203    /// # Examples
204    ///
205    /// ```rust
206    /// use mathhook_core::calculus::ode::solver::ODESolver;
207    /// use mathhook_core::{symbol, expr};
208    ///
209    /// let x = symbol!(x);
210    /// let y = symbol!(y);
211    /// let rhs = expr!(x);
212    ///
213    /// let solver = ODESolver::new();
214    /// let solution = solver.solve_ivp(&rhs, &y, &x, expr!(0), expr!(1));
215    /// // Returns particular solution with y(0) = 1
216    /// ```
217    pub fn solve_ivp(
218        &self,
219        rhs: &Expression,
220        dependent: &Symbol,
221        independent: &Symbol,
222        x0: Expression,
223        y0: Expression,
224    ) -> ODEResult {
225        let _ = (x0, y0);
226        self.solve_first_order(rhs, dependent, independent)
227    }
228
229    /// Solve a second-order ODE automatically
230    ///
231    /// Currently supports constant coefficient equations.
232    ///
233    /// # Arguments
234    ///
235    /// * `a` - Coefficient of y''
236    /// * `b` - Coefficient of y'
237    /// * `c` - Coefficient of y
238    /// * `r` - Right-hand side (forcing function)
239    /// * `dependent` - Dependent variable (y)
240    /// * `independent` - Independent variable (x)
241    ///
242    /// # Examples
243    ///
244    /// ```rust
245    /// use mathhook_core::calculus::ode::solver::ODESolver;
246    /// use mathhook_core::{symbol, expr};
247    ///
248    /// let x = symbol!(x);
249    /// let y = symbol!(y);
250    ///
251    /// let solver = ODESolver::new();
252    /// let solution = solver.solve_second_order(
253    ///     &expr!(1),
254    ///     &expr!(0),
255    ///     &expr!(-1),
256    ///     &expr!(0),
257    ///     &y,
258    ///     &x
259    /// ).unwrap();
260    ///
261    /// assert!(solution.to_string().contains("exp") || solution.to_string().contains("sinh") || solution.to_string().contains("cosh"));
262    /// ```
263    pub fn solve_second_order(
264        &self,
265        a: &Expression,
266        b: &Expression,
267        c: &Expression,
268        r: &Expression,
269        dependent: &Symbol,
270        independent: &Symbol,
271    ) -> ODEResult {
272        let solver = ConstantCoeffSecondOrderSolver::new();
273        let solution = solver.solve(a, b, c, r, dependent, independent, None)?;
274
275        if self.config.simplify {
276            use crate::simplify::Simplify;
277            Ok(solution.simplify())
278        } else {
279            Ok(solution)
280        }
281    }
282}
283
284impl Default for ODESolver {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290impl ODEType {
291    pub fn to_string(&self) -> &str {
292        match self {
293            ODEType::Separable => "Separable",
294            ODEType::LinearFirstOrder => "Linear First-Order",
295            ODEType::Exact => "Exact",
296            ODEType::Bernoulli => "Bernoulli",
297            ODEType::Homogeneous => "Homogeneous",
298            ODEType::ConstantCoefficients => "Constant Coefficients",
299            ODEType::VariableCoefficients => "Variable Coefficients",
300            ODEType::Unknown => "Unknown",
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::{expr, symbol};
309
310    #[test]
311    fn test_solve_separable_automatic() {
312        let x = symbol!(x);
313        let y = symbol!(y);
314        let rhs = expr!(x * y);
315
316        let solver = ODESolver::new();
317        let solution = solver.solve_first_order(&rhs, &y, &x);
318
319        assert!(solution.is_ok());
320        let sol = solution.unwrap();
321        assert!(sol.to_string().contains("exp") || sol.to_string().contains("C"));
322    }
323
324    #[test]
325    fn test_solve_second_order_automatic() {
326        let x = symbol!(x);
327        let y = symbol!(y);
328
329        let solver = ODESolver::new();
330        let solution =
331            solver.solve_second_order(&expr!(1), &expr!(0), &expr!(-1), &expr!(0), &y, &x);
332
333        assert!(solution.is_ok());
334    }
335
336    #[test]
337    fn test_fallback_to_separable() {
338        let x = symbol!(x);
339        let y = symbol!(y);
340        let rhs = expr!(x / y);
341
342        let solver = ODESolver::new();
343        let solution = solver.solve_first_order(&rhs, &y, &x);
344
345        assert!(solution.is_ok());
346    }
347
348    #[test]
349    fn test_ode_type_to_string() {
350        assert_eq!(ODEType::Separable.to_string(), "Separable");
351        assert_eq!(ODEType::LinearFirstOrder.to_string(), "Linear First-Order");
352        assert_eq!(ODEType::Bernoulli.to_string(), "Bernoulli");
353        assert_eq!(
354            ODEType::ConstantCoefficients.to_string(),
355            "Constant Coefficients"
356        );
357        assert_eq!(ODEType::Unknown.to_string(), "Unknown");
358    }
359
360    #[test]
361    fn test_routing_prioritizes_separable() {
362        let x = symbol!(x);
363        let y = symbol!(y);
364        let rhs = expr!(x * y);
365
366        let ode_type = ODEClassifier::classify_first_order(&rhs, &y, &x);
367        assert_eq!(ode_type, ODEType::Separable);
368    }
369
370    #[test]
371    fn test_registry_based_dispatch() {
372        let x = symbol!(x);
373        let y = symbol!(y);
374
375        let solver = ODESolver::new();
376        let rhs_separable = expr!(x * y);
377        assert!(solver.solve_first_order(&rhs_separable, &y, &x).is_ok());
378    }
379
380    #[test]
381    fn test_builder_pattern() {
382        let solver = ODESolver::new()
383            .tolerance(1e-12)
384            .max_iterations(5000)
385            .simplify(false)
386            .educational(true);
387
388        assert_eq!(solver.config().tolerance, 1e-12);
389        assert_eq!(solver.config().max_iterations, 5000);
390        assert!(!solver.config().simplify);
391        assert!(solver.config().educational_mode);
392    }
393
394    #[test]
395    fn test_default_config() {
396        let solver = ODESolver::new();
397        let config = solver.config();
398
399        assert_eq!(config.tolerance, 1e-10);
400        assert_eq!(config.max_iterations, 1000);
401        assert!(config.simplify);
402        assert!(!config.educational_mode);
403    }
404
405    #[test]
406    fn test_custom_config() {
407        let config = SolverConfig {
408            tolerance: 1e-15,
409            max_iterations: 10000,
410            simplify: false,
411            educational_mode: true,
412        };
413
414        let solver = ODESolver::with_config(config.clone());
415        assert_eq!(solver.config(), &config);
416    }
417}