mathhook_core/calculus/pde/
registry.rs1use super::classification::classify_pde;
9use super::types::{PDESolution, Pde, PdeType};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, PartialEq)]
16pub enum PDEError {
17 NoSolverAvailable { pde_type: PdeType },
19 ClassificationFailed { reason: String },
21 SolutionFailed { solver: String, reason: String },
23 InvalidBoundaryConditions { reason: String },
25 InvalidInitialConditions { reason: String },
27 NotSeparable { reason: String },
29 InvalidForm { reason: String },
31}
32
33impl fmt::Display for PDEError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 PDEError::NoSolverAvailable { pde_type } => {
37 write!(f, "No solver available for PDE type: {:?}", pde_type)
38 }
39 PDEError::ClassificationFailed { reason } => {
40 write!(f, "PDE classification failed: {}", reason)
41 }
42 PDEError::SolutionFailed { solver, reason } => {
43 write!(f, "Solver '{}' failed: {}", solver, reason)
44 }
45 PDEError::InvalidBoundaryConditions { reason } => {
46 write!(f, "Invalid boundary conditions: {}", reason)
47 }
48 PDEError::InvalidInitialConditions { reason } => {
49 write!(f, "Invalid initial conditions: {}", reason)
50 }
51 PDEError::NotSeparable { reason } => {
52 write!(f, "PDE is not separable: {}", reason)
53 }
54 PDEError::InvalidForm { reason } => {
55 write!(f, "Invalid PDE form: {}", reason)
56 }
57 }
58 }
59}
60
61impl std::error::Error for PDEError {}
62
63pub type PDEResult = Result<PDESolution, PDEError>;
65
66pub trait PDESolver: Send + Sync {
68 fn solve(&self, pde: &Pde) -> PDEResult;
70
71 fn can_solve(&self, pde_type: PdeType) -> bool;
73
74 fn priority(&self) -> u8;
76
77 fn name(&self) -> &'static str;
79
80 fn description(&self) -> &'static str;
82}
83
84pub struct PDESolverRegistry {
86 solvers: HashMap<PdeType, Vec<Arc<dyn PDESolver>>>,
88 priority_order: Vec<PdeType>,
90}
91
92impl PDESolverRegistry {
93 pub fn new() -> Self {
95 let mut registry = Self {
96 solvers: HashMap::new(),
97 priority_order: Vec::new(),
98 };
99 registry.register_all_solvers();
100 registry
101 }
102
103 fn register_all_solvers(&mut self) {
105 use super::standard::heat::HeatEquationSolver;
106 use super::standard::laplace::LaplaceEquationSolver;
107 use super::standard::wave::WaveEquationSolver;
108
109 self.register(PdeType::Parabolic, Arc::new(HeatEquationSolver::new()));
110 self.register(PdeType::Hyperbolic, Arc::new(WaveEquationSolver::new()));
111 self.register(PdeType::Elliptic, Arc::new(LaplaceEquationSolver::new()));
112
113 self.priority_order = vec![PdeType::Parabolic, PdeType::Hyperbolic, PdeType::Elliptic];
114 }
115
116 pub fn register(&mut self, pde_type: PdeType, solver: Arc<dyn PDESolver>) {
118 self.solvers.entry(pde_type).or_default().push(solver);
119
120 if let Some(solvers) = self.solvers.get_mut(&pde_type) {
121 solvers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
122 }
123 }
124
125 pub fn get_solver(&self, pde_type: &PdeType) -> Option<&Arc<dyn PDESolver>> {
127 self.solvers
128 .get(pde_type)
129 .and_then(|solvers| solvers.first())
130 }
131
132 pub fn solve(&self, pde: &Pde) -> PDEResult {
134 let pde_type =
135 classify_pde(pde).map_err(|e| PDEError::ClassificationFailed { reason: e })?;
136
137 if let Some(solvers) = self.solvers.get(&pde_type) {
138 for solver in solvers {
139 if solver.can_solve(pde_type) {
140 match solver.solve(pde) {
141 Ok(solution) => return Ok(solution),
142 Err(_) => continue,
143 }
144 }
145 }
146 }
147
148 Err(PDEError::NoSolverAvailable { pde_type })
149 }
150
151 pub fn try_all_solvers(&self, pde: &Pde) -> PDEResult {
153 for pde_type in &self.priority_order {
154 if let Some(solvers) = self.solvers.get(pde_type) {
155 for solver in solvers {
156 match solver.solve(pde) {
157 Ok(solution) => return Ok(solution),
158 Err(_) => continue,
159 }
160 }
161 }
162 }
163
164 self.solve(pde)
165 }
166
167 pub fn registered_types(&self) -> Vec<PdeType> {
169 self.solvers.keys().copied().collect()
170 }
171
172 pub fn solver_count(&self) -> usize {
174 self.solvers.values().map(|v| v.len()).sum()
175 }
176}
177
178impl Default for PDESolverRegistry {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_registry_creation() {
190 let registry = PDESolverRegistry::new();
191 assert!(registry.solver_count() > 0);
192 }
193
194 #[test]
195 fn test_registry_registered_types() {
196 let registry = PDESolverRegistry::new();
197 let types = registry.registered_types();
198 assert!(!types.is_empty());
199 assert!(types.contains(&PdeType::Parabolic));
200 assert!(types.contains(&PdeType::Hyperbolic));
201 assert!(types.contains(&PdeType::Elliptic));
202 }
203
204 #[test]
205 fn test_get_solver() {
206 let registry = PDESolverRegistry::new();
207 let solver = registry.get_solver(&PdeType::Parabolic);
208 assert!(solver.is_some());
209 }
210
211 #[test]
212 fn test_solver_priority() {
213 let registry = PDESolverRegistry::new();
214 if let Some(solvers) = registry.solvers.get(&PdeType::Parabolic) {
215 if solvers.len() > 1 {
216 let priorities: Vec<_> = solvers.iter().map(|s| s.priority()).collect();
217 let mut sorted = priorities.clone();
218 sorted.sort_by(|a, b| b.cmp(a));
219 assert_eq!(priorities, sorted, "Solvers should be sorted by priority");
220 }
221 }
222 }
223
224 #[test]
225 fn test_pde_error_variants() {
226 let err1 = PDEError::NoSolverAvailable {
227 pde_type: PdeType::Parabolic,
228 };
229 assert!(matches!(err1, PDEError::NoSolverAvailable { .. }));
230
231 let err2 = PDEError::ClassificationFailed {
232 reason: "test".to_string(),
233 };
234 assert!(matches!(err2, PDEError::ClassificationFailed { .. }));
235
236 let err3 = PDEError::SolutionFailed {
237 solver: "test".to_string(),
238 reason: "test".to_string(),
239 };
240 assert!(matches!(err3, PDEError::SolutionFailed { .. }));
241 }
242
243 #[test]
244 fn test_pde_error_clone() {
245 let err = PDEError::NoSolverAvailable {
246 pde_type: PdeType::Parabolic,
247 };
248 let _cloned = err.clone();
249 }
250
251 #[test]
252 fn test_registry_default() {
253 let registry = PDESolverRegistry::default();
254 assert!(registry.solver_count() > 0);
255 }
256
257 #[test]
258 fn test_pde_error_display() {
259 let err = PDEError::NoSolverAvailable {
260 pde_type: PdeType::Parabolic,
261 };
262 let s = format!("{}", err);
263 assert!(s.contains("No solver available"));
264 }
265}