use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::rand_distributions::Distribution;
#[derive(Debug, Clone)]
pub struct OrderBook {
pub bids: Vec<(f64, f64)>,
pub asks: Vec<(f64, f64)>,
pub mid_price: f64,
pub spread: f64,
}
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
fn normal_dist(
std: f64,
) -> Result<scirs2_core::random::rand_distributions::Normal<f64>> {
scirs2_core::random::rand_distributions::Normal::new(0.0_f64, std).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution creation failed: {e}"))
})
}
fn uniform_dist(
lo: f64,
hi: f64,
) -> Result<scirs2_core::random::rand_distributions::Uniform<f64>> {
scirs2_core::random::rand_distributions::Uniform::new(lo, hi).map_err(|e| {
DatasetsError::ComputationError(format!("Uniform distribution creation failed: {e}"))
})
}
fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>> {
let n = a.nrows();
if n != a.ncols() {
return Err(DatasetsError::InvalidFormat(
"cholesky_lower: matrix must be square".to_string(),
));
}
let mut l = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = 0.0_f64;
for k in 0..j {
sum += l[[i, k]] * l[[j, k]];
}
if i == j {
let diag = a[[i, i]] - sum;
if diag <= 0.0 {
return Err(DatasetsError::ComputationError(
"cholesky_lower: matrix is not positive definite".to_string(),
));
}
l[[i, j]] = diag.sqrt();
} else {
let ljj = l[[j, j]];
if ljj.abs() < 1e-15 {
return Err(DatasetsError::ComputationError(
"cholesky_lower: near-zero diagonal in L".to_string(),
));
}
l[[i, j]] = (a[[i, j]] - sum) / ljj;
}
}
}
Ok(l)
}
pub fn gbm_prices(
s0: f64,
mu: f64,
sigma: f64,
dt: f64,
n_steps: usize,
n_paths: usize,
seed: u64,
) -> Result<Array2<f64>> {
if s0 <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"gbm_prices: s0 must be > 0".to_string(),
));
}
if sigma <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"gbm_prices: sigma must be > 0".to_string(),
));
}
if dt <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"gbm_prices: dt must be > 0".to_string(),
));
}
if n_steps == 0 {
return Err(DatasetsError::InvalidFormat(
"gbm_prices: n_steps must be > 0".to_string(),
));
}
if n_paths == 0 {
return Err(DatasetsError::InvalidFormat(
"gbm_prices: n_paths must be > 0".to_string(),
));
}
let drift = (mu - 0.5 * sigma * sigma) * dt;
let diffusion = sigma * dt.sqrt();
let mut rng = make_rng(seed);
let z_dist = normal_dist(1.0)?;
let mut out = Array2::zeros((n_paths, n_steps + 1));
for p in 0..n_paths {
out[[p, 0]] = s0;
for step in 0..n_steps {
let z = z_dist.sample(&mut rng);
out[[p, step + 1]] = out[[p, step]] * (drift + diffusion * z).exp();
}
}
Ok(out)
}
pub fn correlated_gbm(
s0: &[f64],
mu: &[f64],
correlation: &Array2<f64>,
sigma: &[f64],
dt: f64,
n_steps: usize,
seed: u64,
) -> Result<Array2<f64>> {
let n_assets = s0.len();
if mu.len() != n_assets || sigma.len() != n_assets {
return Err(DatasetsError::InvalidFormat(
"correlated_gbm: s0, mu, and sigma must have the same length".to_string(),
));
}
if correlation.nrows() != n_assets || correlation.ncols() != n_assets {
return Err(DatasetsError::InvalidFormat(
"correlated_gbm: correlation matrix dimensions must match n_assets".to_string(),
));
}
for (i, &s) in s0.iter().enumerate() {
if s <= 0.0 {
return Err(DatasetsError::InvalidFormat(format!(
"correlated_gbm: s0[{i}] must be > 0"
)));
}
}
for (i, &v) in sigma.iter().enumerate() {
if v <= 0.0 {
return Err(DatasetsError::InvalidFormat(format!(
"correlated_gbm: sigma[{i}] must be > 0"
)));
}
}
if dt <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"correlated_gbm: dt must be > 0".to_string(),
));
}
if n_steps == 0 {
return Err(DatasetsError::InvalidFormat(
"correlated_gbm: n_steps must be > 0".to_string(),
));
}
let mut cov = Array2::zeros((n_assets, n_assets));
for i in 0..n_assets {
for j in 0..n_assets {
cov[[i, j]] = correlation[[i, j]] * sigma[i] * sigma[j];
}
}
let l = cholesky_lower(&cov)?;
let mut rng = make_rng(seed);
let z_dist = normal_dist(1.0)?;
let mut out = Array2::zeros((n_steps + 1, n_assets));
for a in 0..n_assets {
out[[0, a]] = s0[a];
}
let sqrt_dt = dt.sqrt();
for step in 0..n_steps {
let z_raw: Vec<f64> = (0..n_assets).map(|_| z_dist.sample(&mut rng)).collect();
for a in 0..n_assets {
let mut dw = 0.0_f64;
for k in 0..n_assets {
dw += l[[a, k]] * z_raw[k];
}
dw *= sqrt_dt;
let drift = (mu[a] - 0.5 * sigma[a] * sigma[a]) * dt;
out[[step + 1, a]] = out[[step, a]] * (drift + dw).exp();
}
}
Ok(out)
}
pub fn synthetic_order_book(
mid_price: f64,
n_levels: usize,
tick_size: f64,
seed: u64,
) -> Result<OrderBook> {
if mid_price <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"synthetic_order_book: mid_price must be > 0".to_string(),
));
}
if n_levels == 0 {
return Err(DatasetsError::InvalidFormat(
"synthetic_order_book: n_levels must be >= 1".to_string(),
));
}
if tick_size <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"synthetic_order_book: tick_size must be > 0".to_string(),
));
}
let mut rng = make_rng(seed);
let noise = uniform_dist(0.8, 1.2)?;
let best_bid = mid_price - 0.5 * tick_size;
let best_ask = mid_price + 0.5 * tick_size;
let mut bids = Vec::with_capacity(n_levels);
let mut asks = Vec::with_capacity(n_levels);
for k in 0..n_levels {
let base_qty = 10.0 * (1.5_f64).powi(k as i32);
let qty = base_qty * noise.sample(&mut rng);
bids.push((best_bid - k as f64 * tick_size, qty));
asks.push((best_ask + k as f64 * tick_size, qty * noise.sample(&mut rng)));
}
bids.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
asks.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let best_bid_px = bids.first().map(|b| b.0).unwrap_or(0.0);
let best_ask_px = asks.first().map(|a| a.0).unwrap_or(0.0);
let computed_mid = (best_bid_px + best_ask_px) / 2.0;
let spread = best_ask_px - best_bid_px;
Ok(OrderBook {
bids,
asks,
mid_price: computed_mid,
spread,
})
}
pub fn log_returns(prices: &Array1<f64>) -> Result<Array1<f64>> {
if prices.len() < 2 {
return Err(DatasetsError::InvalidFormat(
"log_returns: prices must have at least 2 elements".to_string(),
));
}
let mut returns = Vec::with_capacity(prices.len() - 1);
for i in 1..prices.len() {
let prev = prices[i - 1];
let curr = prices[i];
if prev <= 0.0 || curr <= 0.0 {
return Err(DatasetsError::InvalidFormat(format!(
"log_returns: non-positive price at index {}: prev={prev}, curr={curr}",
i
)));
}
returns.push((curr / prev).ln());
}
Ok(Array1::from_vec(returns))
}
pub fn rolling_volatility(returns: &Array1<f64>, window: usize) -> Result<Array1<f64>> {
if window < 2 {
return Err(DatasetsError::InvalidFormat(
"rolling_volatility: window must be >= 2".to_string(),
));
}
if returns.is_empty() {
return Err(DatasetsError::InvalidFormat(
"rolling_volatility: returns must not be empty".to_string(),
));
}
let n = returns.len();
let data: Vec<f64> = returns.to_vec();
let mut out = vec![f64::NAN; n];
for i in (window - 1)..n {
let slice = &data[i + 1 - window..=i];
let mean = slice.iter().sum::<f64>() / slice.len() as f64;
let var = slice.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
/ (slice.len() as f64 - 1.0);
out[i] = var.sqrt();
}
Ok(Array1::from_vec(out))
}
pub fn rolling_sharpe(returns: &Array1<f64>, window: usize, risk_free: f64) -> Result<Array1<f64>> {
if window < 2 {
return Err(DatasetsError::InvalidFormat(
"rolling_sharpe: window must be >= 2".to_string(),
));
}
if returns.is_empty() {
return Err(DatasetsError::InvalidFormat(
"rolling_sharpe: returns must not be empty".to_string(),
));
}
let n = returns.len();
let data: Vec<f64> = returns.to_vec();
let mut out = vec![f64::NAN; n];
for i in (window - 1)..n {
let slice = &data[i + 1 - window..=i];
let mean = slice.iter().sum::<f64>() / slice.len() as f64;
let var = slice.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
/ (slice.len() as f64 - 1.0);
let std_dev = var.sqrt();
if std_dev < 1e-15 {
out[i] = 0.0;
} else {
out[i] = (mean - risk_free) / std_dev;
}
}
Ok(Array1::from_vec(out))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_gbm_shape() {
let paths = gbm_prices(100.0, 0.05, 0.2, 1.0 / 252.0, 252, 10, 42)
.expect("gbm failed");
assert_eq!(paths.nrows(), 10);
assert_eq!(paths.ncols(), 253); }
#[test]
fn test_gbm_initial_price() {
let paths = gbm_prices(50.0, 0.0, 0.1, 0.01, 100, 5, 1).expect("gbm failed");
for p in 0..5 {
assert!((paths[[p, 0]] - 50.0).abs() < 1e-12, "Initial price wrong");
}
}
#[test]
fn test_gbm_positive_prices() {
let paths = gbm_prices(100.0, 0.05, 0.4, 1.0 / 252.0, 252, 50, 7)
.expect("gbm failed");
for p in 0..paths.nrows() {
for t in 0..paths.ncols() {
assert!(paths[[p, t]] > 0.0, "Non-positive price at ({p},{t})");
}
}
}
#[test]
fn test_gbm_log_normal_distribution() {
let n_paths = 10_000;
let mu = 0.05_f64;
let sigma = 0.2_f64;
let dt = 1.0 / 252.0_f64;
let n_steps = 252;
let paths = gbm_prices(100.0, mu, sigma, dt, n_steps, n_paths, 99)
.expect("gbm failed");
let t = n_steps as f64 * dt;
let expected_mean = (mu - 0.5 * sigma * sigma) * t;
let log_returns_final: Vec<f64> = (0..n_paths)
.map(|p| (paths[[p, n_steps]] / paths[[p, 0]]).ln())
.collect();
let sample_mean =
log_returns_final.iter().sum::<f64>() / log_returns_final.len() as f64;
let tol = 0.03; assert!(
(sample_mean - expected_mean).abs() < tol,
"sample_mean={sample_mean:.4}, expected={expected_mean:.4}"
);
}
#[test]
fn test_gbm_reproducibility() {
let a = gbm_prices(100.0, 0.0, 0.2, 0.01, 50, 3, 42).expect("gbm failed");
let b = gbm_prices(100.0, 0.0, 0.2, 0.01, 50, 3, 42).expect("gbm failed");
assert_eq!(a, b);
}
#[test]
fn test_gbm_error_negative_s0() {
assert!(gbm_prices(-100.0, 0.05, 0.2, 0.01, 10, 1, 0).is_err());
}
#[test]
fn test_gbm_error_zero_sigma() {
assert!(gbm_prices(100.0, 0.05, 0.0, 0.01, 10, 1, 0).is_err());
}
#[test]
fn test_correlated_gbm_shape() {
let corr = array![[1.0, 0.6], [0.6, 1.0]];
let paths = correlated_gbm(
&[100.0, 50.0],
&[0.05, 0.08],
&corr,
&[0.2, 0.3],
1.0 / 252.0,
252,
1,
)
.expect("correlated_gbm failed");
assert_eq!(paths.nrows(), 253);
assert_eq!(paths.ncols(), 2);
}
#[test]
fn test_correlated_gbm_initial_prices() {
let corr = array![[1.0, 0.5], [0.5, 1.0]];
let s0 = &[80.0, 120.0];
let paths = correlated_gbm(s0, &[0.0, 0.0], &corr, &[0.1, 0.15], 0.01, 100, 3)
.expect("correlated_gbm failed");
assert!((paths[[0, 0]] - s0[0]).abs() < 1e-12);
assert!((paths[[0, 1]] - s0[1]).abs() < 1e-12);
}
#[test]
fn test_correlated_gbm_positive_prices() {
let corr = array![[1.0, 0.8], [0.8, 1.0]];
let paths =
correlated_gbm(&[50.0, 75.0], &[0.1, 0.05], &corr, &[0.25, 0.2], 0.01, 200, 5)
.expect("correlated_gbm failed");
for i in 0..paths.nrows() {
for j in 0..paths.ncols() {
assert!(paths[[i, j]] > 0.0, "Non-positive price at ({i},{j})");
}
}
}
#[test]
fn test_correlated_gbm_correlation_preserved() {
let n_steps = 2000;
let corr = array![[1.0, 0.9], [0.9, 1.0]];
let paths = correlated_gbm(
&[100.0, 100.0],
&[0.0, 0.0],
&corr,
&[0.2, 0.2],
1.0 / 252.0,
n_steps,
42,
)
.expect("correlated_gbm failed");
let n = paths.nrows() - 1;
let r0: Vec<f64> = (0..n).map(|t| (paths[[t + 1, 0]] / paths[[t, 0]]).ln()).collect();
let r1: Vec<f64> = (0..n).map(|t| (paths[[t + 1, 1]] / paths[[t, 1]]).ln()).collect();
let mean0 = r0.iter().sum::<f64>() / n as f64;
let mean1 = r1.iter().sum::<f64>() / n as f64;
let cov: f64 =
r0.iter().zip(r1.iter()).map(|(a, b)| (a - mean0) * (b - mean1)).sum::<f64>()
/ n as f64;
let var0 = r0.iter().map(|a| (a - mean0).powi(2)).sum::<f64>() / n as f64;
let var1 = r1.iter().map(|b| (b - mean1).powi(2)).sum::<f64>() / n as f64;
let measured_corr = cov / (var0.sqrt() * var1.sqrt());
assert!(
(measured_corr - 0.9).abs() < 0.05,
"Expected correlation ≈ 0.9, got {measured_corr:.4}"
);
}
#[test]
fn test_correlated_gbm_error_mismatched_inputs() {
let corr = array![[1.0, 0.5], [0.5, 1.0]];
assert!(correlated_gbm(
&[100.0],
&[0.05, 0.08],
&corr,
&[0.2, 0.3],
0.01,
10,
0
)
.is_err());
}
#[test]
fn test_order_book_levels() {
let ob = synthetic_order_book(100.0, 5, 0.01, 42).expect("ob failed");
assert_eq!(ob.bids.len(), 5);
assert_eq!(ob.asks.len(), 5);
}
#[test]
fn test_order_book_spread_positive() {
let ob = synthetic_order_book(200.0, 10, 0.5, 0).expect("ob failed");
assert!(ob.spread > 0.0, "Spread must be positive");
}
#[test]
fn test_order_book_bid_ask_sorted() {
let ob = synthetic_order_book(100.0, 5, 0.01, 7).expect("ob failed");
for w in ob.bids.windows(2) {
assert!(w[0].0 >= w[1].0, "Bids not descending");
}
for w in ob.asks.windows(2) {
assert!(w[0].0 <= w[1].0, "Asks not ascending");
}
}
#[test]
fn test_order_book_best_bid_below_ask() {
let ob = synthetic_order_book(100.0, 3, 0.01, 1).expect("ob failed");
let best_bid = ob.bids[0].0;
let best_ask = ob.asks[0].0;
assert!(best_bid < best_ask, "Best bid must be below best ask");
}
#[test]
fn test_order_book_error_zero_levels() {
assert!(synthetic_order_book(100.0, 0, 0.01, 0).is_err());
}
#[test]
fn test_log_returns_shape() {
let prices = array![100.0, 110.0, 105.0, 115.0];
let r = log_returns(&prices).expect("log_returns failed");
assert_eq!(r.len(), 3);
}
#[test]
fn test_log_returns_exponential_growth() {
let a = 0.1_f64;
let n = 10_usize;
let prices = Array1::from_vec((0..n).map(|i| (a * i as f64).exp()).collect());
let r = log_returns(&prices).expect("log_returns failed");
for i in 0..r.len() {
assert!((r[i] - a).abs() < 1e-10, "log_return[{i}] = {} expected {a}", r[i]);
}
}
#[test]
fn test_log_returns_error_too_short() {
let prices = array![100.0];
assert!(log_returns(&prices).is_err());
}
#[test]
fn test_log_returns_error_non_positive() {
let prices = array![100.0, 0.0, 50.0];
assert!(log_returns(&prices).is_err());
}
#[test]
fn test_rolling_volatility_shape() {
let prices = array![100.0, 102.0, 101.0, 103.0, 105.0, 104.0];
let r = log_returns(&prices).expect("lr");
let vol = rolling_volatility(&r, 3).expect("rvol");
assert_eq!(vol.len(), r.len());
}
#[test]
fn test_rolling_volatility_initial_nan() {
let prices = array![100.0, 102.0, 101.0, 103.0, 105.0];
let r = log_returns(&prices).expect("lr");
let vol = rolling_volatility(&r, 3).expect("rvol");
assert!(vol[0].is_nan());
assert!(vol[1].is_nan());
assert!(!vol[2].is_nan());
}
#[test]
fn test_rolling_volatility_constant_returns_zero() {
let r = Array1::from_vec(vec![0.01; 10]);
let vol = rolling_volatility(&r, 4).expect("rvol");
for i in 3..10 {
assert!(vol[i].abs() < 1e-10, "vol[{i}]={}", vol[i]);
}
}
#[test]
fn test_rolling_volatility_error_window_too_small() {
let r = Array1::from_vec(vec![0.01; 10]);
assert!(rolling_volatility(&r, 1).is_err());
}
#[test]
fn test_rolling_sharpe_shape() {
let prices = array![100.0, 102.0, 101.0, 103.0, 105.0, 104.0];
let r = log_returns(&prices).expect("lr");
let sh = rolling_sharpe(&r, 3, 0.0).expect("sharpe");
assert_eq!(sh.len(), r.len());
}
#[test]
fn test_rolling_sharpe_initial_nan() {
let r = Array1::from_vec(vec![0.01, -0.01, 0.02, 0.01, -0.005]);
let sh = rolling_sharpe(&r, 3, 0.0).expect("sharpe");
assert!(sh[0].is_nan());
assert!(sh[1].is_nan());
assert!(!sh[2].is_nan());
}
#[test]
fn test_rolling_sharpe_positive_for_positive_returns() {
let r = Array1::from_vec(vec![0.01, 0.02, 0.015, 0.012, 0.018]);
let sh = rolling_sharpe(&r, 3, 0.0).expect("sharpe");
for i in 2..sh.len() {
assert!(sh[i] > 0.0, "Expected positive Sharpe at {i}, got {}", sh[i]);
}
}
}