use crate::models::forex::dupire_local_vol::DupireLocalVol;
use crate::models::forex::sabr::SabrState;
use crate::models::forex::sabr_time_dependent::TimeDependentSabrParams;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rand_distr::StandardNormal;
pub struct TimeDependentSabrSlvSimulator {
pub params: TimeDependentSabrParams,
pub dupire: DupireLocalVol,
pub n_bins: usize,
rng: ChaCha20Rng,
}
impl TimeDependentSabrSlvSimulator {
pub fn new(params: TimeDependentSabrParams, dupire: DupireLocalVol, seed: u64) -> Self {
Self {
params,
dupire,
n_bins: 50,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}
pub fn with_bins(mut self, n_bins: usize) -> Self {
assert!(n_bins >= 2);
self.n_bins = n_bins;
self
}
pub fn simulate(&mut self, t_end: f64, n_steps: usize, n_paths: usize) -> Vec<SabrState> {
assert!(n_steps > 0 && n_paths > 0 && t_end > 0.0);
let dt = t_end / n_steps as f64;
let sqrt_dt = dt.sqrt();
let alpha0 = self.params.alpha.at(0.0);
let mut states: Vec<SabrState> = (0..n_paths)
.map(|_| SabrState {
forward: self.params.forward_0,
vol: alpha0,
})
.collect();
let mut t = 0.0_f64;
for _ in 0..n_steps {
let t_mid = t + 0.5 * dt;
let pm = self.params.at(t_mid);
let sqrt_1mr2 = (1.0 - pm.rho * pm.rho).sqrt();
let (bin_upper, bin_mean) = bin_vol_sq_by_forward(&states, self.n_bins);
for state in states.iter_mut() {
let z1: f64 = self.rng.sample(StandardNormal);
let z2: f64 = self.rng.sample(StandardNormal);
let dw_f = sqrt_dt * z1;
let dw_a = sqrt_dt * (pm.rho * z1 + sqrt_1mr2 * z2);
let f = state.forward.max(1.0e-12);
let vol = state.vol.max(0.0);
let lv_var = self.dupire.local_variance(t_mid, f);
let e_vol2 = lookup_bin(f, &bin_upper, &bin_mean);
let f_beta_m1 = f.powf(pm.beta - 1.0); let f_scaling = f_beta_m1 * f_beta_m1; let denom = (e_vol2 * f_scaling).max(1.0e-12);
let sigma_slv_sq = (lv_var / denom).max(0.0);
let sigma_slv = sigma_slv_sq.sqrt();
let diffusion_f = sigma_slv * vol * f.powf(pm.beta) * dw_f;
let new_f = (f + diffusion_f).max(0.0);
let new_vol = vol * ((-0.5 * pm.nu * pm.nu) * dt + pm.nu * dw_a).exp();
*state = SabrState {
forward: new_f,
vol: new_vol,
};
}
t += dt;
}
states
}
}
fn bin_vol_sq_by_forward(states: &[SabrState], n_bins: usize) -> (Vec<f64>, Vec<f64>) {
let n = states.len();
let mut indexed: Vec<(f64, f64)> = states.iter().map(|s| (s.forward, s.vol * s.vol)).collect();
indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let bin_size = n.div_ceil(n_bins);
let mut upper: Vec<f64> = Vec::new();
let mut means: Vec<f64> = Vec::new();
let mut i = 0;
while i < n {
let end = (i + bin_size).min(n);
let mut sum = 0.0_f64;
for item in &indexed[i..end] {
sum += item.1;
}
means.push(sum / (end - i) as f64);
upper.push(indexed[end - 1].0);
i = end;
}
(upper, means)
}
fn lookup_bin(f: f64, bin_upper: &[f64], bin_mean: &[f64]) -> f64 {
for (i, &u) in bin_upper.iter().enumerate() {
if f <= u {
return bin_mean[i];
}
}
*bin_mean.last().unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::forex::dupire_local_vol::build as dupire_build;
use crate::models::forex::sabr::hagan_implied_vol;
use crate::models::forex::sabr_calibrator::{calibrate as calibrate_sabr, targets_from_grid};
use crate::models::forex::sabr_effective::PiecewiseConstant;
fn toy_params(f0: f64, t_end: f64, alpha: f64, rho: f64, nu: f64) -> TimeDependentSabrParams {
TimeDependentSabrParams::new(
PiecewiseConstant::constant(t_end, alpha),
PiecewiseConstant::constant(t_end, rho),
PiecewiseConstant::constant(t_end, nu),
0.5,
f0,
)
}
fn flat_dupire(sigma: f64, f0: f64) -> DupireLocalVol {
let expiries = vec![0.25_f64, 0.5, 1.0, 1.5, 2.0];
let strikes: Vec<f64> = (0..7).map(|i| f0 * (0.7 + 0.1 * i as f64)).collect();
let vols = vec![vec![sigma; strikes.len()]; expiries.len()];
dupire_build(&expiries, &strikes, &vols, f0, 0.0, 0.0)
}
#[test]
fn flat_market_preserves_forward_martingale() {
let f0 = 1.30_f64;
let params = toy_params(f0, 1.0, 0.20, 0.0, 0.0);
let dupire = flat_dupire(0.20, f0);
let mut sim = TimeDependentSabrSlvSimulator::new(params, dupire, 42).with_bins(25);
let terms = sim.simulate(1.0, 200, 5_000);
let mean: f64 = terms.iter().map(|s| s.forward).sum::<f64>() / 5_000.0;
let rel = (mean - f0).abs() / f0;
assert!(rel < 0.02, "E[F] = {}, F₀ = {}, rel {:.4}", mean, f0, rel);
}
#[test]
fn binning_produces_monotone_boundaries_and_correct_means() {
let states: Vec<SabrState> = (0..100)
.map(|i| SabrState {
forward: 1.0 + 0.01 * i as f64,
vol: 0.1 * (1 + i / 10) as f64, })
.collect();
let (upper, means) = bin_vol_sq_by_forward(&states, 10);
assert!(upper.len() >= 9 && upper.len() <= 10);
for w in upper.windows(2) {
assert!(w[1] > w[0]);
}
assert!(means[0] < *means.last().unwrap());
}
#[test]
fn slv_compensator_reduces_atm_calibration_residual() {
use crate::models::forex::sabr::SabrParams;
use crate::models::forex::sabr_time_dependent::TimeDependentSabrSimulator;
let expiries = vec![0.25_f64, 0.5, 1.0, 1.5, 2.0];
let f0 = 1.30_f64;
let strikes: Vec<f64> = (0..9).map(|i| f0 * (0.70 + 0.075 * i as f64)).collect();
let truth = SabrParams::new(0.15, 0.5, -0.30, 0.50);
let market_vols: Vec<Vec<f64>> = expiries
.iter()
.map(|&t| {
strikes
.iter()
.map(|&k| hagan_implied_vol(&truth, f0, k, t))
.collect()
})
.collect();
let dupire = dupire_build(&expiries, &strikes, &market_vols, f0, 0.0, 0.0);
let misfit_alpha = 0.13_f64; let misfit_rho = -0.15_f64; let misfit_nu = 0.30_f64; let params = toy_params(f0, 2.0, misfit_alpha, misfit_rho, misfit_nu);
let market_atm_1y = hagan_implied_vol(&truth, f0, f0, 1.0);
let mut plain = TimeDependentSabrSimulator::new(params.clone(), 777);
let terms_plain = plain.simulate(1.0, 200, 10_000);
let mc_plain: f64 = terms_plain
.iter()
.map(|s| (s.forward - f0).max(0.0))
.sum::<f64>()
/ 10_000.0;
let plain_iv =
crate::models::common::black_scholes::bs_implied_vol(mc_plain, f0, f0, 1.0, 1.0, true)
.expect("plain MC price should invert");
let plain_err = (plain_iv - market_atm_1y).abs();
let mut slv = TimeDependentSabrSlvSimulator::new(params, dupire, 777).with_bins(40);
let terms_slv = slv.simulate(1.0, 200, 10_000);
let mc_slv: f64 = terms_slv
.iter()
.map(|s| (s.forward - f0).max(0.0))
.sum::<f64>()
/ 10_000.0;
let slv_iv =
crate::models::common::black_scholes::bs_implied_vol(mc_slv, f0, f0, 1.0, 1.0, true)
.expect("SLV MC price should invert");
let slv_err = (slv_iv - market_atm_1y).abs();
assert!(
slv_err < plain_err * 0.6,
"SLV residual {} bp should be well below plain {} bp (tolerance 60 %)",
slv_err * 10_000.0,
plain_err * 10_000.0,
);
let _ = (calibrate_sabr, targets_from_grid); }
}