use crate::error::{IntegrateError, IntegrateResult};
use crate::sde::{compute_n_steps, SdeOptions, SdeProblem, SdeSolution};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::{Normal, StdRng};
use scirs2_core::Distribution;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SRKVariant {
Platen15,
SRI2,
SRA3,
}
pub struct SRKSolver {
variant: SRKVariant,
opts: SdeOptions,
}
impl SRKSolver {
pub fn new(variant: SRKVariant) -> Self {
Self {
variant,
opts: SdeOptions::default(),
}
}
pub fn with_options(variant: SRKVariant, opts: SdeOptions) -> Self {
Self { variant, opts }
}
pub fn solve<F, G>(
&self,
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
match self.variant {
SRKVariant::Platen15 => platen15_with_options(prob, dt, rng, &self.opts),
SRKVariant::SRI2 => sri2_with_options(prob, dt, rng, &self.opts),
SRKVariant::SRA3 => sra3_with_options(prob, dt, rng, &self.opts),
}
}
}
pub fn platen15<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
platen15_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn platen15_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64)
.map_err(|e| IntegrateError::ComputationError(format!("Normal dist error: {}", e)))?;
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let h = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if h <= 0.0 {
break;
}
let sqrt_h = h.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_h);
let dv: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_h);
let f0 = (prob.f_drift)(t, &x);
let g0 = (prob.g_diffusion)(t, &x);
validate_dimensions(&f0, &g0, n_state, m)?;
let support_ones: Array1<f64> = Array1::from_elem(m, sqrt_h);
let x_hat = &x + &(f0.clone() * h) + &g0.dot(&support_ones);
let f1 = (prob.f_drift)(t + h, &x_hat);
let g1 = (prob.g_diffusion)(t, &x_hat);
let drift_term = (&f0 + &f1) * (0.5 * h);
let diff_term = g0.dot(&dw);
let mut stoch_correction = Array1::<f64>::zeros(n_state);
for j in 0..m {
let dw_j = dw[j];
let dv_j = dv[j];
let dz_j = 0.5 * h * (dw_j + dv_j / 3.0_f64.sqrt());
let milstein_factor = (dw_j * dw_j - h) / (2.0 * sqrt_h);
let ito_factor = (dw_j * h - dz_j) / sqrt_h;
for i in 0..n_state {
let g0_ij = g0[[i, j]];
let g1_ij = g1[[i, j]];
let g_diff = g1_ij - g0_ij;
let g_sum = g1_ij + g0_ij;
stoch_correction[i] +=
0.5 * g_diff / sqrt_h * milstein_factor + 0.25 * g_sum / sqrt_h * ito_factor;
}
}
x = x + drift_term + diff_term + stoch_correction;
t += h;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
pub fn sri2<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
sri2_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn sri2_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64)
.map_err(|e| IntegrateError::ComputationError(format!("Normal dist error: {}", e)))?;
let mut x = prob.x0.clone();
let mut t = t0;
for step in 0..n_steps {
let h = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if h <= 0.0 {
break;
}
let sqrt_h = h.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_h);
let f1 = (prob.f_drift)(t, &x);
let g1 = (prob.g_diffusion)(t, &x);
validate_dimensions(&f1, &g1, n_state, m)?;
let h2 = &x + &(f1.clone() * (h * 0.5)) + &g1.dot(&dw);
let f2 = (prob.f_drift)(t + h, &h2);
let g2 = (prob.g_diffusion)(t, &h2);
let support_ones: Array1<f64> = Array1::from_elem(m, sqrt_h);
let h3 = &x + &(f1.clone() * h) + &g1.dot(&support_ones);
let g3 = (prob.g_diffusion)(t, &h3);
let drift = (&f1 + &f2) * (0.5 * h);
let diff_avg = (&g1 + &g2).dot(&dw) * 0.5;
let mut sri_corr = Array1::<f64>::zeros(n_state);
for j in 0..m {
let dw_j = dw[j];
let iterated = (dw_j * dw_j - h) / (2.0 * sqrt_h);
for i in 0..n_state {
sri_corr[i] += (g3[[i, j]] - g1[[i, j]]) * iterated;
}
}
x = x + drift + diff_avg + sri_corr;
t += h;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
pub fn sra3<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
sra3_with_options(prob, dt, rng, &SdeOptions::default())
}
pub fn sra3_with_options<F, G>(
prob: &SdeProblem<F, G>,
dt: f64,
rng: &mut StdRng,
opts: &SdeOptions,
) -> IntegrateResult<SdeSolution>
where
F: Fn(f64, &Array1<f64>) -> Array1<f64>,
G: Fn(f64, &Array1<f64>) -> Array2<f64>,
{
prob.validate()?;
let t0 = prob.t_span[0];
let t1 = prob.t_span[1];
let n_steps = compute_n_steps(t0, t1, dt, opts.max_steps)?;
let n_state = prob.dim();
let m = prob.n_brownian;
let capacity = if opts.save_all_steps { n_steps + 1 } else { 2 };
let mut sol = SdeSolution::with_capacity(capacity);
sol.push(t0, prob.x0.clone());
let normal = Normal::new(0.0_f64, 1.0_f64)
.map_err(|e| IntegrateError::ComputationError(format!("Normal dist error: {}", e)))?;
let mut x = prob.x0.clone();
let mut t = t0;
let c2 = 0.75_f64;
let a21 = 0.75_f64;
let a31 = 1.0_f64 / 3.0;
let a32 = 2.0_f64 / 3.0;
let b1 = 1.0_f64 / 9.0;
let b2 = 2.0_f64 / 3.0;
let b3 = 2.0_f64 / 9.0;
for step in 0..n_steps {
let h = if step == n_steps - 1 {
t1 - t
} else {
dt.min(t1 - t)
};
if h <= 0.0 {
break;
}
let sqrt_h = h.sqrt();
let dw: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_h);
let dv: Array1<f64> = Array1::from_shape_fn(m, |_| normal.sample(rng) * sqrt_h);
let f1 = (prob.f_drift)(t, &x);
let g0 = (prob.g_diffusion)(t, &x);
validate_dimensions(&f1, &g0, n_state, m)?;
let h2 = &x + &(f1.clone() * (a21 * h));
let f2 = (prob.f_drift)(t + c2 * h, &h2);
let h3 = &x + &(f1.clone() * (a31 * h)) + &(f2.clone() * (a32 * h));
let f3 = (prob.f_drift)(t + h, &h3);
let drift = (&f1 * b1 + &f2 * b2 + &f3 * b3) * h;
let diff_base = g0.dot(&dw);
let sqrt3_inv = 1.0_f64 / 3.0_f64.sqrt();
let mut sra_corr = Array1::<f64>::zeros(n_state);
for j in 0..m {
let dw_j = dw[j];
let dv_j = dv[j];
let i10 = 0.5 * h * (dw_j + dv_j * sqrt3_inv);
let i01 = 0.5 * h * (dw_j - dv_j * sqrt3_inv);
let g_future = (prob.g_diffusion)(t + h, &x);
for i in 0..n_state {
let dg_dt_ij = (g_future[[i, j]] - g0[[i, j]]) / h;
sra_corr[i] += dg_dt_ij * i01;
let _ = i10; }
}
x = x + drift + diff_base + sra_corr;
t += h;
if opts.save_all_steps {
sol.push(t, x.clone());
}
}
if !opts.save_all_steps {
sol.push(t, x);
}
Ok(sol)
}
fn validate_dimensions(
f: &Array1<f64>,
g: &Array2<f64>,
n_state: usize,
m: usize,
) -> IntegrateResult<()> {
if f.len() != n_state {
return Err(IntegrateError::DimensionMismatch(format!(
"Drift output dimension {} != state dimension {}",
f.len(),
n_state
)));
}
if g.nrows() != n_state || g.ncols() != m {
return Err(IntegrateError::DimensionMismatch(format!(
"Diffusion matrix shape ({},{}) != expected ({},{})",
g.nrows(),
g.ncols(),
n_state,
m
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sde::SdeProblem;
use scirs2_core::ndarray::{array, Array2};
use scirs2_core::random::prelude::seeded_rng;
fn make_gbm(
mu: f64,
sigma: f64,
s0: f64,
t1: f64,
) -> SdeProblem<
impl Fn(f64, &Array1<f64>) -> Array1<f64>,
impl Fn(f64, &Array1<f64>) -> Array2<f64>,
> {
SdeProblem::new(
array![s0],
[0.0, t1],
1,
move |_t, x| array![mu * x[0]],
move |_t, x| {
let mut g = Array2::zeros((1, 1));
g[[0, 0]] = sigma * x[0];
g
},
)
}
fn make_ou_additive(
theta: f64,
mu_ou: f64,
sigma: f64,
x0: f64,
t1: f64,
) -> SdeProblem<
impl Fn(f64, &Array1<f64>) -> Array1<f64>,
impl Fn(f64, &Array1<f64>) -> Array2<f64>,
> {
SdeProblem::new(
array![x0],
[0.0, t1],
1,
move |_t, x| array![theta * (mu_ou - x[0])],
move |_t, _x| {
let mut g = Array2::zeros((1, 1));
g[[0, 0]] = sigma;
g
},
)
}
#[test]
fn test_platen15_solution_length() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
let mut rng = seeded_rng(0);
let sol = platen15(&prob, 0.1, &mut rng).expect("platen15 should succeed");
assert_eq!(sol.len(), 11);
assert!((sol.t[0] - 0.0).abs() < 1e-12);
assert!((sol.t_final().expect("solution has time steps") - 1.0).abs() < 1e-10);
}
#[test]
fn test_platen15_gbm_positive() {
let prob = make_gbm(0.05, 0.2, 100.0, 1.0);
let mut rng = seeded_rng(42);
let sol = platen15(&prob, 0.01, &mut rng).expect("platen15 should succeed");
for xi in &sol.x {
assert!(
xi[0] > 0.0,
"Platen15 GBM should stay positive, got {}",
xi[0]
);
}
}
#[test]
fn test_platen15_gbm_weak_mean() {
let mu = 0.1_f64;
let sigma = 0.2_f64;
let s0 = 1.0_f64;
let t1 = 1.0_f64;
let analytic = s0 * (mu * t1).exp();
let n_paths = 400;
let dt = 0.01;
let mut sum = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_gbm(mu, sigma, s0, t1);
let mut rng = seeded_rng(seed + 9000);
let sol = platen15(&prob, dt, &mut rng).expect("platen15 should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let mean = sum / n_paths as f64;
let rel_err = (mean - analytic).abs() / analytic;
assert!(
rel_err < 0.05,
"Platen15 GBM mean {:.4} vs analytic {:.4}, rel_err {:.4}",
mean,
analytic,
rel_err
);
}
#[test]
fn test_sri2_solution_length() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
let mut rng = seeded_rng(1);
let sol = sri2(&prob, 0.1, &mut rng).expect("sri2 should succeed");
assert_eq!(sol.len(), 11);
}
#[test]
fn test_sri2_gbm_weak_mean() {
let mu = 0.05_f64;
let sigma = 0.2_f64;
let s0 = 1.0_f64;
let t1 = 1.0_f64;
let analytic = s0 * (mu * t1).exp();
let n_paths = 400;
let dt = 0.01;
let mut sum = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_gbm(mu, sigma, s0, t1);
let mut rng = seeded_rng(seed + 11000);
let sol = sri2(&prob, dt, &mut rng).expect("sri2 should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let mean = sum / n_paths as f64;
let rel_err = (mean - analytic).abs() / analytic;
assert!(
rel_err < 0.05,
"SRI2 GBM mean {:.4} vs analytic {:.4}, rel_err {:.4}",
mean,
analytic,
rel_err
);
}
#[test]
fn test_sra3_solution_length() {
let prob = make_ou_additive(1.0, 0.0, 0.5, 1.0, 1.0);
let mut rng = seeded_rng(2);
let sol = sra3(&prob, 0.1, &mut rng).expect("sra3 should succeed");
assert_eq!(sol.len(), 11);
}
#[test]
fn test_sra3_ou_weak_mean() {
let theta = 1.0_f64;
let mu_ou = 0.0_f64;
let sigma = 0.3_f64;
let x0 = 2.0_f64;
let t1 = 1.0_f64;
let analytic = x0 * (-theta * t1).exp();
let n_paths = 400;
let dt = 0.01;
let mut sum = 0.0;
for seed in 0..n_paths as u64 {
let prob = make_ou_additive(theta, mu_ou, sigma, x0, t1);
let mut rng = seeded_rng(seed + 7000);
let sol = sra3(&prob, dt, &mut rng).expect("sra3 should succeed");
sum += sol.x_final().expect("solution has state")[0];
}
let mean = sum / n_paths as f64;
let abs_err = (mean - analytic).abs();
assert!(
abs_err < 0.15,
"SRA3 OU mean {:.4} vs analytic {:.4}, abs_err {:.4}",
mean,
analytic,
abs_err
);
}
#[test]
fn test_srk_solver_variants() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
for variant in [SRKVariant::Platen15, SRKVariant::SRI2, SRKVariant::SRA3] {
let solver = SRKSolver::new(variant);
let mut rng = seeded_rng(0);
let sol = solver
.solve(&prob, 0.1, &mut rng)
.expect("solver.solve should succeed");
assert_eq!(
sol.len(),
11,
"Variant {:?} should produce 11 steps",
variant
);
}
}
#[test]
fn test_invalid_dt_platen15() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
let mut rng = seeded_rng(0);
assert!(platen15(&prob, 0.0, &mut rng).is_err());
assert!(platen15(&prob, -0.1, &mut rng).is_err());
}
#[test]
fn test_invalid_dt_sri2() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
let mut rng = seeded_rng(0);
assert!(sri2(&prob, 0.0, &mut rng).is_err());
}
#[test]
fn test_invalid_dt_sra3() {
let prob = make_ou_additive(1.0, 0.0, 0.5, 1.0, 1.0);
let mut rng = seeded_rng(0);
assert!(sra3(&prob, -0.1, &mut rng).is_err());
}
#[test]
fn test_save_only_last() {
let prob = make_gbm(0.05, 0.2, 1.0, 1.0);
let opts = SdeOptions {
save_all_steps: false,
..Default::default()
};
let mut rng = seeded_rng(0);
let sol = platen15_with_options(&prob, 0.01, &mut rng, &opts)
.expect("platen15_with_options should succeed");
assert_eq!(sol.len(), 2, "Should have only initial + final states");
}
#[test]
fn test_multivariate_platen15() {
let prob = SdeProblem::new(
array![1.0_f64, 2.0_f64],
[0.0, 1.0],
2,
|_t, x| array![0.05 * x[0], 0.03 * x[1]],
|_t, x| {
let mut g = Array2::zeros((2, 2));
g[[0, 0]] = 0.2 * x[0];
g[[1, 1]] = 0.15 * x[1];
g
},
);
let mut rng = seeded_rng(42);
let sol = platen15(&prob, 0.01, &mut rng).expect("platen15 should succeed");
assert_eq!(sol.x[0].len(), 2);
for xi in &sol.x {
assert!(xi[0] > 0.0);
assert!(xi[1] > 0.0);
}
}
}