use crate::compat::Instant;
use rust_decimal::Decimal;
use rust_decimal::MathematicalOps;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use crate::error::CorpFinanceError;
use crate::types::*;
use crate::CorpFinanceResult;
use super::returns::ReturnFrequency;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskMetricsInput {
pub returns: Vec<Decimal>,
pub frequency: ReturnFrequency,
pub confidence_level: Rate,
#[serde(skip_serializing_if = "Option::is_none")]
pub portfolio_value: Option<Money>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskMetricsOutput {
pub var_parametric: Decimal,
pub var_historical: Decimal,
pub cvar: Decimal,
pub max_drawdown: Rate,
pub max_drawdown_duration: u32,
pub skewness: Decimal,
pub kurtosis: Decimal,
pub annualised_volatility: Rate,
}
pub fn calculate_risk_metrics(
input: &RiskMetricsInput,
) -> CorpFinanceResult<ComputationOutput<RiskMetricsOutput>> {
let start = Instant::now();
let warnings: Vec<String> = Vec::new();
let n = input.returns.len();
if n < 3 {
return Err(CorpFinanceError::InsufficientData(
"At least 3 return observations required for risk metrics".into(),
));
}
if input.confidence_level <= Decimal::ZERO || input.confidence_level >= Decimal::ONE {
return Err(CorpFinanceError::InvalidInput {
field: "confidence_level".into(),
reason: "Confidence level must be between 0 and 1 (exclusive)".into(),
});
}
let n_dec = Decimal::from(n as i64);
let periods = input.frequency.periods_per_year();
let mean: Decimal = input.returns.iter().sum::<Decimal>() / n_dec;
let variance = {
let sum_sq: Decimal = input.returns.iter().map(|r| (r - mean) * (r - mean)).sum();
sum_sq / Decimal::from((n - 1) as i64)
};
let std_dev = sqrt_decimal(variance);
let annualised_volatility = std_dev * sqrt_decimal(periods);
let z_score = z_score_for_confidence(input.confidence_level);
let var_parametric = -(mean - z_score * std_dev);
let mut sorted = input.returns.clone();
sorted.sort();
let var_index = ((Decimal::ONE - input.confidence_level) * n_dec)
.floor()
.to_string()
.parse::<usize>()
.unwrap_or(0);
let var_index = var_index.min(n - 1);
let var_historical = -(sorted[var_index]);
let threshold = sorted[var_index];
let tail: Vec<&Decimal> = sorted.iter().filter(|r| **r <= threshold).collect();
let cvar = if tail.is_empty() {
var_historical
} else {
let tail_sum: Decimal = tail.iter().copied().sum();
-(tail_sum / Decimal::from(tail.len() as i64))
};
let (max_drawdown, max_drawdown_duration) = max_drawdown_with_duration(&input.returns);
let skewness = if n < 3 || std_dev.is_zero() {
Decimal::ZERO
} else {
let m3: Decimal = input.returns.iter().map(|r| (r - mean).powd(dec!(3))).sum();
let adjustment = n_dec / (Decimal::from((n - 1) as i64) * Decimal::from((n - 2) as i64));
let sigma3 = std_dev * std_dev * std_dev;
if sigma3.is_zero() {
Decimal::ZERO
} else {
adjustment * m3 / sigma3
}
};
let kurtosis = if n < 4 || std_dev.is_zero() {
Decimal::ZERO
} else {
let m4: Decimal = input.returns.iter().map(|r| (r - mean).powd(dec!(4))).sum();
let sigma4 = variance * variance;
if sigma4.is_zero() {
Decimal::ZERO
} else {
let n1 = Decimal::from((n - 1) as i64);
let n2 = Decimal::from((n - 2) as i64);
let n3 = Decimal::from((n - 3) as i64);
let factor1 = n_dec * (n_dec + Decimal::ONE) / (n1 * n2 * n3);
let factor2 = dec!(3) * n1 * n1 / (n2 * n3);
factor1 * (m4 / sigma4) * n_dec - factor2
}
};
let output = RiskMetricsOutput {
var_parametric,
var_historical,
cvar,
max_drawdown,
max_drawdown_duration,
skewness,
kurtosis,
annualised_volatility,
};
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"Portfolio Risk Metrics (VaR, CVaR, Drawdown, Skewness, Kurtosis)",
&serde_json::json!({
"observations": n,
"confidence_level": input.confidence_level.to_string(),
"frequency": format!("{:?}", input.frequency),
}),
warnings,
elapsed,
output,
))
}
fn max_drawdown_with_duration(returns: &[Decimal]) -> (Rate, u32) {
let mut cumulative = Decimal::ONE;
let mut peak = Decimal::ONE;
let mut max_dd = Decimal::ZERO;
let mut peak_idx: usize = 0;
let mut max_dd_start: usize = 0;
let mut max_dd_end: usize = 0;
for (i, r) in returns.iter().enumerate() {
cumulative *= Decimal::ONE + r;
if cumulative > peak {
peak = cumulative;
peak_idx = i;
}
if !peak.is_zero() {
let dd = (peak - cumulative) / peak;
if dd > max_dd {
max_dd = dd;
max_dd_start = peak_idx;
max_dd_end = i;
}
}
}
let duration = if max_dd_end >= max_dd_start {
(max_dd_end - max_dd_start) as u32
} else {
0
};
(max_dd, duration)
}
fn z_score_for_confidence(confidence: Decimal) -> Decimal {
if confidence == dec!(0.90) {
return dec!(1.282);
}
if confidence == dec!(0.95) {
return dec!(1.645);
}
if confidence == dec!(0.975) {
return dec!(1.960);
}
if confidence == dec!(0.99) {
return dec!(2.326);
}
if confidence == dec!(0.995) {
return dec!(2.576);
}
if confidence >= dec!(0.95) && confidence <= dec!(0.99) {
let t = (confidence - dec!(0.95)) / dec!(0.04);
return dec!(1.645) + t * (dec!(2.326) - dec!(1.645));
}
if confidence >= dec!(0.90) && confidence < dec!(0.95) {
let t = (confidence - dec!(0.90)) / dec!(0.05);
return dec!(1.282) + t * (dec!(1.645) - dec!(1.282));
}
dec!(1.645)
}
fn sqrt_decimal(val: Decimal) -> Decimal {
if val <= Decimal::ZERO {
return Decimal::ZERO;
}
val.sqrt().unwrap_or(Decimal::ZERO)
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
fn sample_returns() -> Vec<Decimal> {
vec![
dec!(0.05),
dec!(-0.02),
dec!(0.03),
dec!(0.01),
dec!(-0.01),
dec!(0.04),
dec!(0.02),
dec!(-0.03),
dec!(0.06),
dec!(0.01),
dec!(-0.02),
dec!(0.03),
]
}
#[test]
fn test_basic_risk_metrics() {
let input = RiskMetricsInput {
returns: sample_returns(),
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.95),
portfolio_value: None,
};
let result = calculate_risk_metrics(&input).unwrap();
let out = &result.result;
assert!(out.var_parametric > Decimal::ZERO);
assert!(out.var_historical > Decimal::ZERO);
assert!(out.cvar >= out.var_historical);
assert!(out.annualised_volatility > Decimal::ZERO);
}
#[test]
fn test_cvar_gte_var() {
let input = RiskMetricsInput {
returns: sample_returns(),
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.95),
portfolio_value: None,
};
let result = calculate_risk_metrics(&input).unwrap();
assert!(result.result.cvar >= result.result.var_historical);
}
#[test]
fn test_max_drawdown() {
let returns = vec![dec!(0.10), dec!(-0.20), dec!(0.05), dec!(-0.15)];
let input = RiskMetricsInput {
returns,
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.95),
portfolio_value: None,
};
let result = calculate_risk_metrics(&input).unwrap();
assert!(result.result.max_drawdown > dec!(0.20));
}
#[test]
fn test_higher_confidence_higher_var() {
let rets = sample_returns();
let input95 = RiskMetricsInput {
returns: rets.clone(),
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.95),
portfolio_value: None,
};
let input99 = RiskMetricsInput {
returns: rets,
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.99),
portfolio_value: None,
};
let r95 = calculate_risk_metrics(&input95).unwrap();
let r99 = calculate_risk_metrics(&input99).unwrap();
assert!(r99.result.var_parametric > r95.result.var_parametric);
}
#[test]
fn test_insufficient_data() {
let input = RiskMetricsInput {
returns: vec![dec!(0.05), dec!(0.03)],
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(0.95),
portfolio_value: None,
};
assert!(calculate_risk_metrics(&input).is_err());
}
#[test]
fn test_invalid_confidence() {
let input = RiskMetricsInput {
returns: sample_returns(),
frequency: ReturnFrequency::Monthly,
confidence_level: dec!(1.5),
portfolio_value: None,
};
assert!(calculate_risk_metrics(&input).is_err());
}
#[test]
fn test_z_scores() {
assert_eq!(z_score_for_confidence(dec!(0.95)), dec!(1.645));
assert_eq!(z_score_for_confidence(dec!(0.99)), dec!(2.326));
}
}