mathhook_core/calculus/ode/
solver.rs1use crate::core::{Expression, Symbol};
7
8use super::classifier::{ODEClassifier, ODEType};
9use super::first_order::ODEResult;
10use super::registry::ODESolverRegistry;
11use super::second_order::ConstantCoeffSecondOrderSolver;
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct SolverConfig {
16 pub tolerance: f64,
17 pub max_iterations: usize,
18 pub simplify: bool,
19 pub educational_mode: bool,
20}
21
22impl Default for SolverConfig {
23 fn default() -> Self {
24 Self {
25 tolerance: 1e-10,
26 max_iterations: 1000,
27 simplify: true,
28 educational_mode: false,
29 }
30 }
31}
32
33#[derive(Debug, Clone, PartialEq)]
35pub struct SolutionMetadata {
36 pub ode_type: ODEType,
37 pub method: String,
38 pub fallback_used: bool,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub struct ODESolution {
44 pub solution: Expression,
45 pub metadata: SolutionMetadata,
46}
47
48pub struct ODESolver {
50 registry: ODESolverRegistry,
51 config: SolverConfig,
52}
53
54impl ODESolver {
55 pub fn new() -> Self {
57 Self::with_config(SolverConfig::default())
58 }
59
60 pub fn with_config(config: SolverConfig) -> Self {
62 Self {
63 registry: ODESolverRegistry::new(),
64 config,
65 }
66 }
67
68 #[inline]
79 pub fn tolerance(mut self, tol: f64) -> Self {
80 self.config.tolerance = tol;
81 self
82 }
83
84 #[inline]
95 pub fn max_iterations(mut self, max: usize) -> Self {
96 self.config.max_iterations = max;
97 self
98 }
99
100 #[inline]
111 pub fn simplify(mut self, enable: bool) -> Self {
112 self.config.simplify = enable;
113 self
114 }
115
116 #[inline]
129 pub fn educational(mut self, enable: bool) -> Self {
130 self.config.educational_mode = enable;
131 self
132 }
133
134 #[inline]
136 pub fn config(&self) -> &SolverConfig {
137 &self.config
138 }
139
140 pub fn solve_first_order(
170 &self,
171 rhs: &Expression,
172 dependent: &Symbol,
173 independent: &Symbol,
174 ) -> ODEResult {
175 let ode_type = ODEClassifier::classify_first_order(rhs, dependent, independent);
176
177 let solution = if let Some(solver) = self.registry.get_solver(&ode_type) {
178 solver.solve(rhs, dependent, independent)
179 } else {
180 self.registry.try_all_solvers(rhs, dependent, independent)
181 }?;
182
183 if self.config.simplify {
184 use crate::simplify::Simplify;
185 Ok(solution.simplify())
186 } else {
187 Ok(solution)
188 }
189 }
190
191 pub fn solve_ivp(
218 &self,
219 rhs: &Expression,
220 dependent: &Symbol,
221 independent: &Symbol,
222 x0: Expression,
223 y0: Expression,
224 ) -> ODEResult {
225 let _ = (x0, y0);
226 self.solve_first_order(rhs, dependent, independent)
227 }
228
229 pub fn solve_second_order(
264 &self,
265 a: &Expression,
266 b: &Expression,
267 c: &Expression,
268 r: &Expression,
269 dependent: &Symbol,
270 independent: &Symbol,
271 ) -> ODEResult {
272 let solver = ConstantCoeffSecondOrderSolver::new();
273 let solution = solver.solve(a, b, c, r, dependent, independent, None)?;
274
275 if self.config.simplify {
276 use crate::simplify::Simplify;
277 Ok(solution.simplify())
278 } else {
279 Ok(solution)
280 }
281 }
282}
283
284impl Default for ODESolver {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290impl ODEType {
291 pub fn to_string(&self) -> &str {
292 match self {
293 ODEType::Separable => "Separable",
294 ODEType::LinearFirstOrder => "Linear First-Order",
295 ODEType::Exact => "Exact",
296 ODEType::Bernoulli => "Bernoulli",
297 ODEType::Homogeneous => "Homogeneous",
298 ODEType::ConstantCoefficients => "Constant Coefficients",
299 ODEType::VariableCoefficients => "Variable Coefficients",
300 ODEType::Unknown => "Unknown",
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::{expr, symbol};
309
310 #[test]
311 fn test_solve_separable_automatic() {
312 let x = symbol!(x);
313 let y = symbol!(y);
314 let rhs = expr!(x * y);
315
316 let solver = ODESolver::new();
317 let solution = solver.solve_first_order(&rhs, &y, &x);
318
319 assert!(solution.is_ok());
320 let sol = solution.unwrap();
321 assert!(sol.to_string().contains("exp") || sol.to_string().contains("C"));
322 }
323
324 #[test]
325 fn test_solve_second_order_automatic() {
326 let x = symbol!(x);
327 let y = symbol!(y);
328
329 let solver = ODESolver::new();
330 let solution =
331 solver.solve_second_order(&expr!(1), &expr!(0), &expr!(-1), &expr!(0), &y, &x);
332
333 assert!(solution.is_ok());
334 }
335
336 #[test]
337 fn test_fallback_to_separable() {
338 let x = symbol!(x);
339 let y = symbol!(y);
340 let rhs = expr!(x / y);
341
342 let solver = ODESolver::new();
343 let solution = solver.solve_first_order(&rhs, &y, &x);
344
345 assert!(solution.is_ok());
346 }
347
348 #[test]
349 fn test_ode_type_to_string() {
350 assert_eq!(ODEType::Separable.to_string(), "Separable");
351 assert_eq!(ODEType::LinearFirstOrder.to_string(), "Linear First-Order");
352 assert_eq!(ODEType::Bernoulli.to_string(), "Bernoulli");
353 assert_eq!(
354 ODEType::ConstantCoefficients.to_string(),
355 "Constant Coefficients"
356 );
357 assert_eq!(ODEType::Unknown.to_string(), "Unknown");
358 }
359
360 #[test]
361 fn test_routing_prioritizes_separable() {
362 let x = symbol!(x);
363 let y = symbol!(y);
364 let rhs = expr!(x * y);
365
366 let ode_type = ODEClassifier::classify_first_order(&rhs, &y, &x);
367 assert_eq!(ode_type, ODEType::Separable);
368 }
369
370 #[test]
371 fn test_registry_based_dispatch() {
372 let x = symbol!(x);
373 let y = symbol!(y);
374
375 let solver = ODESolver::new();
376 let rhs_separable = expr!(x * y);
377 assert!(solver.solve_first_order(&rhs_separable, &y, &x).is_ok());
378 }
379
380 #[test]
381 fn test_builder_pattern() {
382 let solver = ODESolver::new()
383 .tolerance(1e-12)
384 .max_iterations(5000)
385 .simplify(false)
386 .educational(true);
387
388 assert_eq!(solver.config().tolerance, 1e-12);
389 assert_eq!(solver.config().max_iterations, 5000);
390 assert!(!solver.config().simplify);
391 assert!(solver.config().educational_mode);
392 }
393
394 #[test]
395 fn test_default_config() {
396 let solver = ODESolver::new();
397 let config = solver.config();
398
399 assert_eq!(config.tolerance, 1e-10);
400 assert_eq!(config.max_iterations, 1000);
401 assert!(config.simplify);
402 assert!(!config.educational_mode);
403 }
404
405 #[test]
406 fn test_custom_config() {
407 let config = SolverConfig {
408 tolerance: 1e-15,
409 max_iterations: 10000,
410 simplify: false,
411 educational_mode: true,
412 };
413
414 let solver = ODESolver::with_config(config.clone());
415 assert_eq!(solver.config(), &config);
416 }
417}