use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use rand::Rng;
use rand_distr::StandardNormal;
use crate::arrow_util::{const_string, f64_array, record_batch, timestamp_grid_ms};
use crate::error::{DatagenError, DatagenResult};
use crate::rng::make_rng;
use crate::schema::price_schema;
#[derive(Clone, Debug)]
pub struct GbmConfig {
pub s0: f64,
pub mu: f64,
pub sigma: f64,
pub dt: f64,
pub n_steps: usize,
pub symbol: String,
pub start_ms: i64,
pub step_ms: i64,
pub seed: Option<u64>,
}
impl Default for GbmConfig {
fn default() -> Self {
Self {
s0: 100.0,
mu: 0.05,
sigma: 0.2,
dt: 1.0 / 252.0,
n_steps: 252,
symbol: "SYM".into(),
start_ms: 0,
step_ms: 86_400_000,
seed: None,
}
}
}
pub struct GbmGenerator {
cfg: GbmConfig,
}
impl GbmGenerator {
pub fn new(cfg: GbmConfig) -> DatagenResult<Self> {
if cfg.s0 <= 0.0 {
return Err(DatagenError::InvalidParameter("s0 must be > 0".into()));
}
if cfg.sigma < 0.0 {
return Err(DatagenError::InvalidParameter("sigma must be >= 0".into()));
}
if cfg.dt <= 0.0 {
return Err(DatagenError::InvalidParameter("dt must be > 0".into()));
}
Ok(Self { cfg })
}
pub fn simulate(&self) -> Vec<f64> {
let mut rng = make_rng(self.cfg.seed);
let n = self.cfg.n_steps;
let dt = self.cfg.dt;
let drift = (self.cfg.mu - 0.5 * self.cfg.sigma * self.cfg.sigma) * dt;
let diffusion = self.cfg.sigma * dt.sqrt();
let mut prices = Vec::with_capacity(n + 1);
let mut s = self.cfg.s0;
prices.push(s);
for _ in 0..n {
let z: f64 = rng.sample(StandardNormal);
s *= (drift + diffusion * z).exp();
prices.push(s);
}
prices
}
pub fn record_batch(&self) -> DatagenResult<RecordBatch> {
let prices = self.simulate();
let n = prices.len();
let ts = timestamp_grid_ms(self.cfg.start_ms, self.cfg.step_ms, n);
let sym = const_string(&self.cfg.symbol, n);
let columns: Vec<ArrayRef> = vec![Arc::new(ts), Arc::new(sym), Arc::new(f64_array(prices))];
record_batch(price_schema(), columns)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deterministic_with_seed() {
let cfg = GbmConfig {
seed: Some(42),
n_steps: 10,
..GbmConfig::default()
};
let a = GbmGenerator::new(cfg.clone()).unwrap().simulate();
let b = GbmGenerator::new(cfg).unwrap().simulate();
assert_eq!(a, b);
assert_eq!(a.len(), 11);
assert!(a.iter().all(|x| x.is_finite() && *x > 0.0));
}
#[test]
fn rejects_bad_params() {
assert!(GbmGenerator::new(GbmConfig {
s0: 0.0,
..GbmConfig::default()
})
.is_err());
assert!(GbmGenerator::new(GbmConfig {
sigma: -1.0,
..GbmConfig::default()
})
.is_err());
assert!(GbmGenerator::new(GbmConfig {
dt: 0.0,
..GbmConfig::default()
})
.is_err());
}
#[test]
fn record_batch_shape() {
let g = GbmGenerator::new(GbmConfig {
seed: Some(1),
n_steps: 5,
..GbmConfig::default()
})
.unwrap();
let rb = g.record_batch().unwrap();
assert_eq!(rb.num_columns(), 3);
assert_eq!(rb.num_rows(), 6);
}
}