use crate::models::common::simulation::SimulationModel;
use crate::models::forex::sabr::{SabrParams, SabrSimulator, SabrState};
use crate::models::forex::sabr_effective::PiecewiseConstant;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rand_distr::StandardNormal;
#[derive(Clone, Debug, PartialEq)]
pub struct TimeDependentSabrParams {
pub alpha: PiecewiseConstant,
pub rho: PiecewiseConstant,
pub nu: PiecewiseConstant,
pub beta: f64,
pub forward_0: f64,
}
impl TimeDependentSabrParams {
pub fn new(
alpha: PiecewiseConstant,
rho: PiecewiseConstant,
nu: PiecewiseConstant,
beta: f64,
forward_0: f64,
) -> Self {
assert_eq!(alpha.knots, rho.knots, "α and ρ schedules must share knots");
assert_eq!(alpha.knots, nu.knots, "α and ν schedules must share knots");
assert!((0.0..=1.0).contains(&beta), "β must be in [0, 1]");
assert!(forward_0 > 0.0, "F₀ must be positive");
for &a in &alpha.values {
assert!(a > 0.0, "α must be positive on every segment (got {})", a);
}
for &n in &nu.values {
assert!(n >= 0.0, "ν must be non-negative on every segment");
}
for &r in &rho.values {
assert!(
r > -1.0 && r < 1.0,
"ρ must be in (−1, 1) on every segment (got {})",
r
);
}
Self {
alpha,
rho,
nu,
beta,
forward_0,
}
}
pub fn at(&self, t: f64) -> SabrParams {
SabrParams {
alpha: self.alpha.at(t),
beta: self.beta,
rho: self.rho.at(t),
nu: self.nu.at(t),
}
}
pub fn horizon(&self) -> f64 {
self.alpha.final_time()
}
}
pub struct TimeDependentSabrSimulator {
pub params: TimeDependentSabrParams,
rng: ChaCha20Rng,
}
impl TimeDependentSabrSimulator {
pub fn new(params: TimeDependentSabrParams, seed: u64) -> Self {
Self {
params,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}
pub fn step_with_noise(
&mut self,
state: &SabrState,
t_mid: f64,
dt: f64,
) -> (SabrState, [f64; 2]) {
assert!(dt > 0.0);
let p = self.params.at(t_mid);
let sqrt_one_minus_rho_sq = (1.0 - p.rho * p.rho).sqrt();
let z1: f64 = self.rng.sample(StandardNormal);
let z2: f64 = self.rng.sample(StandardNormal);
let sqrt_dt = dt.sqrt();
let dw_f = sqrt_dt * z1;
let dw_a = sqrt_dt * (p.rho * z1 + sqrt_one_minus_rho_sq * z2);
let f = state.forward.max(0.0);
let alpha = state.vol.max(0.0);
let diffusion_f = alpha * f.powf(p.beta) * dw_f;
let new_forward = (f + diffusion_f).max(0.0);
let log_drift = -0.5 * p.nu * p.nu * dt;
let log_diffusion = p.nu * dw_a;
let new_vol = alpha * (log_drift + log_diffusion).exp();
(
SabrState {
forward: new_forward,
vol: new_vol,
},
[dw_f, dw_a],
)
}
pub fn simulate(&mut self, t_end: f64, n_steps: usize, n_paths: usize) -> Vec<SabrState> {
assert!(n_steps > 0 && n_paths > 0 && t_end > 0.0);
let dt = t_end / n_steps as f64;
let mut out = Vec::with_capacity(n_paths);
let alpha0 = self.params.alpha.at(0.0);
for _ in 0..n_paths {
let mut state = SabrState {
forward: self.params.forward_0,
vol: alpha0,
};
let mut t = 0.0_f64;
for _ in 0..n_steps {
let t_mid = t + 0.5 * dt;
let (next, _) = self.step_with_noise(&state, t_mid, dt);
state = next;
t += dt;
}
out.push(state);
}
out
}
pub fn reduces_to_constant(&self) -> Option<SabrSimulator> {
let a_vals = &self.params.alpha.values;
let r_vals = &self.params.rho.values;
let n_vals = &self.params.nu.values;
let all_eq = |vs: &[f64]| vs.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
if !all_eq(a_vals) || !all_eq(r_vals) || !all_eq(n_vals) {
return None;
}
let p = SabrParams::new(a_vals[0], self.params.beta, r_vals[0], n_vals[0]);
Some(SabrSimulator::new(p, self.params.forward_0, 0))
}
}
impl SimulationModel for TimeDependentSabrSimulator {
type State = SabrState;
fn initial_state(&self) -> Self::State {
SabrState {
forward: self.params.forward_0,
vol: self.params.alpha.at(0.0),
}
}
fn step(&mut self, state: &Self::State, t: f64, dt: f64) -> Self::State {
self.step_with_noise(state, t, dt).0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::common::black_scholes::bs_implied_vol;
use crate::models::forex::sabr::hagan_atm_vol;
fn flat_params(
alpha: f64,
rho: f64,
nu: f64,
beta: f64,
f0: f64,
t: f64,
) -> TimeDependentSabrParams {
TimeDependentSabrParams::new(
PiecewiseConstant::constant(t, alpha),
PiecewiseConstant::constant(t, rho),
PiecewiseConstant::constant(t, nu),
beta,
f0,
)
}
#[test]
fn flat_schedules_match_constant_param_simulator() {
let p_td = flat_params(0.15, -0.30, 0.40, 0.5, 1.30, 1.0);
let mut td = TimeDependentSabrSimulator::new(p_td.clone(), 2024);
let p = SabrParams::new(0.15, 0.5, -0.30, 0.40);
let mut cst = SabrSimulator::new(p, 1.30, 2024);
let t_end = 1.0_f64;
let n_steps = 200_usize;
let n_paths = 50_usize;
let td_terms = td.simulate(t_end, n_steps, n_paths);
let cst_terms = cst.simulate(t_end, n_steps, n_paths);
for (a, b) in td_terms.iter().zip(cst_terms.iter()) {
assert_eq!(a, b, "td-sim diverges from constant sim");
}
}
#[test]
fn reduces_to_constant_detects_flatness() {
let flat = flat_params(0.15, -0.30, 0.40, 0.5, 1.30, 1.0);
assert!(
TimeDependentSabrSimulator::new(flat, 0)
.reduces_to_constant()
.is_some()
);
let varying = TimeDependentSabrParams::new(
PiecewiseConstant::new(vec![0.0, 0.5, 1.0], vec![0.15, 0.12]),
PiecewiseConstant::new(vec![0.0, 0.5, 1.0], vec![-0.30, -0.30]),
PiecewiseConstant::new(vec![0.0, 0.5, 1.0], vec![0.40, 0.40]),
0.5,
1.30,
);
assert!(
TimeDependentSabrSimulator::new(varying, 0)
.reduces_to_constant()
.is_none()
);
}
#[test]
fn forward_is_martingale_under_time_dependent_schedules() {
let knots = vec![0.0, 0.5, 1.0];
let p = TimeDependentSabrParams::new(
PiecewiseConstant::new(knots.clone(), vec![0.18, 0.12]),
PiecewiseConstant::new(knots.clone(), vec![-0.40, -0.20]),
PiecewiseConstant::new(knots, vec![0.50, 0.30]),
0.5,
1.30,
);
let f0 = p.forward_0;
let mut sim = TimeDependentSabrSimulator::new(p, 99);
let terms = sim.simulate(1.0, 200, 10_000);
let mean: f64 = terms.iter().map(|s| s.forward).sum::<f64>() / 10_000.0;
let rel = (mean - f0).abs() / f0;
assert!(
rel < 0.01,
"E[F(T)] = {}, F₀ = {}, rel {:.4}",
mean,
f0,
rel
);
}
#[test]
fn atm_iv_matches_effective_parameter_hagan() {
use crate::models::forex::sabr_effective::{
effective_correlation, effective_term_structure, effective_vol_vol,
};
let knots = vec![0.0, 0.5, 1.0];
let alpha = PiecewiseConstant::new(knots.clone(), vec![0.15, 0.15]);
let rho = PiecewiseConstant::new(knots.clone(), vec![-0.30, -0.30]);
let nu = PiecewiseConstant::new(knots.clone(), vec![0.45, 0.30]);
let p = TimeDependentSabrParams::new(alpha.clone(), rho.clone(), nu.clone(), 0.5, 1.30);
let f0 = p.forward_0;
let expiry = 1.0_f64;
let mut sim = TimeDependentSabrSimulator::new(p, 31_337);
let terms = sim.simulate(expiry, 400, 40_000);
let mc_price: f64 = terms.iter().map(|s| (s.forward - f0).max(0.0)).sum::<f64>() / 40_000.0;
let mc_iv = bs_implied_vol(mc_price, f0, f0, expiry, 1.0, true).expect("BS inversion");
let gamma_tilde = effective_vol_vol(&nu, &alpha, expiry);
let omega_tilde = effective_term_structure(&nu, &alpha, expiry);
let rho_tilde = effective_correlation(&nu, &alpha, &rho, expiry);
let eff = SabrParams::new(omega_tilde, 0.5, rho_tilde, gamma_tilde);
let hagan_iv = hagan_atm_vol(&eff, f0, expiry);
let diff = (mc_iv - hagan_iv).abs();
assert!(
diff < 0.01,
"MC ATM IV {} vs effective-Hagan {} (diff {:.4})",
mc_iv,
hagan_iv,
diff
);
}
#[test]
#[should_panic(expected = "share knots")]
fn misaligned_schedules_panic() {
TimeDependentSabrParams::new(
PiecewiseConstant::new(vec![0.0, 1.0, 2.0], vec![0.15, 0.15]),
PiecewiseConstant::new(vec![0.0, 0.5, 2.0], vec![-0.30, -0.30]),
PiecewiseConstant::new(vec![0.0, 1.0, 2.0], vec![0.40, 0.40]),
0.5,
1.30,
);
}
}