use crate::time::daycounters::DayCounters;
use chrono::{Days, NaiveDate};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rand_distr::StandardNormal;
pub trait SimulationModel {
type State: Clone;
fn initial_state(&self) -> Self::State;
fn step(&mut self, state: &Self::State, t: f64, dt: f64) -> Self::State;
}
#[derive(Clone, Debug)]
pub struct DatedPaths<S> {
pub valuation_date: NaiveDate,
pub observation_dates: Vec<NaiveDate>,
pub paths: Vec<Vec<S>>,
}
impl<S: Clone> DatedPaths<S> {
pub fn n_paths(&self) -> usize {
self.paths.len()
}
pub fn states_at(&self, date: NaiveDate) -> Option<Vec<S>> {
let idx = self.observation_dates.binary_search(&date).ok()?;
Some(self.paths.iter().map(|p| p[idx].clone()).collect())
}
pub fn sample<F>(&self, date: NaiveDate, mut extract: F) -> Option<Vec<f64>>
where
F: FnMut(&S) -> f64,
{
let idx = self.observation_dates.binary_search(&date).ok()?;
Some(self.paths.iter().map(|p| extract(&p[idx])).collect())
}
}
pub fn simulate_at_dates<M: SimulationModel>(
model: &mut M,
valuation_date: NaiveDate,
observation_dates: &[NaiveDate],
n_paths: usize,
max_step_days: u32,
day_counter: &dyn DayCounters,
) -> DatedPaths<M::State> {
assert!(n_paths > 0, "n_paths must be > 0");
assert!(max_step_days > 0, "max_step_days must be > 0");
assert!(
!observation_dates.is_empty(),
"need at least one observation date"
);
for d in observation_dates {
assert!(
*d > valuation_date,
"observation_dates must be > valuation_date"
);
}
let mut obs: Vec<NaiveDate> = observation_dates.to_vec();
obs.sort();
let last = *obs.last().unwrap();
let mut grid: Vec<NaiveDate> = vec![valuation_date];
let mut cur = valuation_date;
while cur < last {
let next = cur
.checked_add_days(Days::new(max_step_days as u64))
.expect("date overflow");
let capped = if next > last { last } else { next };
grid.push(capped);
cur = capped;
}
for d in &obs {
grid.push(*d);
}
grid.sort();
grid.dedup();
let obs_indices: Vec<usize> = obs
.iter()
.map(|d| {
grid.binary_search(d)
.expect("observation date missing from grid")
})
.collect();
let yf: Vec<f64> = grid
.iter()
.map(|d| day_counter.year_fraction(valuation_date, *d).unwrap_or(0.0))
.collect();
let mut paths: Vec<Vec<M::State>> = vec![Vec::with_capacity(obs.len()); n_paths];
for path in paths.iter_mut() {
let mut state = model.initial_state();
let mut next_obs = 0_usize;
for i in 1..grid.len() {
let dt = yf[i] - yf[i - 1];
if dt <= 0.0 {
continue;
}
let t_mid = 0.5 * (yf[i] + yf[i - 1]);
state = model.step(&state, t_mid, dt);
while next_obs < obs_indices.len() && obs_indices[next_obs] == i {
path.push(state.clone());
next_obs += 1;
}
}
}
DatedPaths {
valuation_date,
observation_dates: obs,
paths,
}
}
pub struct BrownianMotion {
pub x_0: f64,
pub drift: f64,
pub volatility: f64,
rng: ChaCha20Rng,
}
impl BrownianMotion {
pub fn new(seed: u64) -> Self {
Self {
x_0: 0.0,
drift: 0.0,
volatility: 1.0,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}
pub fn with_drift(mut self, mu: f64) -> Self {
self.drift = mu;
self
}
pub fn with_volatility(mut self, sigma: f64) -> Self {
self.volatility = sigma;
self
}
pub fn with_initial(mut self, x_0: f64) -> Self {
self.x_0 = x_0;
self
}
}
impl SimulationModel for BrownianMotion {
type State = f64;
fn initial_state(&self) -> Self::State {
self.x_0
}
fn step(&mut self, state: &Self::State, _t: f64, dt: f64) -> Self::State {
let z: f64 = self.rng.sample(StandardNormal);
state + self.drift * dt + self.volatility * dt.sqrt() * z
}
}
pub struct GeometricBrownianMotion {
pub s_0: f64,
pub drift: f64,
pub volatility: f64,
rng: ChaCha20Rng,
}
impl GeometricBrownianMotion {
pub fn new(s_0: f64, drift: f64, volatility: f64, seed: u64) -> Self {
Self {
s_0,
drift,
volatility,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}
}
impl SimulationModel for GeometricBrownianMotion {
type State = f64;
fn initial_state(&self) -> Self::State {
self.s_0
}
fn step(&mut self, state: &Self::State, _t: f64, dt: f64) -> Self::State {
let z: f64 = self.rng.sample(StandardNormal);
let log_drift = (self.drift - 0.5 * self.volatility * self.volatility) * dt;
let diffusion = self.volatility * dt.sqrt() * z;
state * (log_drift + diffusion).exp()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::time::daycounters::actual365fixed::Actual365Fixed;
#[test]
fn dated_paths_basic_construction_and_lookup() {
let val = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let d1 = NaiveDate::from_ymd_opt(2025, 4, 1).unwrap();
let d2 = NaiveDate::from_ymd_opt(2025, 7, 1).unwrap();
let dc = Actual365Fixed::default();
let mut model = BrownianMotion::new(42);
let paths = simulate_at_dates(&mut model, val, &[d1, d2], 100, 1, &dc);
assert_eq!(paths.n_paths(), 100);
assert_eq!(paths.observation_dates, vec![d1, d2]);
let states_d1 = paths.states_at(d1).unwrap();
assert_eq!(states_d1.len(), 100);
let xs = paths.sample(d2, |x| *x).unwrap();
assert_eq!(xs.len(), 100);
}
#[test]
fn brownian_motion_variance_matches_theory() {
let val = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let horizon = NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(); let dc = Actual365Fixed::default();
let mut model = BrownianMotion::new(12345);
let paths = simulate_at_dates(&mut model, val, &[horizon], 5_000, 1, &dc);
let xs = paths.sample(horizon, |x| *x).unwrap();
let mean = xs.iter().sum::<f64>() / xs.len() as f64;
let var = xs.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / xs.len() as f64;
assert!(mean.abs() < 0.05, "mean {} not near 0", mean);
assert!((var - 1.0).abs() < 0.1, "var {} not near 1", var);
}
#[test]
fn observation_dates_captured_exactly_no_interpolation() {
let val = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let d1 = NaiveDate::from_ymd_opt(2025, 1, 10).unwrap();
let d2 = NaiveDate::from_ymd_opt(2025, 3, 15).unwrap();
let dc = Actual365Fixed::default();
let mut m1 = BrownianMotion::new(7);
let paths = simulate_at_dates(&mut m1, val, &[d1, d2], 5, 7, &dc);
assert_eq!(paths.observation_dates, vec![d1, d2]);
for path in &paths.paths {
assert_eq!(path.len(), 2);
}
}
#[test]
fn gbm_mean_matches_theory() {
let val = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let horizon = NaiveDate::from_ymd_opt(2026, 1, 1).unwrap();
let dc = Actual365Fixed::default();
let s0 = 100.0_f64;
let mu = 0.05_f64;
let sigma = 0.20_f64;
let mut model = GeometricBrownianMotion::new(s0, mu, sigma, 2024);
let paths = simulate_at_dates(&mut model, val, &[horizon], 20_000, 1, &dc);
let terms = paths.states_at(horizon).unwrap();
let mean: f64 = terms.iter().sum::<f64>() / terms.len() as f64;
let expected = s0 * mu.exp();
assert!(
(mean / expected - 1.0).abs() < 0.02,
"GBM mean {} vs {:.4} expected",
mean,
expected
);
}
}