mathhook-core 0.2.0

Core mathematical engine for MathHook - expressions, algebra, and solving
Documentation
//! PDE Solver Registry
//!
//! Registry-based dispatch for PDE solvers following the architecture pattern
//! established in the ODE module (scored 9/10 for registry quality).
//!
//! This eliminates hardcoded match patterns and provides O(1) lookup for solvers.

use super::classification::classify_pde;
use super::types::{PDESolution, Pde, PdeType};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;

/// Error type for PDE solving operations
#[derive(Debug, Clone, PartialEq)]
pub enum PDEError {
    /// No solver available for this PDE type
    NoSolverAvailable { pde_type: PdeType },
    /// Classification failed
    ClassificationFailed { reason: String },
    /// Solver failed to find solution
    SolutionFailed { solver: String, reason: String },
    /// Invalid boundary conditions
    InvalidBoundaryConditions { reason: String },
    /// Invalid initial conditions
    InvalidInitialConditions { reason: String },
    /// Not separable
    NotSeparable { reason: String },
    /// Invalid PDE form
    InvalidForm { reason: String },
}

impl fmt::Display for PDEError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            PDEError::NoSolverAvailable { pde_type } => {
                write!(f, "No solver available for PDE type: {:?}", pde_type)
            }
            PDEError::ClassificationFailed { reason } => {
                write!(f, "PDE classification failed: {}", reason)
            }
            PDEError::SolutionFailed { solver, reason } => {
                write!(f, "Solver '{}' failed: {}", solver, reason)
            }
            PDEError::InvalidBoundaryConditions { reason } => {
                write!(f, "Invalid boundary conditions: {}", reason)
            }
            PDEError::InvalidInitialConditions { reason } => {
                write!(f, "Invalid initial conditions: {}", reason)
            }
            PDEError::NotSeparable { reason } => {
                write!(f, "PDE is not separable: {}", reason)
            }
            PDEError::InvalidForm { reason } => {
                write!(f, "Invalid PDE form: {}", reason)
            }
        }
    }
}

impl std::error::Error for PDEError {}

/// Result type for PDE operations
pub type PDEResult = Result<PDESolution, PDEError>;

/// Trait for PDE solvers that can be registered
pub trait PDESolver: Send + Sync {
    /// Attempts to solve the given PDE
    fn solve(&self, pde: &Pde) -> PDEResult;

    /// Returns true if this solver can handle the given PDE type
    fn can_solve(&self, pde_type: PdeType) -> bool;

    /// Priority for this solver (higher = try first)
    fn priority(&self) -> u8;

    /// Solver name for diagnostics
    fn name(&self) -> &'static str;

    /// Solver description
    fn description(&self) -> &'static str;
}

/// Registry for PDE solvers with O(1) lookup by type
pub struct PDESolverRegistry {
    /// Solvers organized by PDE type
    solvers: HashMap<PdeType, Vec<Arc<dyn PDESolver>>>,
    /// Priority order for trying solvers
    priority_order: Vec<PdeType>,
}

impl PDESolverRegistry {
    /// Creates a new registry with all standard solvers registered
    pub fn new() -> Self {
        let mut registry = Self {
            solvers: HashMap::new(),
            priority_order: Vec::new(),
        };
        registry.register_all_solvers();
        registry
    }

    /// Register all standard PDE solvers
    fn register_all_solvers(&mut self) {
        use super::standard::heat::HeatEquationSolver;
        use super::standard::laplace::LaplaceEquationSolver;
        use super::standard::wave::WaveEquationSolver;

        self.register(PdeType::Parabolic, Arc::new(HeatEquationSolver::new()));
        self.register(PdeType::Hyperbolic, Arc::new(WaveEquationSolver::new()));
        self.register(PdeType::Elliptic, Arc::new(LaplaceEquationSolver::new()));

        self.priority_order = vec![PdeType::Parabolic, PdeType::Hyperbolic, PdeType::Elliptic];
    }

    /// Register a solver for a specific PDE type
    pub fn register(&mut self, pde_type: PdeType, solver: Arc<dyn PDESolver>) {
        self.solvers.entry(pde_type).or_default().push(solver);

        if let Some(solvers) = self.solvers.get_mut(&pde_type) {
            solvers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
        }
    }

    /// Get solver for specific PDE type
    pub fn get_solver(&self, pde_type: &PdeType) -> Option<&Arc<dyn PDESolver>> {
        self.solvers
            .get(pde_type)
            .and_then(|solvers| solvers.first())
    }

    /// Try to solve PDE using registered solvers
    pub fn solve(&self, pde: &Pde) -> PDEResult {
        let pde_type =
            classify_pde(pde).map_err(|e| PDEError::ClassificationFailed { reason: e })?;

        if let Some(solvers) = self.solvers.get(&pde_type) {
            for solver in solvers {
                if solver.can_solve(pde_type) {
                    match solver.solve(pde) {
                        Ok(solution) => return Ok(solution),
                        Err(_) => continue,
                    }
                }
            }
        }

        Err(PDEError::NoSolverAvailable { pde_type })
    }

    /// Try all solvers in priority order
    pub fn try_all_solvers(&self, pde: &Pde) -> PDEResult {
        for pde_type in &self.priority_order {
            if let Some(solvers) = self.solvers.get(pde_type) {
                for solver in solvers {
                    match solver.solve(pde) {
                        Ok(solution) => return Ok(solution),
                        Err(_) => continue,
                    }
                }
            }
        }

        self.solve(pde)
    }

    /// Get all registered solver types
    pub fn registered_types(&self) -> Vec<PdeType> {
        self.solvers.keys().copied().collect()
    }

    /// Get solver count
    pub fn solver_count(&self) -> usize {
        self.solvers.values().map(|v| v.len()).sum()
    }
}

impl Default for PDESolverRegistry {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_registry_creation() {
        let registry = PDESolverRegistry::new();
        assert!(registry.solver_count() > 0);
    }

    #[test]
    fn test_registry_registered_types() {
        let registry = PDESolverRegistry::new();
        let types = registry.registered_types();
        assert!(!types.is_empty());
        assert!(types.contains(&PdeType::Parabolic));
        assert!(types.contains(&PdeType::Hyperbolic));
        assert!(types.contains(&PdeType::Elliptic));
    }

    #[test]
    fn test_get_solver() {
        let registry = PDESolverRegistry::new();
        let solver = registry.get_solver(&PdeType::Parabolic);
        assert!(solver.is_some());
    }

    #[test]
    fn test_solver_priority() {
        let registry = PDESolverRegistry::new();
        if let Some(solvers) = registry.solvers.get(&PdeType::Parabolic) {
            if solvers.len() > 1 {
                let priorities: Vec<_> = solvers.iter().map(|s| s.priority()).collect();
                let mut sorted = priorities.clone();
                sorted.sort_by(|a, b| b.cmp(a));
                assert_eq!(priorities, sorted, "Solvers should be sorted by priority");
            }
        }
    }

    #[test]
    fn test_pde_error_variants() {
        let err1 = PDEError::NoSolverAvailable {
            pde_type: PdeType::Parabolic,
        };
        assert!(matches!(err1, PDEError::NoSolverAvailable { .. }));

        let err2 = PDEError::ClassificationFailed {
            reason: "test".to_string(),
        };
        assert!(matches!(err2, PDEError::ClassificationFailed { .. }));

        let err3 = PDEError::SolutionFailed {
            solver: "test".to_string(),
            reason: "test".to_string(),
        };
        assert!(matches!(err3, PDEError::SolutionFailed { .. }));
    }

    #[test]
    fn test_pde_error_clone() {
        let err = PDEError::NoSolverAvailable {
            pde_type: PdeType::Parabolic,
        };
        let _cloned = err.clone();
    }

    #[test]
    fn test_registry_default() {
        let registry = PDESolverRegistry::default();
        assert!(registry.solver_count() > 0);
    }

    #[test]
    fn test_pde_error_display() {
        let err = PDEError::NoSolverAvailable {
            pde_type: PdeType::Parabolic,
        };
        let s = format!("{}", err);
        assert!(s.contains("No solver available"));
    }
}