use crate::error::{IntegrateError, IntegrateResult};
use crate::sde::{compute_n_steps, SdeOptions, SdeProblem, SdeSolution};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::prelude::{Normal, Rng, StdRng};
use scirs2_core::Distribution;
pub fn euler_maruyama<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
euler_maruyama_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn euler_maruyama_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64).map_err(|e| {
IntegrateError::ComputationError(format!("Normal distribution error: {}", e))
})?;
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let dt_actual = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if dt_actual <= 0.0 {
break;
}
let sqrt_dt = dt_actual.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_dt);
let drift = (prob.f_drift)(t, &x);
let diff_matrix = (prob.g_diffusion)(t, &x);
if drift.len() != n_state {
return Err(IntegrateError::DimensionMismatch(format!(
"Drift output dimension {} != state dimension {}",
drift.len(),
n_state
)));
}
if diff_matrix.nrows() != n_state || diff_matrix.ncols() != m {
return Err(IntegrateError::DimensionMismatch(format!(
"Diffusion matrix shape ({},{}) != expected ({},{})",
diff_matrix.nrows(),
diff_matrix.ncols(),
n_state,
m
)));
}
let stochastic_increment = diff_matrix.dot(&dw);
x = x + drift * dt_actual + stochastic_increment;
t += dt_actual;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
pub fn weak_euler_maruyama<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
weak_euler_maruyama_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn weak_euler_maruyama_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let dt_actual = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if dt_actual <= 0.0 {
break;
}
let sqrt_dt = dt_actual.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| {
if rng.random::<bool>() {
sqrt_dt
} else {
-sqrt_dt
}
});
let drift = (prob.f_drift)(t, &x);
let diff_matrix = (prob.g_diffusion)(t, &x);
if drift.len() != n_state {
return Err(IntegrateError::DimensionMismatch(format!(
"Drift output dimension {} != state dimension {}",
drift.len(),
n_state
)));
}
if diff_matrix.nrows() != n_state || diff_matrix.ncols() != m {
return Err(IntegrateError::DimensionMismatch(format!(
"Diffusion matrix shape ({},{}) != expected ({},{})",
diff_matrix.nrows(),
diff_matrix.ncols(),
n_state,
m
)));
}
let stochastic_increment = diff_matrix.dot(&dw);
x = x + drift * dt_actual + stochastic_increment;
t += dt_actual;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
use scirs2_core::ndarray::Array2;
#[cfg(test)]
mod tests {
use super::*;
use crate::sde::SdeProblem;
use scirs2_core::ndarray::{array, Array2};
use scirs2_core::random::prelude::{seeded_rng, SeedableRng};
fn make_gbm_prob(
mu: f64,
sigma: f64,
x0: f64,
) -> SdeProblem<
impl Fn(f64, &Array1<f64>) -> Array1<f64>,
impl Fn(f64, &Array1<f64>) -> Array2<f64>,
> {
SdeProblem::new(
array![x0],
[0.0, 1.0],
1,
move |_t, x| array![mu * x[0]],
move |_t, x| {
let mut g = Array2::zeros((1, 1));
g[[0, 0]] = sigma * x[0];
g
},
)
}
#[test]
fn test_em_gbm_weak_convergence() {
let mu = 0.1_f64;
let sigma = 0.2_f64;
let x0 = 1.0_f64;
let t1 = 1.0_f64;
let dt = 0.001;
let n_paths = 500;
let analytic_mean = x0 * (mu * t1).exp();
let mut sum = 0.0;
for seed in 0..n_paths_u64(n_paths) {
let prob = make_gbm_prob(mu, sigma, x0);
let mut rng = seeded_rng(seed);
let sol = euler_maruyama(&prob, dt, &mut rng).expect("euler_maruyama should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let sample_mean = sum / n_paths as f64;
let rel_error = (sample_mean - analytic_mean).abs() / analytic_mean;
assert!(
rel_error < 0.05,
"GBM mean {:.4} vs analytic {:.4}, rel error {:.4}",
sample_mean,
analytic_mean,
rel_error
);
}
#[test]
fn test_em_solution_length() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(0);
let sol = euler_maruyama(&prob, 0.1, &mut rng).expect("euler_maruyama should succeed");
assert_eq!(sol.len(), 11);
assert!((sol.t[0] - 0.0).abs() < 1e-12);
assert!((sol.t_final().expect("solution has time steps") - 1.0).abs() < 1e-10);
}
#[test]
fn test_weak_em_solution_length() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(1);
let sol =
weak_euler_maruyama(&prob, 0.1, &mut rng).expect("weak_euler_maruyama should succeed");
assert_eq!(sol.len(), 11);
}
#[test]
fn test_em_invalid_dt() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(0);
assert!(euler_maruyama(&prob, -0.1, &mut rng).is_err());
}
#[test]
fn test_em_multivariate() {
let x0 = array![0.0_f64, 0.0_f64];
let prob = SdeProblem::new(
x0,
[0.0, 1.0],
2,
|_t, _x| array![1.0_f64, 1.0_f64],
|_t, _x| {
let mut g = Array2::zeros((2, 2));
g[[0, 0]] = 1.0;
g[[1, 1]] = 1.0;
g
},
);
let mut rng = seeded_rng(42);
let sol = euler_maruyama(&prob, 0.01, &mut rng).expect("euler_maruyama should succeed");
assert_eq!(sol.x[0].len(), 2);
}
#[test]
fn test_em_save_only_last() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(0);
let opts = SdeOptions {
save_all_steps: false,
..Default::default()
};
let sol = euler_maruyama_with_options(&prob, 0.01, &mut rng, &opts)
.expect("euler_maruyama_with_options should succeed");
assert_eq!(sol.len(), 2);
}
fn n_paths_u64(n: usize) -> u64 {
n as u64
}
}