use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::prelude::{Normal, Rng, StdRng};
use scirs2_core::Distribution;
#[derive(Debug, Clone)]
pub struct StochasticWaveConfig {
pub wave_speed: f64,
pub sigma: f64,
pub dt: f64,
}
impl Default for StochasticWaveConfig {
fn default() -> Self {
Self {
wave_speed: 1.0,
sigma: 0.01,
dt: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct WaveSnapshot {
pub t: f64,
pub u: Array1<f64>,
pub v: Array1<f64>,
pub energy: f64,
}
#[derive(Debug, Clone)]
pub struct StochasticWaveSolution {
pub snapshots: Vec<WaveSnapshot>,
pub grid: Array1<f64>,
}
impl StochasticWaveSolution {
pub fn len(&self) -> usize {
self.snapshots.len()
}
pub fn is_empty(&self) -> bool {
self.snapshots.is_empty()
}
pub fn energy_series(&self) -> Vec<f64> {
self.snapshots.iter().map(|s| s.energy).collect()
}
pub fn mean_energy_growth_rate(&self) -> f64 {
if self.snapshots.len() < 2 {
return 0.0;
}
let e0 = self.snapshots[0].energy;
let ef = self.snapshots.last().map(|s| s.energy).unwrap_or(e0);
let t0 = self.snapshots[0].t;
let tf = self.snapshots.last().map(|s| s.t).unwrap_or(t0);
if (tf - t0).abs() < 1e-15 {
0.0
} else {
(ef - e0) / (tf - t0)
}
}
}
pub struct StochasticWaveSolver {
cfg: StochasticWaveConfig,
n_nodes: usize,
dx: f64,
save_every: usize,
x_coords: Array1<f64>,
}
impl StochasticWaveSolver {
pub fn new(
config: StochasticWaveConfig,
domain_length: f64,
n_nodes: usize,
save_every: usize,
) -> IntegrateResult<Self> {
if n_nodes == 0 {
return Err(IntegrateError::InvalidInput(
"n_nodes must be at least 1".to_string(),
));
}
if domain_length <= 0.0 {
return Err(IntegrateError::InvalidInput(
"domain_length must be positive".to_string(),
));
}
if config.dt <= 0.0 {
return Err(IntegrateError::InvalidInput("dt must be positive".to_string()));
}
if config.wave_speed <= 0.0 {
return Err(IntegrateError::InvalidInput(
"wave_speed must be positive".to_string(),
));
}
let dx = domain_length / (n_nodes + 1) as f64;
let courant = config.wave_speed * config.dt / dx;
if courant > 1.0 {
return Err(IntegrateError::InvalidInput(format!(
"Courant condition violated: c*dt/dx = {courant:.4} > 1. \
Reduce dt or increase n_nodes.",
)));
}
let x_coords = Array1::linspace(dx, domain_length - dx, n_nodes);
Ok(Self {
cfg: config,
n_nodes,
dx,
save_every: save_every.max(1),
x_coords,
})
}
pub fn solve(
&self,
u0: &Array1<f64>,
v0: &Array1<f64>,
t0: f64,
t_end: f64,
rng: &mut StdRng,
) -> IntegrateResult<StochasticWaveSolution> {
if u0.len() != self.n_nodes || v0.len() != self.n_nodes {
return Err(IntegrateError::DimensionMismatch(format!(
"u0/v0 length must equal n_nodes = {}",
self.n_nodes
)));
}
if t_end <= t0 {
return Err(IntegrateError::InvalidInput(
"t_end must be greater than t0".to_string(),
));
}
let normal = Normal::new(0.0_f64, 1.0).map_err(|e| {
IntegrateError::ComputationError(format!("Normal distribution: {e}"))
})?;
let dt = self.cfg.dt;
let c = self.cfg.wave_speed;
let sigma = self.cfg.sigma;
let c2_over_dx2 = c * c / (self.dx * self.dx);
let noise_scale = sigma / (self.dx * dt).sqrt();
let n_steps = ((t_end - t0) / dt).ceil() as usize;
let capacity = (n_steps / self.save_every + 2).max(2);
let mut snapshots = Vec::with_capacity(capacity);
let mut u = u0.clone();
let mut v = v0.clone();
let e0 = self.discrete_energy(&u, &v);
snapshots.push(WaveSnapshot {
t: t0,
u: u.clone(),
v: v.clone(),
energy: e0,
});
let mut t = t0;
let mut u_new = Array1::<f64>::zeros(self.n_nodes);
for step in 0..n_steps {
let actual_dt = dt.min(t_end - t);
if actual_dt <= 0.0 {
break;
}
let c2dx2 = c * c / (self.dx * self.dx);
let ns = sigma / (self.dx * actual_dt).sqrt();
for i in 0..self.n_nodes {
let u_left = if i == 0 { 0.0 } else { u[i - 1] };
let u_right = if i < self.n_nodes - 1 { u[i + 1] } else { 0.0 };
let laplacian = u_left - 2.0 * u[i] + u_right;
let xi = rng.sample(&normal);
v[i] += actual_dt * (c2dx2 * laplacian + ns * xi);
}
for i in 0..self.n_nodes {
u_new[i] = u[i] + actual_dt * v[i];
}
u.assign(&u_new);
t += actual_dt;
if (step + 1) % self.save_every == 0 || t >= t_end - 1e-14 {
let energy = self.discrete_energy(&u, &v);
snapshots.push(WaveSnapshot {
t,
u: u.clone(),
v: v.clone(),
energy,
});
}
}
Ok(StochasticWaveSolution {
snapshots,
grid: self.x_coords.clone(),
})
}
fn discrete_energy(&self, u: &Array1<f64>, v: &Array1<f64>) -> f64 {
let c = self.cfg.wave_speed;
let mut kinetic = 0.0_f64;
let mut potential = 0.0_f64;
for i in 0..self.n_nodes {
kinetic += v[i] * v[i];
let u_right = if i < self.n_nodes - 1 { u[i + 1] } else { 0.0 };
let grad = (u_right - u[i]) / self.dx;
potential += c * c * grad * grad;
}
0.5 * self.dx * (kinetic + potential)
}
pub fn grid(&self) -> &Array1<f64> {
&self.x_coords
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
use scirs2_core::random::prelude::*;
fn make_rng() -> StdRng {
seeded_rng(999)
}
#[test]
fn test_deterministic_wave_energy_conserved() {
let config = StochasticWaveConfig {
wave_speed: 1.0,
sigma: 0.0, dt: 5e-5,
};
let solver = StochasticWaveSolver::new(config, 1.0, 20, 10).expect("StochasticWaveSolver::new should succeed");
let u0 = Array1::from_vec(
(0..20)
.map(|i| ((i as f64 + 1.0) * std::f64::consts::PI / 21.0).sin())
.collect::<Vec<f64>>(),
);
let v0 = Array1::zeros(20);
let mut rng = make_rng();
let sol = solver.solve(&u0, &v0, 0.0, 0.05, &mut rng).expect("solver.solve should succeed");
let energies = sol.energy_series();
let e0 = energies[0];
for &e in &energies[1..] {
let rel_err = ((e - e0) / e0).abs();
assert!(
rel_err < 0.02,
"Energy not conserved: e0={e0:.6}, e={e:.6}, rel={rel_err:.4}"
);
}
}
#[test]
fn test_stochastic_wave_energy_grows() {
let config = StochasticWaveConfig {
wave_speed: 1.0,
sigma: 0.5,
dt: 5e-5,
};
let solver = StochasticWaveSolver::new(config, 1.0, 20, 10).expect("StochasticWaveSolver::new should succeed");
let u0 = Array1::zeros(20);
let v0 = Array1::zeros(20);
let mut rng = make_rng();
let sol = solver.solve(&u0, &v0, 0.0, 0.02, &mut rng).expect("solver.solve should succeed");
for s in &sol.snapshots {
assert!(s.energy.is_finite(), "Non-finite energy at t={}", s.t);
}
assert!(sol.snapshots.last().map(|s| s.energy).unwrap_or(0.0) >= 0.0);
}
#[test]
fn test_courant_violation_error() {
let config = StochasticWaveConfig {
wave_speed: 1.0,
sigma: 0.0,
dt: 0.1,
};
let result = StochasticWaveSolver::new(config, 1.0, 20, 1);
assert!(result.is_err(), "Should fail Courant check");
}
#[test]
fn test_dimension_mismatch_error() {
let config = StochasticWaveConfig::default();
let solver = StochasticWaveSolver::new(config, 1.0, 10, 1).expect("StochasticWaveSolver::new should succeed");
let u0 = Array1::zeros(5); let v0 = Array1::zeros(10);
let mut rng = make_rng();
let result = solver.solve(&u0, &v0, 0.0, 0.001, &mut rng);
assert!(result.is_err());
}
#[test]
fn test_all_finite_values() {
let config = StochasticWaveConfig {
wave_speed: 0.5,
sigma: 0.1,
dt: 1e-4,
};
let solver = StochasticWaveSolver::new(config, 1.0, 15, 20).expect("StochasticWaveSolver::new should succeed");
let n = solver.n_nodes;
let u0 = Array1::zeros(n);
let v0 = Array1::zeros(n);
let mut rng = make_rng();
let sol = solver.solve(&u0, &v0, 0.0, 0.01, &mut rng).expect("solver.solve should succeed");
for s in &sol.snapshots {
assert!(s.u.iter().all(|v| v.is_finite()));
assert!(s.v.iter().all(|v| v.is_finite()));
}
}
}