use super::classification::classify_pde;
use super::types::{PDESolution, Pde, PdeType};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum PDEError {
NoSolverAvailable { pde_type: PdeType },
ClassificationFailed { reason: String },
SolutionFailed { solver: String, reason: String },
InvalidBoundaryConditions { reason: String },
InvalidInitialConditions { reason: String },
NotSeparable { reason: String },
InvalidForm { reason: String },
}
impl fmt::Display for PDEError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PDEError::NoSolverAvailable { pde_type } => {
write!(f, "No solver available for PDE type: {:?}", pde_type)
}
PDEError::ClassificationFailed { reason } => {
write!(f, "PDE classification failed: {}", reason)
}
PDEError::SolutionFailed { solver, reason } => {
write!(f, "Solver '{}' failed: {}", solver, reason)
}
PDEError::InvalidBoundaryConditions { reason } => {
write!(f, "Invalid boundary conditions: {}", reason)
}
PDEError::InvalidInitialConditions { reason } => {
write!(f, "Invalid initial conditions: {}", reason)
}
PDEError::NotSeparable { reason } => {
write!(f, "PDE is not separable: {}", reason)
}
PDEError::InvalidForm { reason } => {
write!(f, "Invalid PDE form: {}", reason)
}
}
}
}
impl std::error::Error for PDEError {}
pub type PDEResult = Result<PDESolution, PDEError>;
pub trait PDESolver: Send + Sync {
fn solve(&self, pde: &Pde) -> PDEResult;
fn can_solve(&self, pde_type: PdeType) -> bool;
fn priority(&self) -> u8;
fn name(&self) -> &'static str;
fn description(&self) -> &'static str;
}
pub struct PDESolverRegistry {
solvers: HashMap<PdeType, Vec<Arc<dyn PDESolver>>>,
priority_order: Vec<PdeType>,
}
impl PDESolverRegistry {
pub fn new() -> Self {
let mut registry = Self {
solvers: HashMap::new(),
priority_order: Vec::new(),
};
registry.register_all_solvers();
registry
}
fn register_all_solvers(&mut self) {
use super::standard::heat::HeatEquationSolver;
use super::standard::laplace::LaplaceEquationSolver;
use super::standard::wave::WaveEquationSolver;
self.register(PdeType::Parabolic, Arc::new(HeatEquationSolver::new()));
self.register(PdeType::Hyperbolic, Arc::new(WaveEquationSolver::new()));
self.register(PdeType::Elliptic, Arc::new(LaplaceEquationSolver::new()));
self.priority_order = vec![PdeType::Parabolic, PdeType::Hyperbolic, PdeType::Elliptic];
}
pub fn register(&mut self, pde_type: PdeType, solver: Arc<dyn PDESolver>) {
self.solvers.entry(pde_type).or_default().push(solver);
if let Some(solvers) = self.solvers.get_mut(&pde_type) {
solvers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
}
}
pub fn get_solver(&self, pde_type: &PdeType) -> Option<&Arc<dyn PDESolver>> {
self.solvers
.get(pde_type)
.and_then(|solvers| solvers.first())
}
pub fn solve(&self, pde: &Pde) -> PDEResult {
let pde_type =
classify_pde(pde).map_err(|e| PDEError::ClassificationFailed { reason: e })?;
if let Some(solvers) = self.solvers.get(&pde_type) {
for solver in solvers {
if solver.can_solve(pde_type) {
match solver.solve(pde) {
Ok(solution) => return Ok(solution),
Err(_) => continue,
}
}
}
}
Err(PDEError::NoSolverAvailable { pde_type })
}
pub fn try_all_solvers(&self, pde: &Pde) -> PDEResult {
for pde_type in &self.priority_order {
if let Some(solvers) = self.solvers.get(pde_type) {
for solver in solvers {
match solver.solve(pde) {
Ok(solution) => return Ok(solution),
Err(_) => continue,
}
}
}
}
self.solve(pde)
}
pub fn registered_types(&self) -> Vec<PdeType> {
self.solvers.keys().copied().collect()
}
pub fn solver_count(&self) -> usize {
self.solvers.values().map(|v| v.len()).sum()
}
}
impl Default for PDESolverRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_creation() {
let registry = PDESolverRegistry::new();
assert!(registry.solver_count() > 0);
}
#[test]
fn test_registry_registered_types() {
let registry = PDESolverRegistry::new();
let types = registry.registered_types();
assert!(!types.is_empty());
assert!(types.contains(&PdeType::Parabolic));
assert!(types.contains(&PdeType::Hyperbolic));
assert!(types.contains(&PdeType::Elliptic));
}
#[test]
fn test_get_solver() {
let registry = PDESolverRegistry::new();
let solver = registry.get_solver(&PdeType::Parabolic);
assert!(solver.is_some());
}
#[test]
fn test_solver_priority() {
let registry = PDESolverRegistry::new();
if let Some(solvers) = registry.solvers.get(&PdeType::Parabolic) {
if solvers.len() > 1 {
let priorities: Vec<_> = solvers.iter().map(|s| s.priority()).collect();
let mut sorted = priorities.clone();
sorted.sort_by(|a, b| b.cmp(a));
assert_eq!(priorities, sorted, "Solvers should be sorted by priority");
}
}
}
#[test]
fn test_pde_error_variants() {
let err1 = PDEError::NoSolverAvailable {
pde_type: PdeType::Parabolic,
};
assert!(matches!(err1, PDEError::NoSolverAvailable { .. }));
let err2 = PDEError::ClassificationFailed {
reason: "test".to_string(),
};
assert!(matches!(err2, PDEError::ClassificationFailed { .. }));
let err3 = PDEError::SolutionFailed {
solver: "test".to_string(),
reason: "test".to_string(),
};
assert!(matches!(err3, PDEError::SolutionFailed { .. }));
}
#[test]
fn test_pde_error_clone() {
let err = PDEError::NoSolverAvailable {
pde_type: PdeType::Parabolic,
};
let _cloned = err.clone();
}
#[test]
fn test_registry_default() {
let registry = PDESolverRegistry::default();
assert!(registry.solver_count() > 0);
}
#[test]
fn test_pde_error_display() {
let err = PDEError::NoSolverAvailable {
pde_type: PdeType::Parabolic,
};
let s = format!("{}", err);
assert!(s.contains("No solver available"));
}
}