use numra_core::Scalar;
use numra_pde::Grid1D;
#[derive(Clone, Debug, Default)]
pub enum NoiseCorrelation<S: Scalar> {
#[default]
White,
Colored {
correlation_length: S,
},
TraceClass {
n_modes: usize,
decay_rate: S,
},
}
pub trait SpdeSystem<S: Scalar> {
fn dim(&self) -> usize {
1
}
fn drift(&self, t: S, u: &[S], du: &mut [S], grid: &Grid1D<S>);
fn diffusion(&self, t: S, u: &[S], sigma: &mut [S], grid: &Grid1D<S>);
fn noise_correlation(&self) -> NoiseCorrelation<S> {
NoiseCorrelation::White
}
fn is_additive(&self) -> bool {
true
}
}
#[allow(dead_code)]
pub struct Spde1D<S: Scalar, F, G>
where
F: Fn(S, &[S], &mut [S], &Grid1D<S>),
G: Fn(S, &[S], &mut [S], &Grid1D<S>),
{
drift_fn: F,
diffusion_fn: G,
correlation: NoiseCorrelation<S>,
additive: bool,
}
#[allow(dead_code)]
impl<S: Scalar, F, G> Spde1D<S, F, G>
where
F: Fn(S, &[S], &mut [S], &Grid1D<S>),
G: Fn(S, &[S], &mut [S], &Grid1D<S>),
{
pub fn new(drift: F, diffusion: G) -> Self {
Self {
drift_fn: drift,
diffusion_fn: diffusion,
correlation: NoiseCorrelation::White,
additive: true,
}
}
pub fn with_correlation(mut self, correlation: NoiseCorrelation<S>) -> Self {
self.correlation = correlation;
self
}
pub fn with_additive(mut self, additive: bool) -> Self {
self.additive = additive;
self
}
}
impl<S: Scalar, F, G> SpdeSystem<S> for Spde1D<S, F, G>
where
F: Fn(S, &[S], &mut [S], &Grid1D<S>),
G: Fn(S, &[S], &mut [S], &Grid1D<S>),
{
fn drift(&self, t: S, u: &[S], du: &mut [S], grid: &Grid1D<S>) {
(self.drift_fn)(t, u, du, grid)
}
fn diffusion(&self, t: S, u: &[S], sigma: &mut [S], grid: &Grid1D<S>) {
(self.diffusion_fn)(t, u, sigma, grid)
}
fn noise_correlation(&self) -> NoiseCorrelation<S> {
self.correlation.clone()
}
fn is_additive(&self) -> bool {
self.additive
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spde1d_construction() {
let _spde: Spde1D<f64, _, _> = Spde1D::new(
|_t, u, du, grid| {
let dx = grid.dx_uniform();
let n = u.len();
for i in 0..n {
let u_left = if i == 0 { 0.0 } else { u[i - 1] };
let u_right = if i == n - 1 { 0.0 } else { u[i + 1] };
du[i] = (u_left - 2.0 * u[i] + u_right) / (dx * dx);
}
},
|_t, _u, sigma, _grid| {
for s in sigma.iter_mut() {
*s = 0.1;
}
},
);
}
#[test]
fn test_noise_correlation_default() {
let correlation: NoiseCorrelation<f64> = NoiseCorrelation::default();
assert!(matches!(correlation, NoiseCorrelation::White));
}
#[test]
fn test_colored_noise() {
let correlation: NoiseCorrelation<f64> = NoiseCorrelation::Colored {
correlation_length: 0.1,
};
if let NoiseCorrelation::Colored { correlation_length } = correlation {
assert!((correlation_length - 0.1).abs() < 1e-10);
}
}
}