use crate::error::{IntegrateError, IntegrateResult};
use crate::sde::{compute_n_steps, SdeOptions, SdeProblem, SdeSolution};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::{Normal, Rng, StdRng};
use scirs2_core::Distribution;
pub fn srk_strong<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
srk_strong_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn srk_strong_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64)
.map_err(|e| IntegrateError::ComputationError(format!("Normal dist error: {}", e)))?;
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let dt_actual = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if dt_actual <= 0.0 {
break;
}
let sqrt_dt = dt_actual.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_dt);
let f0 = (prob.f_drift)(t, &x);
let g0 = (prob.g_diffusion)(t, &x);
if f0.len() != n_state || g0.nrows() != n_state || g0.ncols() != m {
return Err(IntegrateError::DimensionMismatch(
"Dimension mismatch in drift or diffusion".to_string(),
));
}
let f0_dt = f0.clone() * dt_actual;
let mut srk_correction = Array1::<f64>::zeros(n_state);
for j in 0..m {
let g_col_j = g0.column(j).to_owned();
let h_plus = &x + &f0_dt + &g_col_j * sqrt_dt;
let h_minus = &x + &f0_dt - &g_col_j * sqrt_dt;
let g_h_plus = (prob.g_diffusion)(t, &h_plus);
let g_h_minus = (prob.g_diffusion)(t, &h_minus);
let dw_j = dw[j];
let factor = (dw_j * dw_j * 0.5 - dt_actual * 0.5) / sqrt_dt;
for i in 0..n_state {
srk_correction[i] += (g_h_plus[[i, j]] - g_h_minus[[i, j]]) * 0.5 * factor;
}
}
x = x + f0 * dt_actual + g0.dot(&dw) + srk_correction;
t += dt_actual;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
pub fn platen_scheme<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
platen_scheme_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn platen_scheme_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64)
.map_err(|e| IntegrateError::ComputationError(format!("Normal dist error: {}", e)))?;
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let dt_actual = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if dt_actual <= 0.0 {
break;
}
let sqrt_dt = dt_actual.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_dt);
let f0 = (prob.f_drift)(t, &x);
let g0 = (prob.g_diffusion)(t, &x);
if f0.len() != n_state || g0.nrows() != n_state || g0.ncols() != m {
return Err(IntegrateError::DimensionMismatch(
"Dimension mismatch in drift or diffusion".to_string(),
));
}
let ones_m = Array1::from_elem(m, 1.0_f64);
let x_hat = &x + &(f0.clone() * dt_actual) + &g0.dot(&ones_m) * sqrt_dt;
let f1 = (prob.f_drift)(t + dt_actual, &x_hat);
let g1 = (prob.g_diffusion)(t, &x_hat);
let drift_update = (&f0 + &f1) * (0.5 * dt_actual);
let diff_update = g0.dot(&dw);
let mut platen_corr = Array1::<f64>::zeros(n_state);
for j in 0..m {
let dw_j = dw[j];
let factor = (dw_j * dw_j - dt_actual) / (2.0 * sqrt_dt);
for i in 0..n_state {
platen_corr[i] += (g1[[i, j]] - g0[[i, j]]) * factor;
}
}
x = x + drift_update + diff_update + platen_corr;
t += dt_actual;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sde::SdeProblem;
use scirs2_core::ndarray::{array, Array2};
use scirs2_core::random::prelude::seeded_rng;
fn make_ou_prob(theta: f64, mu: f64, sigma: f64, x0: f64, t1: f64) -> SdeProblem<
impl Fn(f64, &Array1<f64>) -> Array1<f64>,
impl Fn(f64, &Array1<f64>) -> Array2<f64>,
> {
SdeProblem::new(
array![x0],
[0.0, t1],
1,
move |_t, x| array![theta * (mu - x[0])],
move |_t, _x| {
let mut g = Array2::zeros((1, 1));
g[[0, 0]] = sigma;
g
},
)
}
fn make_gbm_prob(mu: f64, sigma: f64, s0: f64) -> SdeProblem<
impl Fn(f64, &Array1<f64>) -> Array1<f64>,
impl Fn(f64, &Array1<f64>) -> Array2<f64>,
> {
SdeProblem::new(
array![s0],
[0.0, 1.0],
1,
move |_t, x| array![mu * x[0]],
move |_t, x| {
let mut g = Array2::zeros((1, 1));
g[[0, 0]] = sigma * x[0];
g
},
)
}
#[test]
fn test_srk_strong_solution_length() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(0);
let sol = srk_strong(&prob, 0.1, &mut rng).expect("srk_strong should succeed");
assert_eq!(sol.len(), 11);
}
#[test]
fn test_platen_solution_length() {
let prob = make_ou_prob(1.0, 0.0, 0.5, 1.0, 1.0);
let mut rng = seeded_rng(42);
let sol = platen_scheme(&prob, 0.1, &mut rng).expect("platen_scheme should succeed");
assert_eq!(sol.len(), 11);
}
#[test]
fn test_srk_gbm_weak_mean() {
let mu = 0.1_f64;
let sigma = 0.2_f64;
let s0 = 1.0_f64;
let t1 = 1.0_f64;
let dt = 0.01;
let analytic = s0 * (mu * t1).exp();
let n_paths = 300;
let mut sum = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_gbm_prob(mu, sigma, s0);
let mut rng = seeded_rng(seed + 5000);
let sol = srk_strong(&prob, dt, &mut rng).expect("srk_strong should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let mean = sum / n_paths as f64;
let rel_err = (mean - analytic).abs() / analytic;
assert!(
rel_err < 0.05,
"SRK GBM mean {:.4} vs analytic {:.4}, rel_err {:.4}",
mean,
analytic,
rel_err
);
}
#[test]
fn test_platen_ou_stationary_mean() {
let theta = 2.0_f64;
let mu_ou = 0.5_f64;
let sigma = 0.3_f64;
let dt = 0.005;
let t1 = 5.0;
let n_paths = 200;
let mut sum_final = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_ou_prob(theta, mu_ou, sigma, 0.0, t1);
let mut rng = seeded_rng(seed + 2000);
let sol = platen_scheme(&prob, dt, &mut rng).expect("platen_scheme should succeed");
sum_final += sol.x_final().expect("solution has state")[0];
}
let mean_final = sum_final / n_paths as f64;
assert!(
(mean_final - mu_ou).abs() < 0.1,
"OU stationary mean {:.4} vs expected {:.4}",
mean_final,
mu_ou
);
}
#[test]
fn test_platen_additive_noise_accuracy() {
let theta = 1.0_f64;
let mu_ou = 0.0_f64;
let sigma = 0.5_f64;
let x0 = 2.0_f64;
let t1 = 1.0_f64;
let analytic_mean = x0 * (-theta * t1).exp();
let dt = 0.005;
let n_paths = 400;
let mut sum = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_ou_prob(theta, mu_ou, sigma, x0, t1);
let mut rng = seeded_rng(seed + 3000);
let sol = platen_scheme(&prob, dt, &mut rng).expect("platen_scheme should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let sample_mean = sum / n_paths as f64;
let abs_err = (sample_mean - analytic_mean).abs();
assert!(
abs_err < 0.1,
"Platen OU mean {:.4} vs analytic {:.4}, abs_err {:.4}",
sample_mean,
analytic_mean,
abs_err
);
}
#[test]
fn test_srk_invalid_dt() {
let prob = make_gbm_prob(0.05, 0.2, 1.0);
let mut rng = seeded_rng(0);
assert!(srk_strong(&prob, 0.0, &mut rng).is_err());
}
}