mathhook_core/calculus/pde/
registry.rs1use super::classification::classify_pde;
9use super::types::{PDESolution, Pde, PdeType};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum PDEError {
16 NoSolverAvailable { pde_type: PdeType },
18 ClassificationFailed { reason: String },
20 SolutionFailed { solver: String, reason: String },
22 InvalidBoundaryConditions { reason: String },
24 InvalidInitialConditions { reason: String },
26 NotSeparable { reason: String },
28 InvalidForm { reason: String },
30}
31
32pub type PDEResult = Result<PDESolution, PDEError>;
34
35pub trait PDESolver: Send + Sync {
37 fn solve(&self, pde: &Pde) -> PDEResult;
39
40 fn can_solve(&self, pde_type: PdeType) -> bool;
42
43 fn priority(&self) -> u8;
45
46 fn name(&self) -> &'static str;
48
49 fn description(&self) -> &'static str;
51}
52
53pub struct PDESolverRegistry {
55 solvers: HashMap<PdeType, Vec<Arc<dyn PDESolver>>>,
57 priority_order: Vec<PdeType>,
59}
60
61impl PDESolverRegistry {
62 pub fn new() -> Self {
64 let mut registry = Self {
65 solvers: HashMap::new(),
66 priority_order: Vec::new(),
67 };
68 registry.register_all_solvers();
69 registry
70 }
71
72 fn register_all_solvers(&mut self) {
74 use super::standard::heat::HeatEquationSolver;
75 use super::standard::laplace::LaplaceEquationSolver;
76 use super::standard::wave::WaveEquationSolver;
77
78 self.register(PdeType::Parabolic, Arc::new(HeatEquationSolver::new()));
79 self.register(PdeType::Hyperbolic, Arc::new(WaveEquationSolver::new()));
80 self.register(PdeType::Elliptic, Arc::new(LaplaceEquationSolver::new()));
81
82 self.priority_order = vec![PdeType::Parabolic, PdeType::Hyperbolic, PdeType::Elliptic];
83 }
84
85 pub fn register(&mut self, pde_type: PdeType, solver: Arc<dyn PDESolver>) {
87 self.solvers.entry(pde_type).or_default().push(solver);
88
89 if let Some(solvers) = self.solvers.get_mut(&pde_type) {
90 solvers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
91 }
92 }
93
94 pub fn get_solver(&self, pde_type: &PdeType) -> Option<&Arc<dyn PDESolver>> {
96 self.solvers
97 .get(pde_type)
98 .and_then(|solvers| solvers.first())
99 }
100
101 pub fn solve(&self, pde: &Pde) -> PDEResult {
103 let pde_type =
104 classify_pde(pde).map_err(|e| PDEError::ClassificationFailed { reason: e })?;
105
106 if let Some(solvers) = self.solvers.get(&pde_type) {
107 for solver in solvers {
108 if solver.can_solve(pde_type) {
109 match solver.solve(pde) {
110 Ok(solution) => return Ok(solution),
111 Err(_) => continue,
112 }
113 }
114 }
115 }
116
117 Err(PDEError::NoSolverAvailable { pde_type })
118 }
119
120 pub fn try_all_solvers(&self, pde: &Pde) -> PDEResult {
122 for pde_type in &self.priority_order {
123 if let Some(solvers) = self.solvers.get(pde_type) {
124 for solver in solvers {
125 match solver.solve(pde) {
126 Ok(solution) => return Ok(solution),
127 Err(_) => continue,
128 }
129 }
130 }
131 }
132
133 self.solve(pde)
134 }
135
136 pub fn registered_types(&self) -> Vec<PdeType> {
138 self.solvers.keys().copied().collect()
139 }
140
141 pub fn solver_count(&self) -> usize {
143 self.solvers.values().map(|v| v.len()).sum()
144 }
145}
146
147impl Default for PDESolverRegistry {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_registry_creation() {
159 let registry = PDESolverRegistry::new();
160 assert!(registry.solver_count() > 0);
161 }
162
163 #[test]
164 fn test_registry_registered_types() {
165 let registry = PDESolverRegistry::new();
166 let types = registry.registered_types();
167 assert!(!types.is_empty());
168 assert!(types.contains(&PdeType::Parabolic));
169 assert!(types.contains(&PdeType::Hyperbolic));
170 assert!(types.contains(&PdeType::Elliptic));
171 }
172
173 #[test]
174 fn test_get_solver() {
175 let registry = PDESolverRegistry::new();
176 let solver = registry.get_solver(&PdeType::Parabolic);
177 assert!(solver.is_some());
178 }
179
180 #[test]
181 fn test_solver_priority() {
182 let registry = PDESolverRegistry::new();
183 if let Some(solvers) = registry.solvers.get(&PdeType::Parabolic) {
184 if solvers.len() > 1 {
185 let priorities: Vec<_> = solvers.iter().map(|s| s.priority()).collect();
186 let mut sorted = priorities.clone();
187 sorted.sort_by(|a, b| b.cmp(a));
188 assert_eq!(priorities, sorted, "Solvers should be sorted by priority");
189 }
190 }
191 }
192
193 #[test]
194 fn test_pde_error_variants() {
195 let err1 = PDEError::NoSolverAvailable {
196 pde_type: PdeType::Parabolic,
197 };
198 assert!(matches!(err1, PDEError::NoSolverAvailable { .. }));
199
200 let err2 = PDEError::ClassificationFailed {
201 reason: "test".to_string(),
202 };
203 assert!(matches!(err2, PDEError::ClassificationFailed { .. }));
204
205 let err3 = PDEError::SolutionFailed {
206 solver: "test".to_string(),
207 reason: "test".to_string(),
208 };
209 assert!(matches!(err3, PDEError::SolutionFailed { .. }));
210 }
211
212 #[test]
213 fn test_pde_error_clone() {
214 let err = PDEError::NoSolverAvailable {
215 pde_type: PdeType::Parabolic,
216 };
217 let _cloned = err.clone();
218 }
219
220 #[test]
221 fn test_registry_default() {
222 let registry = PDESolverRegistry::default();
223 assert!(registry.solver_count() > 0);
224 }
225}