use crate::models::Greeks;
use statrs::distribution::{ContinuousCDF, Normal};
pub fn calculate_greeks(
is_call: bool,
spot_price: f64,
strike_price: f64,
time_to_expiry: f64,
risk_free_rate: f64,
implied_volatility: f64,
) -> Option<Greeks> {
if spot_price <= 0.0
|| strike_price <= 0.0
|| time_to_expiry <= 0.0
|| implied_volatility <= 0.0
{
return None;
}
let sqrt_t = time_to_expiry.sqrt();
let d1 = (spot_price / strike_price).ln()
+ (risk_free_rate + implied_volatility.powi(2) / 2.0) * time_to_expiry;
let d1 = d1 / (implied_volatility * sqrt_t);
let d2 = d1 - implied_volatility * sqrt_t;
let normal = Normal::new(0.0, 1.0).ok()?;
let n_d1 = normal.cdf(d1);
let n_d2 = normal.cdf(d2);
let phi_d1 = (-d1.powi(2) / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
let delta = if is_call { n_d1 } else { n_d1 - 1.0 };
let gamma = phi_d1 / (spot_price * implied_volatility * sqrt_t);
let theta_part1 = -(spot_price * phi_d1 * implied_volatility) / (2.0 * sqrt_t);
let theta = if is_call {
theta_part1 - risk_free_rate * strike_price * (-risk_free_rate * time_to_expiry).exp() * n_d2
} else {
theta_part1 + risk_free_rate * strike_price * (-risk_free_rate * time_to_expiry).exp() * (1.0 - n_d2)
};
let theta = theta / 365.0;
let vega = spot_price * phi_d1 * sqrt_t / 100.0;
let rho = if is_call {
strike_price * time_to_expiry * (-risk_free_rate * time_to_expiry).exp() * n_d2 / 100.0
} else {
-strike_price * time_to_expiry * (-risk_free_rate * time_to_expiry).exp() * (1.0 - n_d2)
/ 100.0
};
Some(Greeks {
delta,
gamma,
theta,
vega,
rho,
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_call_greeks() {
let greeks = calculate_greeks(
true, 100.0, 100.0, 1.0, 0.05, 0.2, );
assert!(greeks.is_some());
let greeks = greeks.unwrap();
assert!(greeks.delta > 0.4 && greeks.delta < 0.7);
assert!(greeks.gamma > 0.0);
assert!(greeks.vega > 0.0);
}
#[test]
fn test_put_greeks() {
let greeks = calculate_greeks(
false, 100.0, 100.0, 1.0, 0.05, 0.2, );
assert!(greeks.is_some());
let greeks = greeks.unwrap();
assert!(greeks.delta < 0.0 && greeks.delta > -0.7);
assert!(greeks.gamma > 0.0);
}
#[test]
fn test_invalid_inputs() {
let greeks = calculate_greeks(true, 0.0, 100.0, 1.0, 0.05, 0.2);
assert!(greeks.is_none());
let greeks = calculate_greeks(true, 100.0, 100.0, -1.0, 0.05, 0.2);
assert!(greeks.is_none());
}
#[test]
fn test_deep_itm_call_delta() {
let greeks = calculate_greeks(true, 150.0, 100.0, 0.5, 0.05, 0.2).unwrap();
assert!(greeks.delta > 0.95, "Deep ITM call delta: {}", greeks.delta);
}
#[test]
fn test_deep_otm_call_delta() {
let greeks = calculate_greeks(true, 50.0, 100.0, 0.5, 0.05, 0.2).unwrap();
assert!(greeks.delta < 0.05, "Deep OTM call delta: {}", greeks.delta);
}
#[test]
fn test_theta_is_negative() {
let greeks = calculate_greeks(true, 100.0, 100.0, 0.5, 0.05, 0.2).unwrap();
assert!(greeks.theta < 0.0, "Theta should be negative: {}", greeks.theta);
}
#[test]
fn test_put_call_parity_delta() {
let call = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
let put = calculate_greeks(false, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert_relative_eq!(call.delta - put.delta, 1.0, epsilon = 0.01);
}
#[test]
fn test_gamma_same_for_call_put() {
let call = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
let put = calculate_greeks(false, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert_relative_eq!(call.gamma, put.gamma, epsilon = 0.0001);
}
#[test]
fn test_vega_same_for_call_put() {
let call = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
let put = calculate_greeks(false, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert_relative_eq!(call.vega, put.vega, epsilon = 0.0001);
}
#[test]
fn test_short_expiry_high_gamma() {
let short = calculate_greeks(true, 100.0, 100.0, 0.1, 0.05, 0.2).unwrap();
let long = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert!(short.gamma > long.gamma, "Short gamma: {}, Long gamma: {}", short.gamma, long.gamma);
}
#[test]
fn test_zero_volatility_returns_none() {
let greeks = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.0);
assert!(greeks.is_none());
}
#[test]
fn test_zero_strike_returns_none() {
let greeks = calculate_greeks(true, 100.0, 0.0, 1.0, 0.05, 0.2);
assert!(greeks.is_none());
}
#[test]
fn test_zero_time_returns_none() {
let greeks = calculate_greeks(true, 100.0, 100.0, 0.0, 0.05, 0.2);
assert!(greeks.is_none());
}
#[test]
fn test_negative_spot_price_returns_none() {
let greeks = calculate_greeks(true, -100.0, 100.0, 1.0, 0.05, 0.2);
assert!(greeks.is_none());
}
#[test]
fn test_negative_strike_returns_none() {
let greeks = calculate_greeks(true, 100.0, -100.0, 1.0, 0.05, 0.2);
assert!(greeks.is_none());
}
#[test]
fn test_negative_volatility_returns_none() {
let greeks = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, -0.2);
assert!(greeks.is_none());
}
#[test]
fn test_very_short_expiry() {
let greeks = calculate_greeks(true, 100.0, 100.0, 1.0 / 365.0, 0.05, 0.2);
assert!(greeks.is_some());
assert!(greeks.unwrap().gamma > 0.0);
}
#[test]
fn test_very_long_expiry() {
let short_gamma = calculate_greeks(true, 100.0, 100.0, 0.1, 0.05, 0.2)
.unwrap()
.gamma;
let long_gamma = calculate_greeks(true, 100.0, 100.0, 5.0, 0.05, 0.2)
.unwrap()
.gamma;
assert!(long_gamma < short_gamma);
}
#[test]
fn test_high_volatility() {
let low_vol_greeks = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
let high_vol_greeks = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.8).unwrap();
assert!(low_vol_greeks.vega > 0.0);
assert!(high_vol_greeks.vega > 0.0);
}
#[test]
fn test_rho_positive_for_calls() {
let greeks = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert!(greeks.rho > 0.0);
}
#[test]
fn test_rho_negative_for_puts() {
let greeks = calculate_greeks(false, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert!(greeks.rho < 0.0);
}
#[test]
fn test_higher_interest_rates_increase_call_rho() {
let low_rate = calculate_greeks(true, 100.0, 100.0, 1.0, 0.01, 0.2)
.unwrap()
.rho;
let high_rate = calculate_greeks(true, 100.0, 100.0, 1.0, 0.10, 0.2)
.unwrap()
.rho;
assert!(high_rate > low_rate);
}
#[test]
fn test_atm_has_highest_gamma() {
let atm_gamma = calculate_greeks(true, 100.0, 100.0, 0.5, 0.05, 0.2)
.unwrap()
.gamma;
let itm_gamma = calculate_greeks(true, 100.0, 90.0, 0.5, 0.05, 0.2)
.unwrap()
.gamma;
let otm_gamma = calculate_greeks(true, 100.0, 110.0, 0.5, 0.05, 0.2)
.unwrap()
.gamma;
assert!(atm_gamma > itm_gamma);
assert!(atm_gamma > otm_gamma);
}
#[test]
fn test_vega_highest_for_atm() {
let atm_vega = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2)
.unwrap()
.vega;
let itm_vega = calculate_greeks(true, 100.0, 80.0, 1.0, 0.05, 0.2)
.unwrap()
.vega;
let otm_vega = calculate_greeks(true, 100.0, 120.0, 1.0, 0.05, 0.2)
.unwrap()
.vega;
assert!(atm_vega > itm_vega);
assert!(atm_vega > otm_vega);
}
#[test]
fn test_put_delta_less_than_call_delta() {
let call = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
let put = calculate_greeks(false, 100.0, 100.0, 1.0, 0.05, 0.2).unwrap();
assert!(call.delta > put.delta);
}
#[test]
fn test_call_itm_delta_increases_with_spot() {
let call_100 = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2)
.unwrap()
.delta;
let call_110 = calculate_greeks(true, 110.0, 100.0, 1.0, 0.05, 0.2)
.unwrap()
.delta;
let call_120 = calculate_greeks(true, 120.0, 100.0, 1.0, 0.05, 0.2)
.unwrap()
.delta;
assert!(call_110 > call_100);
assert!(call_120 > call_110);
}
#[test]
fn test_theta_decay_over_time() {
let long_theta = calculate_greeks(true, 100.0, 100.0, 1.0, 0.05, 0.2)
.unwrap()
.theta;
let short_theta = calculate_greeks(true, 100.0, 100.0, 0.1, 0.05, 0.2)
.unwrap()
.theta;
assert!(short_theta < long_theta || short_theta.abs() > long_theta.abs());
}
#[test]
fn test_consistency_across_multiple_calls() {
let greeks1 = calculate_greeks(true, 100.0, 100.0, 0.5, 0.05, 0.2).unwrap();
let greeks2 = calculate_greeks(true, 100.0, 100.0, 0.5, 0.05, 0.2).unwrap();
assert_eq!(greeks1.delta, greeks2.delta);
assert_eq!(greeks1.gamma, greeks2.gamma);
assert_eq!(greeks1.theta, greeks2.theta);
assert_eq!(greeks1.vega, greeks2.vega);
assert_eq!(greeks1.rho, greeks2.rho);
}
}