use crate::Decimal;
use crate::types::decimal::{decimal_ln, decimal_sqrt};
use crate::types::error::{MMError, MMResult};
#[derive(Debug, Clone, Copy, Default)]
pub struct VolatilityEstimator {
annualization_factor: Option<Decimal>,
}
impl VolatilityEstimator {
#[must_use]
pub fn new() -> Self {
Self {
annualization_factor: None, }
}
#[must_use]
pub fn with_annualization_factor(factor: Decimal) -> Self {
Self {
annualization_factor: Some(factor),
}
}
fn get_annualization_factor(&self) -> MMResult<Decimal> {
if let Some(factor) = self.annualization_factor {
Ok(factor)
} else {
decimal_sqrt(Decimal::from(252))
}
}
pub fn calculate_simple(&self, prices: &[Decimal]) -> MMResult<Decimal> {
if prices.len() < 2 {
return Err(MMError::InvalidMarketState(
"need at least 2 prices to calculate volatility".to_string(),
));
}
let mut log_returns = Vec::with_capacity(prices.len() - 1);
for i in 1..prices.len() {
if prices[i] <= Decimal::ZERO || prices[i - 1] <= Decimal::ZERO {
return Err(MMError::InvalidMarketState(
"prices must be positive".to_string(),
));
}
let ratio = prices[i] / prices[i - 1];
let log_return = decimal_ln(ratio)?;
log_returns.push(log_return);
}
let sum: Decimal = log_returns.iter().sum();
let mean = sum / Decimal::from(log_returns.len());
let squared_deviations: Decimal = log_returns
.iter()
.map(|&r| {
let dev = r - mean;
dev * dev
})
.sum();
let variance = squared_deviations / Decimal::from(log_returns.len() - 1);
let std_dev = decimal_sqrt(variance)?;
let annualization_factor = self.get_annualization_factor()?;
Ok(std_dev * annualization_factor)
}
pub fn calculate_ewma(&self, prices: &[Decimal], lambda: Decimal) -> MMResult<Decimal> {
if prices.len() < 2 {
return Err(MMError::InvalidMarketState(
"need at least 2 prices for EWMA".to_string(),
));
}
if lambda <= Decimal::ZERO || lambda >= Decimal::ONE {
return Err(MMError::InvalidConfiguration(
"lambda must be between 0 and 1".to_string(),
));
}
let mut log_returns = Vec::with_capacity(prices.len() - 1);
for i in 1..prices.len() {
if prices[i] <= Decimal::ZERO || prices[i - 1] <= Decimal::ZERO {
return Err(MMError::InvalidMarketState(
"prices must be positive".to_string(),
));
}
let ratio = prices[i] / prices[i - 1];
let log_return = decimal_ln(ratio)?;
log_returns.push(log_return);
}
let initial_variance = if log_returns.len() >= 5 {
let first_returns = &log_returns[0..5];
let mean: Decimal = first_returns.iter().sum::<Decimal>() / Decimal::from(5);
let squared_devs: Decimal = first_returns
.iter()
.map(|&r| {
let dev = r - mean;
dev * dev
})
.sum();
squared_devs / Decimal::from(4)
} else {
let mean: Decimal =
log_returns.iter().sum::<Decimal>() / Decimal::from(log_returns.len());
let squared_devs: Decimal = log_returns
.iter()
.map(|&r| {
let dev = r - mean;
dev * dev
})
.sum();
squared_devs / Decimal::from(log_returns.len() - 1)
};
let one_minus_lambda = Decimal::ONE - lambda;
let mut variance = initial_variance;
for &log_return in &log_returns {
let squared_return = log_return * log_return;
variance = lambda * variance + one_minus_lambda * squared_return;
}
let std_dev = decimal_sqrt(variance)?;
let annualization_factor = self.get_annualization_factor()?;
Ok(std_dev * annualization_factor)
}
pub fn calculate_parkinson(
&self,
high_prices: &[Decimal],
low_prices: &[Decimal],
) -> MMResult<Decimal> {
if high_prices.len() != low_prices.len() {
return Err(MMError::InvalidMarketState(
"high and low price vectors must have same length".to_string(),
));
}
if high_prices.is_empty() {
return Err(MMError::InvalidMarketState(
"need at least 1 price pair for Parkinson estimator".to_string(),
));
}
let mut sum_squared_ratios = Decimal::ZERO;
for i in 0..high_prices.len() {
if high_prices[i] <= Decimal::ZERO || low_prices[i] <= Decimal::ZERO {
return Err(MMError::InvalidMarketState(
"prices must be positive".to_string(),
));
}
if high_prices[i] < low_prices[i] {
return Err(MMError::InvalidMarketState(
"high price must be >= low price".to_string(),
));
}
let ratio = high_prices[i] / low_prices[i];
let log_ratio = decimal_ln(ratio)?;
sum_squared_ratios += log_ratio * log_ratio;
}
let ln_2 = decimal_ln(Decimal::from(2))?;
let n = Decimal::from(high_prices.len());
let four = Decimal::from(4);
let variance = sum_squared_ratios / (four * n * ln_2);
let std_dev = decimal_sqrt(variance)?;
let annualization_factor = self.get_annualization_factor()?;
Ok(std_dev * annualization_factor)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dec;
#[test]
fn test_volatility_estimator_new() {
let estimator = VolatilityEstimator::new();
assert!(estimator.annualization_factor.is_none());
}
#[test]
fn test_volatility_estimator_with_custom_factor() {
let factor = dec!(100);
let estimator = VolatilityEstimator::with_annualization_factor(factor);
assert_eq!(estimator.annualization_factor, Some(factor));
}
#[test]
fn test_calculate_simple_valid() {
let estimator = VolatilityEstimator::new();
let prices = vec![
dec!(100.0),
dec!(101.0),
dec!(99.5),
dec!(100.5),
dec!(102.0),
];
let vol = estimator.calculate_simple(&prices).unwrap();
assert!(vol > Decimal::ZERO);
assert!(vol < Decimal::ONE);
}
#[test]
fn test_calculate_simple_insufficient_data() {
let estimator = VolatilityEstimator::new();
let prices = vec![dec!(100.0)];
let result = estimator.calculate_simple(&prices);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MMError::InvalidMarketState(_)
));
}
#[test]
fn test_calculate_simple_negative_price() {
let estimator = VolatilityEstimator::new();
let prices = vec![dec!(100.0), dec!(-50.0), dec!(75.0)];
let result = estimator.calculate_simple(&prices);
assert!(result.is_err());
}
#[test]
fn test_calculate_ewma_valid() {
let estimator = VolatilityEstimator::new();
let prices = vec![
dec!(100.0),
dec!(101.5),
dec!(99.0),
dec!(102.0),
dec!(101.0),
];
let vol = estimator.calculate_ewma(&prices, dec!(0.94)).unwrap();
assert!(vol > Decimal::ZERO);
assert!(vol < Decimal::ONE);
}
#[test]
fn test_calculate_ewma_invalid_lambda() {
let estimator = VolatilityEstimator::new();
let prices = vec![dec!(100.0), dec!(101.0), dec!(102.0)];
let result = estimator.calculate_ewma(&prices, Decimal::ZERO);
assert!(result.is_err());
let result = estimator.calculate_ewma(&prices, Decimal::ONE);
assert!(result.is_err());
let result = estimator.calculate_ewma(&prices, dec!(1.5));
assert!(result.is_err());
}
#[test]
fn test_calculate_parkinson_valid() {
let estimator = VolatilityEstimator::new();
let highs = vec![dec!(102.0), dec!(103.0), dec!(101.5)];
let lows = vec![dec!(99.0), dec!(100.0), dec!(98.5)];
let vol = estimator.calculate_parkinson(&highs, &lows).unwrap();
assert!(vol > Decimal::ZERO);
}
#[test]
fn test_calculate_parkinson_mismatched_lengths() {
let estimator = VolatilityEstimator::new();
let highs = vec![dec!(102.0), dec!(103.0)];
let lows = vec![dec!(99.0)];
let result = estimator.calculate_parkinson(&highs, &lows);
assert!(result.is_err());
}
#[test]
fn test_calculate_parkinson_high_less_than_low() {
let estimator = VolatilityEstimator::new();
let highs = vec![dec!(100.0), dec!(99.0)];
let lows = vec![dec!(101.0), dec!(100.0)];
let result = estimator.calculate_parkinson(&highs, &lows);
assert!(result.is_err());
}
#[test]
fn test_ewma_vs_simple() {
let estimator = VolatilityEstimator::new();
let prices = vec![
dec!(100.0),
dec!(100.5),
dec!(101.0),
dec!(105.0),
dec!(95.0),
];
let simple_vol = estimator.calculate_simple(&prices).unwrap();
let ewma_vol = estimator.calculate_ewma(&prices, dec!(0.94)).unwrap();
assert!(simple_vol > Decimal::ZERO);
assert!(ewma_vol > Decimal::ZERO);
}
#[test]
fn test_calculate_ewma_insufficient_data() {
let estimator = VolatilityEstimator::new();
let prices = vec![dec!(100.0)];
let result = estimator.calculate_ewma(&prices, dec!(0.94));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MMError::InvalidMarketState(_)
));
}
#[test]
fn test_calculate_ewma_negative_price() {
let estimator = VolatilityEstimator::new();
let prices = vec![dec!(100.0), dec!(-50.0), dec!(75.0)];
let result = estimator.calculate_ewma(&prices, dec!(0.94));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MMError::InvalidMarketState(_)
));
}
#[test]
fn test_calculate_ewma_with_many_prices() {
let estimator = VolatilityEstimator::new();
let prices = vec![
dec!(100.0),
dec!(101.0),
dec!(102.0),
dec!(101.5),
dec!(103.0),
dec!(104.0),
];
let vol = estimator.calculate_ewma(&prices, dec!(0.94)).unwrap();
assert!(vol > Decimal::ZERO);
}
#[test]
fn test_calculate_parkinson_empty_vectors() {
let estimator = VolatilityEstimator::new();
let highs: Vec<Decimal> = vec![];
let lows: Vec<Decimal> = vec![];
let result = estimator.calculate_parkinson(&highs, &lows);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MMError::InvalidMarketState(_)
));
}
#[test]
fn test_calculate_parkinson_negative_price() {
let estimator = VolatilityEstimator::new();
let highs = vec![dec!(102.0), dec!(-103.0)];
let lows = vec![dec!(99.0), dec!(-105.0)];
let result = estimator.calculate_parkinson(&highs, &lows);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MMError::InvalidMarketState(_)
));
}
#[test]
fn test_calculate_parkinson_zero_price() {
let estimator = VolatilityEstimator::new();
let highs = vec![dec!(102.0), Decimal::ZERO];
let lows = vec![dec!(99.0), dec!(98.0)];
let result = estimator.calculate_parkinson(&highs, &lows);
assert!(result.is_err());
}
}