use argmin::{
core::{CostFunction, Executor, State},
solver::particleswarm::ParticleSwarm,
};
use derive_builder::Builder;
use plotly::{color::NamedColor, common::Marker, common::Mode, Plot, Scatter};
use std::{collections::BTreeMap, hash::Hash, iter::zip};
use time::Date;
use RustQuant_math::{
interpolation::{ExponentialInterpolator, Interpolator, LinearInterpolator},
InterpolationIndex,
};
use RustQuant_stochastics::{CurveModel, NelsonSiegelSvensson};
use RustQuant_time::{Calendar, DateRollingConvention, DayCountConvention};
pub trait CurveIndex: Ord + Hash + InterpolationIndex + Clone + Copy {}
impl<T> CurveIndex for T where T: Ord + Hash + InterpolationIndex + Clone + Copy {}
#[derive(Clone, Debug, PartialEq, Default)]
pub struct Curve<C>
where
C: CurveIndex,
{
pub nodes: BTreeMap<C, f64>,
}
macro_rules! impl_curve {
($index:ty) => {
impl Curve<$index> {
pub fn new() -> Self {
Self {
nodes: BTreeMap::new(),
}
}
pub fn first_key(&self) -> Option<&$index> {
self.nodes.keys().next()
}
pub fn last_key(&self) -> Option<&$index> {
self.nodes.keys().next_back()
}
pub fn keys(&self) -> Vec<$index> {
self.nodes.keys().cloned().collect()
}
pub fn values(&self) -> Vec<f64> {
self.nodes.values().cloned().collect()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn first_value(&self) -> Option<&f64> {
self.nodes.values().next()
}
pub fn last_value(&self) -> Option<&f64> {
self.nodes.values().next_back()
}
pub fn insert(&mut self, index: $index, value: f64) {
self.nodes.insert(index, value);
}
pub fn get(&self, index: $index) -> Option<&f64> {
self.nodes.get(&index)
}
pub fn get_mut(&mut self, index: $index) -> Option<&mut f64> {
self.nodes.get_mut(&index)
}
pub fn new_from_slice(indices: &[$index], values: &[f64]) -> Self {
let mut curve = Self::new();
for (index, value) in indices.iter().zip(values.iter()) {
curve.insert(*index, *value);
}
curve
}
pub fn new_from_function<F>(f: F, indices: &[$index]) -> Self
where
F: Fn($index) -> f64,
{
let mut curve = Self::new();
for index in indices {
curve.insert(*index, f(*index));
}
curve
}
pub fn new_from_constant(value: f64, indices: &[$index]) -> Self {
let mut curve = Self::new();
for index in indices {
curve.insert(*index, value);
}
curve
}
pub fn get_brackets(&self, index: $index) -> ($index, $index) {
let first = self.first_key().unwrap();
let last = self.last_key().unwrap();
if index <= *first {
return (*first, *first);
}
if index >= *last {
return (*last, *last);
}
let left = self.nodes.range(..index).next_back().unwrap().0;
let right = self.nodes.range(index..).next().unwrap().0;
return (*left, *right);
}
pub fn shift(&mut self, shift: f64) {
for value in self.nodes.values_mut() {
*value += shift;
}
}
pub fn interpolate(&mut self, index: $index) -> f64 {
if self.nodes.contains_key(&index) {
return *self.nodes.get(&index).unwrap();
}
let xs: Vec<$index> = self.nodes.keys().cloned().collect();
let ys: Vec<f64> = self.nodes.values().cloned().collect();
let interpolator = LinearInterpolator::new(xs, ys).unwrap();
self.insert(index, interpolator.interpolate(index).unwrap());
*self.nodes.get(&index).unwrap()
}
pub fn interpolate_many(&mut self, indices: &[$index]) {
let xs: Vec<$index> = self.nodes.keys().cloned().collect();
let ys: Vec<f64> = self.nodes.values().cloned().collect();
let interpolator = ExponentialInterpolator::new(xs, ys).unwrap();
for index in indices {
if !self.nodes.contains_key(index) {
self.insert(*index, interpolator.interpolate(*index).unwrap());
}
}
}
pub fn plot(&self) {
let mut plot = Plot::new();
let xs = self
.nodes
.keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
let ys = self.nodes.values().cloned().collect::<Vec<f64>>();
let trace = Scatter::new(xs, ys).mode(Mode::LinesMarkers);
plot.add_trace(trace);
plot.show();
}
pub fn plot_many(curves: &[Self]) {
let mut plot = Plot::new();
for curve in curves {
let xs = curve
.nodes
.keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
let ys = curve.nodes.values().cloned().collect::<Vec<f64>>();
let trace = Scatter::new(xs, ys);
plot.add_trace(trace);
}
plot.show();
}
}
};
}
impl_curve!(time::Date);
impl_curve!(time::Time);
impl_curve!(time::OffsetDateTime);
impl_curve!(time::PrimitiveDateTime);
const CURVE_OPTIM_MAX_ITER: u64 = 69;
const CURVE_OPTIM_SWARM_SIZE: usize = 1000;
pub trait Curves<C> {
fn new(dates: &[Date], rates: &[f64]) -> Self;
fn initial_date(&self) -> Date;
fn terminal_date(&self) -> Date;
fn get_rate(&mut self, date: Date) -> f64;
fn get_rates(&mut self, dates: &[Date]) -> Vec<f64>;
fn insert_rate(&mut self, date: Date, rate: f64);
fn fit(&mut self) -> Result<(), argmin::core::Error>;
fn plot(&self);
}
macro_rules! impl_specific_curve_cost_function {
($curve:ident, $curve_function:ident) => {
impl<C> CostFunction for &$curve<Date, C>
where
C: Calendar,
{
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, p: &Self::Param) -> Result<Self::Output, argmin::core::Error> {
let nss = RustQuant_stochastics::NelsonSiegelSvensson::new(
p[0], p[1], p[2], p[3], p[4], p[5],
);
let n = self.curve.len() as f64;
let x = self.curve.keys();
let y = self.curve.values();
let y_model = x
.iter()
.map(|date| nss.$curve_function(*date))
.collect::<Vec<f64>>();
let data = zip(y.iter(), y_model.iter());
let log_cosh_loss = data.map(|(o, p)| (p - o).cosh().ln()).sum::<f64>() / n;
Ok(log_cosh_loss)
}
}
};
}
macro_rules! impl_specific_curve {
($curve:ident, $curve_function:ident) => {
impl<C> Curves<C> for $curve<Date, C>
where
C: Calendar + Clone,
{
#[doc = concat!("Fit the ", stringify!($curve))]
fn fit(&mut self) -> Result<(), argmin::core::Error> {
let zero = f64::EPSILON;
let bounds = [
(zero, 0.3), (-0.3, 0.3), (-1.0, 1.0), (-1.0, 1.0), (zero, 5.0), (zero, 5.0), ].to_vec().into_iter().map(|(a, b)| (a, b)).collect();
let model = self.clone();
let solver = ParticleSwarm::new(bounds, CURVE_OPTIM_SWARM_SIZE);
let executor =
Executor::new(&model, solver).configure(|state| state.max_iters(CURVE_OPTIM_MAX_ITER));
let result = executor.run()?;
let params = result.state().get_best_param().unwrap().position.to_vec();
self.nss = NelsonSiegelSvensson::new(
params[0], params[1], params[2], params[3], params[4], params[5],
);
self.fitted = true;
println!("TIME: {:?}", result.state().get_time());
Ok(())
}
#[doc = concat!("Creates a new ", stringify!($curve), " curve from a set of `Date`s and rates.")]
fn new(dates: &[Date], rates: &[f64]) -> Self {
assert!(dates.len() == rates.len());
Self {
curve: Curve::<Date>::new_from_slice(&dates, &rates),
calendar: None,
day_count_convention: None,
date_rolling_convention: None,
nss: NelsonSiegelSvensson::default(),
fitted: false,
fitted_curve: None,
}
}
#[doc = concat!("Get the initial date of the ", stringify!($curve))]
fn initial_date(&self) -> Date {
*self.curve.first_key().unwrap()
}
#[doc = concat!("Get the terminal date of the ", stringify!($curve))]
fn terminal_date(&self) -> Date {
*self.curve.last_key().unwrap()
}
#[doc = concat!("Insert a new rate into the ", stringify!($curve))]
fn insert_rate(&mut self, date: Date, rate: f64) {
self.curve.insert(date, rate);
}
fn get_rate(&mut self, date: Date) -> f64 {
match self.curve.get(date) {
Some(rate) => *rate,
None => {
if !self.fitted {
self.fit().unwrap();
self.fitted_curve = Some(Curve::<Date>::new());
}
let rate = self.nss.$curve_function(date);
self.insert_rate(date, rate);
self.fitted_curve.as_mut().unwrap().insert(date, rate);
rate
}
}
}
fn get_rates(&mut self, dates: &[Date]) -> Vec<f64> {
dates.iter().map(|date| self.get_rate(*date)).collect()
}
#[doc = concat!("Plot the ", stringify!($curve))]
fn plot(&self) {
let mut plot = Plot::new();
let xs = self
.curve
.nodes
.keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
let ys = self.curve.nodes.values().cloned().collect::<Vec<f64>>();
let trace = Scatter::new(xs, ys).mode(Mode::Markers).name(
concat!(stringify!($curve)),
);
plot.add_trace(trace);
if self.fitted {
let xs = self
.fitted_curve
.as_ref()
.unwrap()
.nodes
.keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
let ys = self
.fitted_curve
.as_ref()
.unwrap()
.nodes
.values()
.cloned()
.collect::<Vec<f64>>();
let trace = Scatter::new(xs, ys)
.mode(Mode::LinesMarkers)
.marker(Marker::new().color(NamedColor::Red))
.name("Fitted curve");
plot.add_trace(trace);
}
plot.show();
}
}
};
}
#[derive(Builder, Clone, Debug)]
pub struct DiscountCurve<I, C>
where
I: CurveIndex,
C: Calendar,
{
pub curve: Curve<I>,
pub calendar: Option<C>,
pub day_count_convention: Option<DayCountConvention>,
pub date_rolling_convention: Option<DateRollingConvention>,
#[builder(default)]
pub nss: NelsonSiegelSvensson,
#[builder(default = "false")]
pub fitted: bool,
#[builder(default)]
pub fitted_curve: Option<Curve<I>>,
}
impl_specific_curve_cost_function!(DiscountCurve, discount_factor);
impl_specific_curve!(DiscountCurve, discount_factor);
#[derive(Builder, Clone, Debug)]
pub struct SpotCurve<I, C>
where
I: CurveIndex,
C: Calendar,
{
pub curve: Curve<I>,
pub calendar: Option<C>,
pub day_count_convention: Option<DayCountConvention>,
pub date_rolling_convention: Option<DateRollingConvention>,
#[builder(default)]
pub nss: NelsonSiegelSvensson,
#[builder(default = "false")]
pub fitted: bool,
#[builder(default)]
pub fitted_curve: Option<Curve<I>>,
}
impl_specific_curve_cost_function!(SpotCurve, spot_rate);
impl_specific_curve!(SpotCurve, spot_rate);
#[derive(Builder, Clone, Debug)]
pub struct ForwardCurve<I, C>
where
I: CurveIndex,
C: Calendar,
{
pub curve: Curve<I>,
pub calendar: Option<C>,
pub day_count_convention: Option<DayCountConvention>,
pub date_rolling_convention: Option<DateRollingConvention>,
#[builder(default)]
pub nss: NelsonSiegelSvensson,
#[builder(default = "false")]
pub fitted: bool,
#[builder(default)]
pub fitted_curve: Option<Curve<I>>,
}
impl_specific_curve_cost_function!(ForwardCurve, forward_rate);
impl_specific_curve!(ForwardCurve, forward_rate);
#[derive(Builder, Clone, Debug)]
pub struct FlatCurve<C>
where
C: Calendar,
{
pub rate: f64,
pub calendar: Option<C>,
pub day_count_convention: Option<DayCountConvention>,
pub date_rolling_convention: Option<DateRollingConvention>,
}
impl<C> FlatCurve<C>
where
C: Calendar,
{
pub fn new_flat_curve(rate: f64) -> Self {
Self {
rate,
calendar: None,
day_count_convention: None,
date_rolling_convention: None,
}
}
pub fn get_rate(&self) -> f64 {
self.rate
}
pub fn get_rate_for_date(&self, _date: Date) -> f64 {
self.rate
}
pub fn get_rates_for_dates(&self, dates: &[Date]) -> Vec<f64> {
vec![self.rate; dates.len()]
}
}
#[cfg(test)]
mod tests_curves {
}