use crate::models::common::simulation::SimulationModel;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rand_distr::StandardNormal;
use statrs::function::gamma::ln_gamma;
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct CirProcess {
pub kappa: f64,
pub theta: f64,
pub gamma: f64,
pub sigma_0: f64,
}
impl CirProcess {
pub fn mean(&self, t: f64) -> f64 {
if t <= 0.0 {
return self.sigma_0;
}
let decay = (-self.kappa * t).exp();
self.sigma_0 * decay + self.theta * (1.0 - decay)
}
pub fn variance(&self, t: f64) -> f64 {
if t <= 0.0 {
return 0.0;
}
let decay = (-self.kappa * t).exp();
let g2_over_k = self.gamma * self.gamma / self.kappa;
let one_minus = 1.0 - decay;
self.sigma_0 * g2_over_k * decay * one_minus
+ self.theta * (g2_over_k * 0.5) * one_minus * one_minus
}
pub fn sqrt_mean(&self, t: f64) -> f64 {
self.sqrt_mean_tol(t, 1.0e-14)
}
pub fn sqrt_mean_tol(&self, t: f64, tol: f64) -> f64 {
if t <= 0.0 {
return self.sigma_0.max(0.0).sqrt();
}
assert!(self.kappa > 0.0, "CIR: κ must be positive");
if self.gamma.abs() < 1.0e-14 {
return self.mean(t).max(0.0).sqrt();
}
let one_minus = 1.0 - (-self.kappa * t).exp();
let c = self.gamma * self.gamma * one_minus / (4.0 * self.kappa);
let ell = 4.0 * self.kappa * self.theta / (self.gamma * self.gamma);
let eps = 4.0 * self.kappa * self.sigma_0 * (-self.kappa * t).exp()
/ (self.gamma * self.gamma * one_minus);
if eps > 500.0 {
let m = self.mean(t).max(1.0e-300);
let v = self.variance(t);
return (m * (1.0 - v / (8.0 * m * m))).max(0.0).sqrt();
}
let half_eps = 0.5 * eps;
let a = 0.5 * ell;
let b = 0.5 * (ell + 1.0);
let mut term = (ln_gamma(b) - ln_gamma(a)).exp();
let mut sum = term;
let mut k = 0_u32;
loop {
let k_next = k + 1;
let ratio = half_eps * (b + k as f64) / ((k_next as f64) * (a + k as f64));
term *= ratio;
sum += term;
if term.abs() <= tol * sum.abs() || k_next > 2000 {
break;
}
k = k_next;
}
(2.0 * c).sqrt() * (-0.5 * eps).exp() * sum
}
pub fn sqrt_mean_infinity(&self) -> f64 {
assert!(self.kappa > 0.0, "CIR: κ must be positive");
assert!(self.gamma > 0.0, "CIR: γ must be positive");
let ell = 4.0 * self.kappa * self.theta / (self.gamma * self.gamma);
let gamma_ratio = (ln_gamma(0.5 * (ell + 1.0)) - ln_gamma(0.5 * ell)).exp();
self.gamma / (2.0 * self.kappa).sqrt() * gamma_ratio
}
pub fn sqrt_mean_proxy(&self) -> SqrtMeanProxy {
let beta1 = self.sqrt_mean_infinity();
let beta2 = self.sigma_0.max(0.0).sqrt() - beta1;
SqrtMeanProxy {
beta1,
beta2,
beta3: self.kappa,
}
}
pub fn feller_satisfied(&self) -> bool {
2.0 * self.kappa * self.theta >= self.gamma * self.gamma
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct SqrtMeanProxy {
pub beta1: f64,
pub beta2: f64,
pub beta3: f64,
}
impl SqrtMeanProxy {
pub fn eval(&self, t: f64) -> f64 {
self.beta1 + self.beta2 * (-self.beta3 * t).exp()
}
}
pub struct CirSimulator {
pub process: CirProcess,
rng: ChaCha20Rng,
}
impl CirSimulator {
pub fn new(process: CirProcess, seed: u64) -> Self {
Self {
process,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}
}
impl SimulationModel for CirSimulator {
type State = f64;
fn initial_state(&self) -> Self::State {
self.process.sigma_0
}
fn step(&mut self, state: &Self::State, _t: f64, dt: f64) -> Self::State {
let z: f64 = self.rng.sample(StandardNormal);
let sigma = state.max(0.0);
(sigma
+ self.process.kappa * (self.process.theta - sigma) * dt
+ self.process.gamma * sigma.sqrt() * dt.sqrt() * z)
.max(0.0)
}
}
#[cfg(test)]
mod tests {
use super::{CirProcess, CirSimulator};
use crate::models::common::simulation::simulate_at_dates;
use crate::time::daycounters::actual365fixed::Actual365Fixed;
use chrono::NaiveDate;
fn grzelak_params() -> CirProcess {
CirProcess {
kappa: 0.5,
theta: 0.1,
gamma: 0.3,
sigma_0: 0.1,
}
}
#[test]
fn mean_boundary_conditions() {
let p = grzelak_params();
assert!((p.mean(0.0) - p.sigma_0).abs() < 1e-15);
assert!((p.mean(1_000.0) - p.theta).abs() < 1e-10);
}
#[test]
fn variance_boundary_conditions() {
let p = grzelak_params();
assert!(p.variance(0.0).abs() < 1e-15);
let expected = p.theta * p.gamma * p.gamma / (2.0 * p.kappa);
assert!((p.variance(1_000.0) - expected).abs() < 1e-10);
}
#[test]
fn sqrt_mean_at_zero_is_sqrt_sigma0() {
let p = grzelak_params();
assert!((p.sqrt_mean(0.0) - p.sigma_0.sqrt()).abs() < 1e-15);
}
#[test]
fn sqrt_mean_bounded_above_by_sqrt_mean_of_sigma() {
let p = grzelak_params();
for &t in &[0.25_f64, 1.0, 5.0, 10.0, 30.0] {
let sm = p.sqrt_mean(t);
let m = p.mean(t);
assert!(
sm * sm <= m + 1e-12,
"t={}: E[√σ]² = {} > E[σ] = {}",
t,
sm * sm,
m
);
assert!(sm > 0.0);
}
}
#[test]
fn sqrt_mean_converges_to_infinity_limit() {
let p = grzelak_params();
let sm_inf = p.sqrt_mean_infinity();
assert!((p.sqrt_mean(50.0) - sm_inf).abs() < 1e-8);
assert!(sm_inf < p.theta.sqrt());
assert!(sm_inf > 0.0);
}
#[test]
fn proxy_matches_at_anchors_and_rough_interior() {
let p = grzelak_params();
let proxy = p.sqrt_mean_proxy();
assert!((proxy.eval(0.0) - p.sigma_0.sqrt()).abs() < 1e-15);
assert!((proxy.eval(100.0) - p.sqrt_mean_infinity()).abs() < 1e-12);
for &t in &[0.25_f64, 1.0, 5.0, 10.0] {
let exact = p.sqrt_mean(t);
let approx = proxy.eval(t);
let rel = (approx - exact).abs() / exact;
assert!(
rel < 0.05,
"proxy error at t={}: {:.4}% (exact={}, proxy={})",
t,
rel * 100.0,
exact,
approx
);
}
}
#[test]
fn grzelak_params_satisfy_feller_by_small_margin() {
let p = grzelak_params();
assert!(p.feller_satisfied());
let two_k_theta = 2.0 * p.kappa * p.theta;
let g2 = p.gamma * p.gamma;
assert!((two_k_theta - g2 - 0.01).abs() < 1e-15);
}
#[test]
fn feller_true_when_2ks_ge_g2() {
let p = CirProcess {
kappa: 2.0,
theta: 0.04,
gamma: 0.3,
sigma_0: 0.04,
};
assert!(p.feller_satisfied());
}
#[test]
fn tight_tolerance_matches_default_tolerance() {
let p = grzelak_params();
for &t in &[0.5_f64, 2.0, 7.5] {
let default = p.sqrt_mean(t);
let tight = p.sqrt_mean_tol(t, 1.0e-16);
assert!(
(default - tight).abs() < 1.0e-12,
"t={}: default {} tight {}",
t,
default,
tight
);
}
}
#[test]
fn cir_simulator_mean_matches_closed_form() {
let p = CirProcess {
kappa: 1.5,
theta: 0.04,
gamma: 0.25,
sigma_0: 0.04,
};
let mut sim = CirSimulator::new(p, 2024);
let val = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let horizon = NaiveDate::from_ymd_opt(2026, 1, 1).unwrap();
let dc = Actual365Fixed::default();
let paths = simulate_at_dates(&mut sim, val, &[horizon], 10_000, 1, &dc);
let terminals = paths.states_at(horizon).unwrap();
let mean: f64 = terminals.iter().sum::<f64>() / terminals.len() as f64;
let expected = p.mean(1.0);
assert!(
(mean - expected).abs() < 1.0e-3,
"MC mean {} vs closed form {}",
mean,
expected
);
}
#[test]
fn cir_simulator_trivial_step_invariants() {
use crate::models::common::simulation::SimulationModel;
let p = CirProcess {
kappa: 1.0,
theta: 0.1,
gamma: 0.1,
sigma_0: 0.05,
};
let sim = CirSimulator::new(p, 1);
assert_eq!(sim.initial_state(), 0.05);
}
}