#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum BoundaryCondition {
Dirichlet {
value: f64,
},
Neumann {
flux: f64,
},
Robin {
alpha: f64,
beta: f64,
value: f64,
},
Periodic,
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub enum CollocationStrategy {
#[default]
Random,
LatinHypercube,
UniformGrid,
AdaptiveResidual,
}
#[derive(Debug, Clone)]
pub struct PINNConfig {
pub hidden_layers: Vec<usize>,
pub learning_rate: f64,
pub max_epochs: usize,
pub n_collocation: usize,
pub n_boundary: usize,
pub physics_weight: f64,
pub boundary_weight: f64,
pub data_weight: f64,
pub collocation: CollocationStrategy,
pub convergence_tol: f64,
}
impl Default for PINNConfig {
fn default() -> Self {
Self {
hidden_layers: vec![64, 64, 64],
learning_rate: 1e-3,
max_epochs: 10000,
n_collocation: 1000,
n_boundary: 100,
physics_weight: 1.0,
boundary_weight: 10.0,
data_weight: 1.0,
collocation: CollocationStrategy::default(),
convergence_tol: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct PINNResult {
pub final_loss: f64,
pub physics_loss: f64,
pub boundary_loss: f64,
pub data_loss: f64,
pub epochs_trained: usize,
pub converged: bool,
pub loss_history: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct PDEProblem {
pub spatial_dim: usize,
pub domain: Vec<(f64, f64)>,
pub boundaries: Vec<Boundary>,
pub has_time: bool,
pub time_domain: Option<(f64, f64)>,
}
#[derive(Debug, Clone)]
pub struct Boundary {
pub dim: usize,
pub side: BoundarySide,
pub condition: BoundaryCondition,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum BoundarySide {
Low,
High,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = PINNConfig::default();
assert_eq!(config.hidden_layers, vec![64, 64, 64]);
assert!((config.learning_rate - 1e-3).abs() < 1e-15);
assert_eq!(config.max_epochs, 10000);
assert_eq!(config.n_collocation, 1000);
assert_eq!(config.n_boundary, 100);
assert!((config.physics_weight - 1.0).abs() < 1e-15);
assert!((config.boundary_weight - 10.0).abs() < 1e-15);
assert!((config.data_weight - 1.0).abs() < 1e-15);
assert!((config.convergence_tol - 1e-6).abs() < 1e-15);
}
#[test]
fn test_boundary_condition_variants() {
let dirichlet = BoundaryCondition::Dirichlet { value: 1.0 };
let neumann = BoundaryCondition::Neumann { flux: 0.5 };
let robin = BoundaryCondition::Robin {
alpha: 1.0,
beta: 2.0,
value: 3.0,
};
let periodic = BoundaryCondition::Periodic;
let _ = format!("{:?}", dirichlet);
let _ = format!("{:?}", neumann);
let _ = format!("{:?}", robin);
let _ = format!("{:?}", periodic);
}
#[test]
fn test_pde_problem_construction() {
let problem = PDEProblem {
spatial_dim: 2,
domain: vec![(0.0, 1.0), (0.0, 1.0)],
boundaries: vec![
Boundary {
dim: 0,
side: BoundarySide::Low,
condition: BoundaryCondition::Dirichlet { value: 0.0 },
},
Boundary {
dim: 0,
side: BoundarySide::High,
condition: BoundaryCondition::Dirichlet { value: 1.0 },
},
],
has_time: false,
time_domain: None,
};
assert_eq!(problem.spatial_dim, 2);
assert_eq!(problem.domain.len(), 2);
assert_eq!(problem.boundaries.len(), 2);
assert!(!problem.has_time);
}
#[test]
fn test_boundary_side_equality() {
assert_eq!(BoundarySide::Low, BoundarySide::Low);
assert_eq!(BoundarySide::High, BoundarySide::High);
assert_ne!(BoundarySide::Low, BoundarySide::High);
}
#[test]
fn test_pinn_result_fields() {
let result = PINNResult {
final_loss: 0.001,
physics_loss: 0.0005,
boundary_loss: 0.0003,
data_loss: 0.0002,
epochs_trained: 5000,
converged: true,
loss_history: vec![1.0, 0.5, 0.1, 0.01, 0.001],
};
assert!(result.converged);
assert_eq!(result.epochs_trained, 5000);
assert_eq!(result.loss_history.len(), 5);
assert!(result.final_loss < 0.01);
}
#[test]
fn test_collocation_strategy_default() {
let strategy = CollocationStrategy::default();
assert!(matches!(strategy, CollocationStrategy::Random));
}
#[test]
fn test_time_dependent_problem() {
let problem = PDEProblem {
spatial_dim: 1,
domain: vec![(0.0, 1.0)],
boundaries: vec![],
has_time: true,
time_domain: Some((0.0, 1.0)),
};
assert!(problem.has_time);
assert!(problem.time_domain.is_some());
let (t_min, t_max) = problem.time_domain.unwrap_or((0.0, 0.0));
assert!((t_min - 0.0).abs() < 1e-15);
assert!((t_max - 1.0).abs() < 1e-15);
}
}