use std::fmt::Display;
use ndarray::Array1;
use crate::curves::DiscountCurve;
use crate::traits::FloatExt;
#[derive(Debug, Clone, Copy)]
pub enum Shock<T: FloatExt> {
Additive(T),
Multiplicative(T),
Level(T),
}
impl<T: FloatExt> Shock<T> {
pub fn apply(&self, x: T) -> T {
match *self {
Self::Additive(v) => x + v,
Self::Multiplicative(v) => x * v,
Self::Level(v) => v,
}
}
}
impl<T: FloatExt> Display for Shock<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Additive(v) => write!(f, "+{v:?}"),
Self::Multiplicative(v) => write!(f, "×{v:?}"),
Self::Level(v) => write!(f, "={v:?}"),
}
}
}
#[derive(Debug, Clone)]
pub enum CurveShift<T: FloatExt> {
Parallel(T),
Twist { short_shift: T, long_shift: T },
KeyRate { pillar: T, amount: T },
AtPillars(Array1<T>),
}
impl<T: FloatExt> CurveShift<T> {
pub fn apply(&self, base: &DiscountCurve<T>) -> DiscountCurve<T> {
let points = base.points();
let n = points.len();
let times: Array1<T> = Array1::from_iter(points.iter().map(|p| p.time));
let base_rates: Array1<T> = Array1::from_iter(points.iter().map(|p| {
if p.time > T::zero() {
-p.discount_factor.ln() / p.time
} else {
T::zero()
}
}));
let shifted_rates = match self {
Self::Parallel(amount) => base_rates.mapv(|r| r + *amount),
Self::Twist {
short_shift,
long_shift,
} => {
let t_max = times[n - 1].max(T::min_positive_val());
base_rates
.iter()
.zip(times.iter())
.map(|(&r, &t)| {
let w = t / t_max;
let shift = *short_shift * (T::one() - w) + *long_shift * w;
r + shift
})
.collect::<Array1<T>>()
}
Self::KeyRate { pillar, amount } => {
let mut out = base_rates.clone();
let idx = nearest_pillar_index(×, *pillar);
out[idx] += *amount;
out
}
Self::AtPillars(deltas) => {
assert_eq!(
deltas.len(),
n,
"AtPillars shift length {} must match curve pillar count {n}",
deltas.len()
);
&base_rates + deltas
}
};
DiscountCurve::from_zero_rates(×, &shifted_rates, base.method())
}
}
fn nearest_pillar_index<T: FloatExt>(times: &Array1<T>, target: T) -> usize {
let mut best = 0usize;
let mut best_diff = (times[0] - target).abs();
for (i, &t) in times.iter().enumerate().skip(1) {
let d = (t - target).abs();
if d < best_diff {
best_diff = d;
best = i;
}
}
best
}
#[derive(Debug, Clone)]
pub struct Scenario<T: FloatExt> {
pub name: String,
pub tags: Vec<String>,
pub shocks: Vec<(String, Shock<T>)>,
pub curve_shifts: Vec<(String, CurveShift<T>)>,
}
impl<T: FloatExt> Scenario<T> {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
tags: Vec::new(),
shocks: Vec::new(),
curve_shifts: Vec::new(),
}
}
pub fn with_shock(mut self, factor: impl Into<String>, shock: Shock<T>) -> Self {
self.shocks.push((factor.into(), shock));
self
}
pub fn with_curve_shift(mut self, curve_key: impl Into<String>, shift: CurveShift<T>) -> Self {
self.curve_shifts.push((curve_key.into(), shift));
self
}
pub fn tagged(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn resolve_scalar(&self, factor: &str, base: T) -> T {
let mut x = base;
for (name, shock) in &self.shocks {
if name == factor {
x = shock.apply(x);
}
}
x
}
pub fn resolve_curve(&self, curve_key: &str, base: &DiscountCurve<T>) -> DiscountCurve<T> {
let mut curve = base.clone();
for (key, shift) in &self.curve_shifts {
if key == curve_key {
curve = shift.apply(&curve);
}
}
curve
}
}
#[derive(Debug, Clone)]
pub struct ScenarioResult<T: FloatExt> {
pub name: String,
pub base_value: T,
pub shocked_value: T,
pub pnl: T,
}
#[derive(Debug, Clone)]
pub struct StressTest<T: FloatExt> {
scenarios: Vec<Scenario<T>>,
}
impl<T: FloatExt> StressTest<T> {
pub fn new(scenarios: Vec<Scenario<T>>) -> Self {
Self { scenarios }
}
pub fn push(&mut self, scenario: Scenario<T>) {
self.scenarios.push(scenario);
}
pub fn scenarios(&self) -> &[Scenario<T>] {
&self.scenarios
}
pub fn run<F, G>(&self, base_value: F, mut scenario_value: G) -> Vec<ScenarioResult<T>>
where
F: Fn() -> T,
G: FnMut(&Scenario<T>) -> T,
{
let base = base_value();
self
.scenarios
.iter()
.map(|s| {
let shocked = scenario_value(s);
ScenarioResult {
name: s.name.clone(),
base_value: base,
shocked_value: shocked,
pnl: shocked - base,
}
})
.collect()
}
}