mathhook_core/calculus/ode/
registry.rs

1//! ODE Solver Registry
2//!
3//! Registry-based dispatch system for ODE solvers, eliminating hardcoded match patterns.
4
5use super::classifier::ODEType;
6use super::first_order::{
7    HomogeneousODESolver, LinearFirstOrderSolver, ODEError, ODEResult, SeparableODESolver,
8};
9use crate::core::{Expression, Symbol};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Check if an expression contains a given symbol
14fn contains_symbol(expr: &Expression, sym: &Symbol) -> bool {
15    match expr {
16        Expression::Symbol(s) => s == sym,
17        Expression::Add(terms) | Expression::Mul(terms) => {
18            terms.iter().any(|t| contains_symbol(t, sym))
19        }
20        Expression::Pow(base, exp) => contains_symbol(base, sym) || contains_symbol(exp, sym),
21        Expression::Function { args, .. } => args.iter().any(|a| contains_symbol(a, sym)),
22        _ => false,
23    }
24}
25
26/// Trait for first-order ODE solvers
27pub trait FirstOrderSolver: Send + Sync {
28    fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult;
29
30    fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool;
31
32    fn name(&self) -> &'static str;
33    fn description(&self) -> &'static str;
34}
35
36struct SeparableSolverAdapter;
37
38impl FirstOrderSolver for SeparableSolverAdapter {
39    fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
40        let solver = SeparableODESolver::new();
41        solver.solve(rhs, dependent, independent, None)
42    }
43
44    #[inline]
45    fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
46        SeparableODESolver::new().is_separable(rhs, dependent, independent)
47    }
48
49    #[inline]
50    fn name(&self) -> &'static str {
51        "Separable"
52    }
53
54    #[inline]
55    fn description(&self) -> &'static str {
56        "Solves separable ODEs of the form dy/dx = g(x)h(y)"
57    }
58}
59
60struct LinearFirstOrderSolverAdapter;
61
62impl FirstOrderSolver for LinearFirstOrderSolverAdapter {
63    fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
64        let (p, q) = extract_linear_coefficients(rhs, dependent, independent)?;
65        let solver = LinearFirstOrderSolver;
66        LinearFirstOrderSolver::solve(&solver, &p, &q, dependent, independent, None)
67    }
68
69    #[inline]
70    fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
71        extract_linear_coefficients(rhs, dependent, independent).is_ok()
72    }
73
74    #[inline]
75    fn name(&self) -> &'static str {
76        "Linear First-Order"
77    }
78
79    #[inline]
80    fn description(&self) -> &'static str {
81        "Solves linear first-order ODEs using integrating factor method"
82    }
83}
84
85struct HomogeneousSolverAdapter;
86
87impl FirstOrderSolver for HomogeneousSolverAdapter {
88    fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
89        let solver = HomogeneousODESolver;
90        solver.solve(rhs, dependent, independent)
91    }
92
93    #[inline]
94    fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
95        HomogeneousODESolver.is_homogeneous(rhs, dependent, independent)
96    }
97
98    #[inline]
99    fn name(&self) -> &'static str {
100        "Homogeneous"
101    }
102
103    #[inline]
104    fn description(&self) -> &'static str {
105        "Solves homogeneous ODEs of the form dy/dx = f(y/x)"
106    }
107}
108
109pub struct ODESolverRegistry {
110    solvers: HashMap<ODEType, Arc<dyn FirstOrderSolver>>,
111    priority_order: Vec<ODEType>,
112}
113
114impl ODESolverRegistry {
115    pub fn new() -> Self {
116        let mut solvers: HashMap<ODEType, Arc<dyn FirstOrderSolver>> = HashMap::new();
117
118        solvers.insert(ODEType::Separable, Arc::new(SeparableSolverAdapter));
119        solvers.insert(
120            ODEType::LinearFirstOrder,
121            Arc::new(LinearFirstOrderSolverAdapter),
122        );
123        solvers.insert(ODEType::Homogeneous, Arc::new(HomogeneousSolverAdapter));
124
125        let priority_order = vec![
126            ODEType::Separable,
127            ODEType::LinearFirstOrder,
128            ODEType::Homogeneous,
129        ];
130
131        Self {
132            solvers,
133            priority_order,
134        }
135    }
136
137    #[inline]
138    pub fn get_solver(&self, ode_type: &ODEType) -> Option<&Arc<dyn FirstOrderSolver>> {
139        self.solvers.get(ode_type)
140    }
141
142    pub fn try_all_solvers(
143        &self,
144        rhs: &Expression,
145        dependent: &Symbol,
146        independent: &Symbol,
147    ) -> ODEResult {
148        for ode_type in &self.priority_order {
149            if let Some(solver) = self.solvers.get(ode_type) {
150                if solver.can_solve(rhs, dependent, independent) {
151                    return solver.solve(rhs, dependent, independent);
152                }
153            }
154        }
155
156        Err(ODEError::UnknownType {
157            equation: rhs.clone(),
158            reason: "No suitable solver found after trying all registered methods".to_owned(),
159        })
160    }
161}
162
163impl Default for ODESolverRegistry {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169fn extract_linear_coefficients(
170    rhs: &Expression,
171    dependent: &Symbol,
172    _independent: &Symbol,
173) -> Result<(Expression, Expression), ODEError> {
174    use crate::expr;
175
176    match rhs {
177        Expression::Add(terms) => {
178            let mut p_terms = Vec::new();
179            let mut q_terms = Vec::new();
180
181            for term in terms.iter() {
182                if contains_symbol(term, dependent) {
183                    if let Some(_coeff) = extract_y_coefficient(term, dependent) {
184                        p_terms.push(expr!((-1) * _coeff));
185                    } else {
186                        return Err(ODEError::NotLinearForm {
187                            reason: "Cannot extract coefficient from term containing y".to_owned(),
188                        });
189                    }
190                } else {
191                    q_terms.push(term.clone());
192                }
193            }
194
195            let p = if p_terms.is_empty() {
196                expr!(0)
197            } else {
198                Expression::add(p_terms)
199            };
200
201            let q = if q_terms.is_empty() {
202                expr!(0)
203            } else {
204                Expression::add(q_terms)
205            };
206
207            Ok((p, q))
208        }
209        Expression::Mul(factors) => {
210            let mut y_factor = None;
211            let mut other_factors = Vec::new();
212
213            for factor in factors.iter() {
214                if contains_symbol(factor, dependent) {
215                    if matches!(factor, Expression::Symbol(s) if s == dependent) {
216                        y_factor = Some(expr!(1));
217                    } else {
218                        return Err(ODEError::NotLinearForm {
219                            reason: "Complex y term in product".to_owned(),
220                        });
221                    }
222                } else {
223                    other_factors.push(factor.clone());
224                }
225            }
226
227            if y_factor.is_some() {
228                let _coeff = if other_factors.is_empty() {
229                    expr!(1)
230                } else {
231                    Expression::mul(other_factors)
232                };
233
234                Ok((expr!((-1) * _coeff), expr!(0)))
235            } else {
236                Ok((expr!(0), rhs.clone()))
237            }
238        }
239        _ => {
240            if contains_symbol(rhs, dependent) {
241                if matches!(rhs, Expression::Symbol(s) if s == dependent) {
242                    Ok((expr!(-1), expr!(0)))
243                } else {
244                    Err(ODEError::NotLinearForm {
245                        reason: "Cannot extract linear form".to_owned(),
246                    })
247                }
248            } else {
249                Ok((expr!(0), rhs.clone()))
250            }
251        }
252    }
253}
254
255fn extract_y_coefficient(term: &Expression, y: &Symbol) -> Option<Expression> {
256    use crate::expr;
257
258    match term {
259        Expression::Symbol(s) if s == y => Some(expr!(1)),
260        Expression::Mul(factors) => {
261            let mut coeff_factors = Vec::new();
262            let mut found_y = false;
263
264            for factor in factors.iter() {
265                if matches!(factor, Expression::Symbol(s) if s == y) {
266                    found_y = true;
267                } else {
268                    coeff_factors.push(factor.clone());
269                }
270            }
271
272            if found_y {
273                Some(if coeff_factors.is_empty() {
274                    expr!(1)
275                } else {
276                    Expression::mul(coeff_factors)
277                })
278            } else {
279                None
280            }
281        }
282        _ => None,
283    }
284}