use super::network::PINNNetwork;
use super::types::{
Boundary, BoundaryCondition, BoundarySide, CollocationStrategy, PDEProblem, PINNConfig,
PINNResult,
};
use crate::error::IntegrateResult;
use scirs2_core::ndarray::{Array1, Array2};
pub struct PINNSolver {
network: PINNNetwork,
config: PINNConfig,
}
struct AdamOptimizer {
lr: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
m: Array1<f64>,
v: Array1<f64>,
t: usize,
}
impl AdamOptimizer {
fn new(n_params: usize, lr: f64) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
m: Array1::<f64>::zeros(n_params),
v: Array1::<f64>::zeros(n_params),
t: 0,
}
}
fn step(&mut self, params: &Array1<f64>, grad: &Array1<f64>) -> Array1<f64> {
self.t += 1;
let t = self.t as f64;
self.m = &self.m * self.beta1 + grad * (1.0 - self.beta1);
self.v = &self.v * self.beta2 + &(grad * grad) * (1.0 - self.beta2);
let m_hat = &self.m / (1.0 - self.beta1.powf(t));
let v_hat = &self.v / (1.0 - self.beta2.powf(t));
params - &(&m_hat / &(v_hat.mapv(|x| x.sqrt()) + self.epsilon) * self.lr)
}
}
fn xorshift64(state: &mut u64) -> f64 {
let mut s = *state;
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
*state = s;
(s as f64) / (u64::MAX as f64)
}
impl PINNSolver {
pub fn new(problem: &PDEProblem, config: PINNConfig) -> IntegrateResult<Self> {
let input_dim = if problem.has_time {
problem.spatial_dim + 1
} else {
problem.spatial_dim
};
let network = PINNNetwork::new(input_dim, &config.hidden_layers, 1)?;
Ok(Self { network, config })
}
pub fn train<F>(
&mut self,
pde_residual: &F,
problem: &PDEProblem,
data_points: Option<(&Array2<f64>, &Array1<f64>)>,
) -> IntegrateResult<PINNResult>
where
F: Fn(&PINNNetwork, &Array1<f64>) -> IntegrateResult<f64>,
{
let mut rng_state: u64 = 12345_u64 | 1;
let collocation_pts = generate_collocation_points(
problem,
self.config.n_collocation,
&self.config.collocation,
&mut rng_state,
);
let boundary_data =
generate_boundary_points(problem, self.config.n_boundary, &mut rng_state);
let n_params = self.network.n_parameters();
let mut optimizer = AdamOptimizer::new(n_params, self.config.learning_rate);
let mut loss_history = Vec::with_capacity(self.config.max_epochs);
let mut physics_loss = 0.0;
let mut boundary_loss = 0.0;
let mut data_loss = 0.0;
let mut converged = false;
let mut epochs_trained = 0;
let fd_step = 1e-5;
for epoch in 0..self.config.max_epochs {
physics_loss = 0.0;
let n_coll = collocation_pts.nrows();
for i in 0..n_coll {
let pt = collocation_pts.row(i).to_owned();
let residual = pde_residual(&self.network, &pt)?;
physics_loss += residual * residual;
}
if n_coll > 0 {
physics_loss /= n_coll as f64;
}
boundary_loss = 0.0;
let mut n_bc_total = 0;
for (bc_points, bc_condition) in &boundary_data {
let n_bc = bc_points.nrows();
for j in 0..n_bc {
let pt = bc_points.row(j).to_owned();
let bc_err = compute_bc_error(&self.network, &pt, bc_condition, fd_step)?;
boundary_loss += bc_err * bc_err;
n_bc_total += 1;
}
}
if n_bc_total > 0 {
boundary_loss /= n_bc_total as f64;
}
data_loss = 0.0;
if let Some((x_data, y_data)) = data_points {
let n_data = x_data.nrows();
for i in 0..n_data {
let pt = x_data.row(i).to_owned();
let predicted = self.network.forward(&pt)?;
let err = predicted - y_data[i];
data_loss += err * err;
}
if n_data > 0 {
data_loss /= n_data as f64;
}
}
let total_loss = self.config.physics_weight * physics_loss
+ self.config.boundary_weight * boundary_loss
+ self.config.data_weight * data_loss;
loss_history.push(total_loss);
epochs_trained = epoch + 1;
if total_loss < self.config.convergence_tol {
converged = true;
break;
}
let current_params = self.network.parameters();
let mut grad = Array1::<f64>::zeros(n_params);
for p in 0..n_params {
let mut params_plus = current_params.clone();
params_plus[p] += fd_step;
self.network.set_parameters(¶ms_plus)?;
let loss_plus = self.compute_total_loss(
pde_residual,
&collocation_pts,
&boundary_data,
data_points,
fd_step,
)?;
grad[p] = (loss_plus - total_loss) / fd_step;
}
let new_params = optimizer.step(¤t_params, &grad);
self.network.set_parameters(&new_params)?;
}
Ok(PINNResult {
final_loss: loss_history.last().copied().unwrap_or(f64::INFINITY),
physics_loss,
boundary_loss,
data_loss,
epochs_trained,
converged,
loss_history,
})
}
fn compute_total_loss<F>(
&self,
pde_residual: &F,
collocation_pts: &Array2<f64>,
boundary_data: &[(Array2<f64>, BoundaryCondition)],
data_points: Option<(&Array2<f64>, &Array1<f64>)>,
fd_step: f64,
) -> IntegrateResult<f64>
where
F: Fn(&PINNNetwork, &Array1<f64>) -> IntegrateResult<f64>,
{
let n_coll = collocation_pts.nrows();
let mut physics_loss = 0.0;
for i in 0..n_coll {
let pt = collocation_pts.row(i).to_owned();
let residual = pde_residual(&self.network, &pt)?;
physics_loss += residual * residual;
}
if n_coll > 0 {
physics_loss /= n_coll as f64;
}
let mut boundary_loss = 0.0;
let mut n_bc_total = 0;
for (bc_points, bc_condition) in boundary_data {
let n_bc = bc_points.nrows();
for j in 0..n_bc {
let pt = bc_points.row(j).to_owned();
let bc_err = compute_bc_error(&self.network, &pt, bc_condition, fd_step)?;
boundary_loss += bc_err * bc_err;
n_bc_total += 1;
}
}
if n_bc_total > 0 {
boundary_loss /= n_bc_total as f64;
}
let mut data_loss = 0.0;
if let Some((x_data, y_data)) = data_points {
let n_data = x_data.nrows();
for i in 0..n_data {
let pt = x_data.row(i).to_owned();
let predicted = self.network.forward(&pt)?;
let err = predicted - y_data[i];
data_loss += err * err;
}
if n_data > 0 {
data_loss /= n_data as f64;
}
}
Ok(self.config.physics_weight * physics_loss
+ self.config.boundary_weight * boundary_loss
+ self.config.data_weight * data_loss)
}
pub fn predict(&self, points: &Array2<f64>) -> IntegrateResult<Array1<f64>> {
self.network.forward_batch(points)
}
pub fn network(&self) -> &PINNNetwork {
&self.network
}
}
fn compute_bc_error(
network: &PINNNetwork,
x: &Array1<f64>,
condition: &BoundaryCondition,
h: f64,
) -> IntegrateResult<f64> {
match condition {
BoundaryCondition::Dirichlet { value } => {
let u = network.forward(x)?;
Ok(u - value)
}
BoundaryCondition::Neumann { flux } => {
let grad = network.gradient(x, h)?;
let du_dn = grad[0];
Ok(du_dn - flux)
}
BoundaryCondition::Robin { alpha, beta, value } => {
let u = network.forward(x)?;
let grad = network.gradient(x, h)?;
let du_dn = grad[0];
Ok(alpha * u + beta * du_dn - value)
}
BoundaryCondition::Periodic => {
Ok(0.0)
}
}
}
pub fn generate_collocation_points(
problem: &PDEProblem,
n_points: usize,
strategy: &CollocationStrategy,
rng_state: &mut u64,
) -> Array2<f64> {
let input_dim = if problem.has_time {
problem.spatial_dim + 1
} else {
problem.spatial_dim
};
let mut points = Array2::<f64>::zeros((n_points, input_dim));
match strategy {
CollocationStrategy::UniformGrid => {
let n_per_dim = (n_points as f64).powf(1.0 / input_dim as f64).ceil() as usize;
let actual_n = n_per_dim.pow(input_dim as u32).min(n_points);
let mut all_bounds = problem.domain.clone();
if let Some((t_min, t_max)) = problem.time_domain {
all_bounds.push((t_min, t_max));
}
for idx in 0..actual_n {
let mut remainder = idx;
for d in 0..input_dim {
let coord_idx = remainder % n_per_dim;
remainder /= n_per_dim;
let (lo, hi) = all_bounds.get(d).copied().unwrap_or((0.0, 1.0));
let frac = if n_per_dim > 1 {
(coord_idx as f64 + 0.5) / n_per_dim as f64
} else {
0.5
};
points[[idx, d]] = lo + frac * (hi - lo);
}
}
for idx in actual_n..n_points {
for d in 0..input_dim {
let (lo, hi) = if d < problem.domain.len() {
problem.domain[d]
} else {
problem.time_domain.unwrap_or((0.0, 1.0))
};
points[[idx, d]] = lo + xorshift64(rng_state) * (hi - lo);
}
}
}
CollocationStrategy::LatinHypercube => {
let mut all_bounds = problem.domain.clone();
if let Some((t_min, t_max)) = problem.time_domain {
all_bounds.push((t_min, t_max));
}
for d in 0..input_dim {
let mut indices: Vec<usize> = (0..n_points).collect();
for i in (1..n_points).rev() {
let j = (xorshift64(rng_state) * (i + 1) as f64) as usize % (i + 1);
indices.swap(i, j);
}
let (lo, hi) = all_bounds.get(d).copied().unwrap_or((0.0, 1.0));
for i in 0..n_points {
let frac = (indices[i] as f64 + xorshift64(rng_state)) / n_points as f64;
points[[i, d]] = lo + frac * (hi - lo);
}
}
}
CollocationStrategy::Random | CollocationStrategy::AdaptiveResidual => {
let mut all_bounds = problem.domain.clone();
if let Some((t_min, t_max)) = problem.time_domain {
all_bounds.push((t_min, t_max));
}
for i in 0..n_points {
for d in 0..input_dim {
let (lo, hi) = all_bounds.get(d).copied().unwrap_or((0.0, 1.0));
points[[i, d]] = lo + xorshift64(rng_state) * (hi - lo);
}
}
}
}
points
}
pub fn generate_boundary_points(
problem: &PDEProblem,
n_per_boundary: usize,
rng_state: &mut u64,
) -> Vec<(Array2<f64>, BoundaryCondition)> {
let input_dim = if problem.has_time {
problem.spatial_dim + 1
} else {
problem.spatial_dim
};
let mut result = Vec::with_capacity(problem.boundaries.len());
for boundary in &problem.boundaries {
let mut pts = Array2::<f64>::zeros((n_per_boundary, input_dim));
for i in 0..n_per_boundary {
for d in 0..input_dim {
if d == boundary.dim && d < problem.domain.len() {
let (lo, hi) = problem.domain[d];
pts[[i, d]] = match boundary.side {
BoundarySide::High => hi,
BoundarySide::Low => lo,
};
} else if d < problem.domain.len() {
let (lo, hi) = problem.domain[d];
pts[[i, d]] = lo + xorshift64(rng_state) * (hi - lo);
} else if let Some((t_min, t_max)) = problem.time_domain {
pts[[i, d]] = t_min + xorshift64(rng_state) * (t_max - t_min);
}
}
}
result.push((pts, boundary.condition.clone()));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pinn::problems;
use scirs2_core::ndarray::array;
#[test]
fn test_solver_creation() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let config = PINNConfig::default();
let solver = PINNSolver::new(&problem, config);
assert!(solver.is_ok());
}
#[test]
fn test_collocation_points_count_and_range() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let mut rng = 42u64 | 1;
let pts =
generate_collocation_points(&problem, 100, &CollocationStrategy::Random, &mut rng);
assert_eq!(pts.nrows(), 100);
assert_eq!(pts.ncols(), 2);
for i in 0..100 {
assert!(pts[[i, 0]] >= 0.0 && pts[[i, 0]] <= 1.0);
assert!(pts[[i, 1]] >= 0.0 && pts[[i, 1]] <= 1.0);
}
}
#[test]
fn test_collocation_uniform_grid() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let mut rng = 42u64 | 1;
let pts =
generate_collocation_points(&problem, 25, &CollocationStrategy::UniformGrid, &mut rng);
assert_eq!(pts.nrows(), 25);
for i in 0..25 {
assert!(pts[[i, 0]] >= 0.0 && pts[[i, 0]] <= 1.0);
assert!(pts[[i, 1]] >= 0.0 && pts[[i, 1]] <= 1.0);
}
}
#[test]
fn test_collocation_latin_hypercube() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let mut rng = 42u64 | 1;
let pts = generate_collocation_points(
&problem,
50,
&CollocationStrategy::LatinHypercube,
&mut rng,
);
assert_eq!(pts.nrows(), 50);
assert_eq!(pts.ncols(), 2);
for i in 0..50 {
assert!(pts[[i, 0]] >= 0.0 && pts[[i, 0]] <= 1.0);
assert!(pts[[i, 1]] >= 0.0 && pts[[i, 1]] <= 1.0);
}
}
#[test]
fn test_boundary_points_on_boundary() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let mut rng = 42u64 | 1;
let bnd_data = generate_boundary_points(&problem, 20, &mut rng);
assert_eq!(bnd_data.len(), 4);
for (pts, _cond) in &bnd_data {
assert_eq!(pts.nrows(), 20);
assert_eq!(pts.ncols(), 2);
for i in 0..20 {
let x = pts[[i, 0]];
let y = pts[[i, 1]];
let on_boundary = (x - 0.0).abs() < 1e-15
|| (x - 1.0).abs() < 1e-15
|| (y - 0.0).abs() < 1e-15
|| (y - 1.0).abs() < 1e-15;
assert!(on_boundary, "point ({}, {}) not on boundary", x, y);
}
}
}
#[test]
fn test_train_loss_decreases() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let config = PINNConfig {
hidden_layers: vec![8, 8],
max_epochs: 20,
n_collocation: 10,
n_boundary: 5,
learning_rate: 1e-3,
..PINNConfig::default()
};
let mut solver = PINNSolver::new(&problem, config).expect("solver creation");
let result = solver
.train(&problems::laplace_residual, &problem, None)
.expect("training");
assert!(result.epochs_trained > 0);
assert!(!result.loss_history.is_empty());
let first = result.loss_history[0];
let last = result.loss_history[result.loss_history.len() - 1];
assert!(
last <= first * 10.0,
"loss did not decrease: first={}, last={}",
first,
last
);
}
#[test]
fn test_predict_after_training() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let config = PINNConfig {
hidden_layers: vec![8],
max_epochs: 5,
n_collocation: 5,
n_boundary: 3,
..PINNConfig::default()
};
let mut solver = PINNSolver::new(&problem, config).expect("solver creation");
let _ = solver.train(&problems::laplace_residual, &problem, None);
let test_pts = Array2::from_shape_vec((3, 2), vec![0.2, 0.3, 0.5, 0.5, 0.8, 0.9])
.expect("test points");
let predictions = solver.predict(&test_pts);
assert!(predictions.is_ok());
let vals = predictions.expect("predictions");
assert_eq!(vals.len(), 3);
for &v in vals.iter() {
assert!(v.is_finite());
}
}
#[test]
fn test_train_with_data() {
let problem = problems::laplace_problem_2d((0.0, 1.0, 0.0, 1.0));
let config = PINNConfig {
hidden_layers: vec![8],
max_epochs: 5,
n_collocation: 5,
n_boundary: 3,
..PINNConfig::default()
};
let mut solver = PINNSolver::new(&problem, config).expect("solver creation");
let x_data = Array2::from_shape_vec((2, 2), vec![0.5, 0.5, 0.25, 0.75]).expect("data x");
let y_data = array![0.0, 0.0];
let result = solver
.train(
&problems::laplace_residual,
&problem,
Some((&x_data, &y_data)),
)
.expect("training with data");
assert!(result.epochs_trained > 0);
assert!(result.data_loss.is_finite());
}
#[test]
fn test_adam_step_reduces_quadratic() {
let mut optimizer = AdamOptimizer::new(1, 0.1);
let params = array![5.0];
let grad = array![10.0];
let new_params = optimizer.step(¶ms, &grad);
assert!(
new_params[0].abs() < params[0].abs(),
"Adam step should reduce magnitude: {} -> {}",
params[0],
new_params[0]
);
}
#[test]
fn test_adam_momentum_nonzero() {
let mut optimizer = AdamOptimizer::new(2, 0.01);
let params = array![1.0, 2.0];
let grad = array![0.5, -0.3];
let _ = optimizer.step(¶ms, &grad);
assert!(optimizer.m[0].abs() > 1e-15);
assert!(optimizer.m[1].abs() > 1e-15);
assert!(optimizer.v[0] > 0.0);
assert!(optimizer.v[1] > 0.0);
assert_eq!(optimizer.t, 1);
}
#[test]
fn test_collocation_time_dependent() {
let problem = problems::heat_problem_1d((0.0, 1.0), (0.0, 0.5), 0.01);
let mut rng = 99u64 | 1;
let pts = generate_collocation_points(&problem, 50, &CollocationStrategy::Random, &mut rng);
assert_eq!(pts.ncols(), 2);
assert_eq!(pts.nrows(), 50);
for i in 0..50 {
assert!(pts[[i, 0]] >= 0.0 && pts[[i, 0]] <= 1.0);
assert!(pts[[i, 1]] >= 0.0 && pts[[i, 1]] <= 0.5);
}
}
}