mathhook_core/calculus/ode/
registry.rs1use super::classifier::ODEType;
6use super::first_order::{
7 HomogeneousODESolver, LinearFirstOrderSolver, ODEError, ODEResult, SeparableODESolver,
8};
9use crate::core::{Expression, Symbol};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13fn contains_symbol(expr: &Expression, sym: &Symbol) -> bool {
15 match expr {
16 Expression::Symbol(s) => s == sym,
17 Expression::Add(terms) | Expression::Mul(terms) => {
18 terms.iter().any(|t| contains_symbol(t, sym))
19 }
20 Expression::Pow(base, exp) => contains_symbol(base, sym) || contains_symbol(exp, sym),
21 Expression::Function { args, .. } => args.iter().any(|a| contains_symbol(a, sym)),
22 _ => false,
23 }
24}
25
26pub trait FirstOrderSolver: Send + Sync {
28 fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult;
29
30 fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool;
31
32 fn name(&self) -> &'static str;
33 fn description(&self) -> &'static str;
34}
35
36struct SeparableSolverAdapter;
37
38impl FirstOrderSolver for SeparableSolverAdapter {
39 fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
40 let solver = SeparableODESolver::new();
41 solver.solve(rhs, dependent, independent, None)
42 }
43
44 #[inline]
45 fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
46 SeparableODESolver::new().is_separable(rhs, dependent, independent)
47 }
48
49 #[inline]
50 fn name(&self) -> &'static str {
51 "Separable"
52 }
53
54 #[inline]
55 fn description(&self) -> &'static str {
56 "Solves separable ODEs of the form dy/dx = g(x)h(y)"
57 }
58}
59
60struct LinearFirstOrderSolverAdapter;
61
62impl FirstOrderSolver for LinearFirstOrderSolverAdapter {
63 fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
64 let (p, q) = extract_linear_coefficients(rhs, dependent, independent)?;
65 let solver = LinearFirstOrderSolver;
66 LinearFirstOrderSolver::solve(&solver, &p, &q, dependent, independent, None)
67 }
68
69 #[inline]
70 fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
71 extract_linear_coefficients(rhs, dependent, independent).is_ok()
72 }
73
74 #[inline]
75 fn name(&self) -> &'static str {
76 "Linear First-Order"
77 }
78
79 #[inline]
80 fn description(&self) -> &'static str {
81 "Solves linear first-order ODEs using integrating factor method"
82 }
83}
84
85struct HomogeneousSolverAdapter;
86
87impl FirstOrderSolver for HomogeneousSolverAdapter {
88 fn solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> ODEResult {
89 let solver = HomogeneousODESolver;
90 solver.solve(rhs, dependent, independent)
91 }
92
93 #[inline]
94 fn can_solve(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
95 HomogeneousODESolver.is_homogeneous(rhs, dependent, independent)
96 }
97
98 #[inline]
99 fn name(&self) -> &'static str {
100 "Homogeneous"
101 }
102
103 #[inline]
104 fn description(&self) -> &'static str {
105 "Solves homogeneous ODEs of the form dy/dx = f(y/x)"
106 }
107}
108
109pub struct ODESolverRegistry {
110 solvers: HashMap<ODEType, Arc<dyn FirstOrderSolver>>,
111 priority_order: Vec<ODEType>,
112}
113
114impl ODESolverRegistry {
115 pub fn new() -> Self {
116 let mut solvers: HashMap<ODEType, Arc<dyn FirstOrderSolver>> = HashMap::new();
117
118 solvers.insert(ODEType::Separable, Arc::new(SeparableSolverAdapter));
119 solvers.insert(
120 ODEType::LinearFirstOrder,
121 Arc::new(LinearFirstOrderSolverAdapter),
122 );
123 solvers.insert(ODEType::Homogeneous, Arc::new(HomogeneousSolverAdapter));
124
125 let priority_order = vec![
126 ODEType::Separable,
127 ODEType::LinearFirstOrder,
128 ODEType::Homogeneous,
129 ];
130
131 Self {
132 solvers,
133 priority_order,
134 }
135 }
136
137 #[inline]
138 pub fn get_solver(&self, ode_type: &ODEType) -> Option<&Arc<dyn FirstOrderSolver>> {
139 self.solvers.get(ode_type)
140 }
141
142 pub fn try_all_solvers(
143 &self,
144 rhs: &Expression,
145 dependent: &Symbol,
146 independent: &Symbol,
147 ) -> ODEResult {
148 for ode_type in &self.priority_order {
149 if let Some(solver) = self.solvers.get(ode_type) {
150 if solver.can_solve(rhs, dependent, independent) {
151 return solver.solve(rhs, dependent, independent);
152 }
153 }
154 }
155
156 Err(ODEError::UnknownType {
157 equation: rhs.clone(),
158 reason: "No suitable solver found after trying all registered methods".to_owned(),
159 })
160 }
161}
162
163impl Default for ODESolverRegistry {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169fn extract_linear_coefficients(
170 rhs: &Expression,
171 dependent: &Symbol,
172 _independent: &Symbol,
173) -> Result<(Expression, Expression), ODEError> {
174 use crate::expr;
175
176 match rhs {
177 Expression::Add(terms) => {
178 let mut p_terms = Vec::new();
179 let mut q_terms = Vec::new();
180
181 for term in terms.iter() {
182 if contains_symbol(term, dependent) {
183 if let Some(_coeff) = extract_y_coefficient(term, dependent) {
184 p_terms.push(expr!((-1) * _coeff));
185 } else {
186 return Err(ODEError::NotLinearForm {
187 reason: "Cannot extract coefficient from term containing y".to_owned(),
188 });
189 }
190 } else {
191 q_terms.push(term.clone());
192 }
193 }
194
195 let p = if p_terms.is_empty() {
196 expr!(0)
197 } else {
198 Expression::add(p_terms)
199 };
200
201 let q = if q_terms.is_empty() {
202 expr!(0)
203 } else {
204 Expression::add(q_terms)
205 };
206
207 Ok((p, q))
208 }
209 Expression::Mul(factors) => {
210 let mut y_factor = None;
211 let mut other_factors = Vec::new();
212
213 for factor in factors.iter() {
214 if contains_symbol(factor, dependent) {
215 if matches!(factor, Expression::Symbol(s) if s == dependent) {
216 y_factor = Some(expr!(1));
217 } else {
218 return Err(ODEError::NotLinearForm {
219 reason: "Complex y term in product".to_owned(),
220 });
221 }
222 } else {
223 other_factors.push(factor.clone());
224 }
225 }
226
227 if y_factor.is_some() {
228 let _coeff = if other_factors.is_empty() {
229 expr!(1)
230 } else {
231 Expression::mul(other_factors)
232 };
233
234 Ok((expr!((-1) * _coeff), expr!(0)))
235 } else {
236 Ok((expr!(0), rhs.clone()))
237 }
238 }
239 _ => {
240 if contains_symbol(rhs, dependent) {
241 if matches!(rhs, Expression::Symbol(s) if s == dependent) {
242 Ok((expr!(-1), expr!(0)))
243 } else {
244 Err(ODEError::NotLinearForm {
245 reason: "Cannot extract linear form".to_owned(),
246 })
247 }
248 } else {
249 Ok((expr!(0), rhs.clone()))
250 }
251 }
252 }
253}
254
255fn extract_y_coefficient(term: &Expression, y: &Symbol) -> Option<Expression> {
256 use crate::expr;
257
258 match term {
259 Expression::Symbol(s) if s == y => Some(expr!(1)),
260 Expression::Mul(factors) => {
261 let mut coeff_factors = Vec::new();
262 let mut found_y = false;
263
264 for factor in factors.iter() {
265 if matches!(factor, Expression::Symbol(s) if s == y) {
266 found_y = true;
267 } else {
268 coeff_factors.push(factor.clone());
269 }
270 }
271
272 if found_y {
273 Some(if coeff_factors.is_empty() {
274 expr!(1)
275 } else {
276 Expression::mul(coeff_factors)
277 })
278 } else {
279 None
280 }
281 }
282 _ => None,
283 }
284}