use crate::traits::FloatExt;
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TrsDirection {
#[default]
ReceiveEquity,
PayEquity,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrsPeriod<T: FloatExt> {
pub end_time: T,
pub accrual: T,
pub funding_rate: T,
}
#[derive(Debug, Clone)]
pub struct TotalReturnSwap<T: FloatExt> {
pub notional: T,
pub spot: T,
pub equity_drift_rate: T,
pub schedule: Vec<TrsPeriod<T>>,
pub spread: T,
pub direction: TrsDirection,
}
#[derive(Debug, Clone)]
pub struct TrsValuation<T: FloatExt> {
pub equity_leg_pv: T,
pub funding_leg_pv: T,
pub net_pv: T,
pub fair_spread: T,
pub spread_annuity: T,
pub equity_cashflows: Vec<T>,
pub funding_cashflows: Vec<T>,
}
impl<T: FloatExt> TotalReturnSwap<T> {
pub fn total_return_forward(&self, t: T) -> T {
self.spot * (self.equity_drift_rate * t).exp()
}
pub fn value<F: Fn(T) -> T>(&self, df: F) -> TrsValuation<T> {
let one = T::one();
let zero = T::zero();
let n = self.schedule.len();
let mut prev_fwd = self.spot;
let mut equity_pv = zero;
let mut funding_pv = zero;
let mut annuity = zero;
let mut equity_cf = Vec::with_capacity(n);
let mut funding_cf = Vec::with_capacity(n);
for period in &self.schedule {
let fwd_now = self.total_return_forward(period.end_time);
let r_eq = fwd_now / prev_fwd - one;
let cf_eq = self.notional * r_eq;
let cf_fund = self.notional * period.accrual * (period.funding_rate + self.spread);
let disc = df(period.end_time);
equity_pv += disc * cf_eq;
funding_pv += disc * cf_fund;
annuity += disc * period.accrual;
equity_cf.push(cf_eq);
funding_cf.push(cf_fund);
prev_fwd = fwd_now;
}
let funding_no_spread: T = self
.schedule
.iter()
.map(|p| df(p.end_time) * p.accrual * p.funding_rate)
.fold(zero, |acc, x| acc + x);
let fair_spread = if annuity != zero {
(equity_pv - funding_no_spread * self.notional) / (annuity * self.notional)
} else {
T::nan()
};
let net_pv = match self.direction {
TrsDirection::ReceiveEquity => equity_pv - funding_pv,
TrsDirection::PayEquity => funding_pv - equity_pv,
};
TrsValuation {
equity_leg_pv: equity_pv,
funding_leg_pv: funding_pv,
net_pv,
fair_spread,
spread_annuity: annuity * self.notional,
equity_cashflows: equity_cf,
funding_cashflows: funding_cf,
}
}
pub fn value_flat(&self, discount_rate: T) -> TrsValuation<T> {
self.value(|t| (-discount_rate * t).exp())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn quarterly_schedule(maturity: f64, rate: f64) -> Vec<TrsPeriod<f64>> {
let n = (maturity * 4.0).round() as usize;
let dt = maturity / n as f64;
(1..=n)
.map(|i| TrsPeriod {
end_time: dt * i as f64,
accrual: dt,
funding_rate: rate,
})
.collect()
}
#[test]
fn fair_spread_small_for_self_financing_continuous() {
let trs = TotalReturnSwap {
notional: 1_000_000.0,
spot: 100.0,
equity_drift_rate: 0.04,
schedule: quarterly_schedule(1.0, 0.04),
spread: 0.0,
direction: TrsDirection::ReceiveEquity,
};
let v = trs.value_flat(0.04);
assert!(v.fair_spread.abs() < 5e-4, "fair_spread={}", v.fair_spread);
assert!(
v.fair_spread > 0.0,
"TR ≥ simple-funding ⇒ fair spread positive"
);
}
#[test]
fn dividend_yield_does_not_enter_total_return() {
let trs = TotalReturnSwap {
notional: 1.0,
spot: 100.0,
equity_drift_rate: 0.05,
schedule: quarterly_schedule(1.0, 0.05),
spread: 0.0,
direction: TrsDirection::ReceiveEquity,
};
let v = trs.value_flat(0.05);
assert!(v.equity_leg_pv > 0.0);
}
#[test]
fn pay_vs_receive_have_opposite_signs() {
let mut trs = TotalReturnSwap {
notional: 1_000_000.0,
spot: 100.0,
equity_drift_rate: 0.06,
schedule: quarterly_schedule(1.0, 0.04),
spread: 0.005,
direction: TrsDirection::ReceiveEquity,
};
let v_recv = trs.value_flat(0.04);
trs.direction = TrsDirection::PayEquity;
let v_pay = trs.value_flat(0.04);
assert!((v_recv.net_pv + v_pay.net_pv).abs() < 1e-9);
}
#[test]
fn fair_spread_zeroes_net_pv() {
let mut trs = TotalReturnSwap {
notional: 1_000_000.0,
spot: 100.0,
equity_drift_rate: 0.06,
schedule: quarterly_schedule(2.0, 0.04),
spread: 0.0,
direction: TrsDirection::ReceiveEquity,
};
let v0 = trs.value_flat(0.04);
trs.spread = v0.fair_spread;
let v1 = trs.value_flat(0.04);
assert!(
v1.net_pv.abs() < 1e-7,
"net_pv at fair spread = {}",
v1.net_pv
);
}
#[test]
fn cashflows_match_period_count() {
let trs = TotalReturnSwap {
notional: 1.0,
spot: 100.0,
equity_drift_rate: 0.05,
schedule: quarterly_schedule(1.0, 0.05),
spread: 0.0,
direction: TrsDirection::ReceiveEquity,
};
let v = trs.value_flat(0.05);
assert_eq!(v.equity_cashflows.len(), 4);
assert_eq!(v.funding_cashflows.len(), 4);
}
#[test]
fn matches_textbook_one_period_formula() {
let r = 0.05;
let t: f64 = 1.0;
let trs = TotalReturnSwap {
notional: 1.0,
spot: 100.0,
equity_drift_rate: r,
schedule: vec![TrsPeriod {
end_time: t,
accrual: t,
funding_rate: r,
}],
spread: 0.0,
direction: TrsDirection::ReceiveEquity,
};
let v = trs.value_flat(r);
let expected_eq_cf = (r * t).exp() - 1.0;
let expected_fund_cf = r * t;
assert!((v.equity_cashflows[0] - expected_eq_cf).abs() < 1e-12);
assert!((v.funding_cashflows[0] - expected_fund_cf).abs() < 1e-12);
let expected_spread = (expected_eq_cf - expected_fund_cf) / t;
assert!((v.fair_spread - expected_spread).abs() < 1e-12);
}
}