use ndarray::Array1;
use crate::error::{DigiFiError, ErrorTitle};
use crate::random_generators::{RandomGenerator, uniform_generators::FibonacciGenerator, standard_normal_generators::StandardNormalBoxMuller};
pub enum StochasticDriftType {
TrendStationary { trend: TrendStationary, s_0: f64, },
DifferenceStationary { trend: DifferenceStationary, },
}
pub trait CustomTrendStationary {
fn trend_func(&self, n_paths: usize, t: f64) -> Result<Array1<f64>, DigiFiError>;
fn error_title(&self) -> String;
fn validate(&self, n_paths: usize) -> Result<(), DigiFiError> {
let t: f64 = FibonacciGenerator::new_shuffle(1)?.generate()?[0];
if self.trend_func(n_paths, t)?.len() != n_paths {
return Err(DigiFiError::CustomFunctionLengthVal { title: self.error_title() });
}
Ok(())
}
}
pub enum TrendStationary {
Linear { a: f64, b: f64, },
Quadratic { a: f64, b: f64, c: f64, },
Exponential { a: f64, b: f64, },
Custom { f: Box<dyn CustomTrendStationary> },
}
impl TrendStationary {
pub fn validate(&self, n_paths: usize) -> Result<(), DigiFiError> {
match &self {
Self::Linear { .. } => Ok(()),
Self::Quadratic { a, .. } => {
if *a == 0.0 {
return Err(DigiFiError::ParameterConstraint {
title: Self::error_title(),
constraint: "The argument `a` for quadratic trend stationary trend type must be non-zero.".to_owned(),
});
}
Ok(())
},
Self::Exponential { .. } => Ok(()),
Self::Custom { f } => { f.validate(n_paths) },
}
}
pub fn get_stationary_trend(&self, n_paths: usize, t: f64) -> Result<Array1<f64>, DigiFiError> {
let base_shape: Array1<f64> = Array1::from_vec(vec![1.0; n_paths]);
match &self {
Self::Linear { a, b } => { Ok((a*t + b) * base_shape) },
Self::Quadratic { a, b, c } => { Ok((a * t.powi(2) + b * t + c) * base_shape) },
Self::Exponential { a, b } => { Ok((b * (a * t).exp()) * base_shape) },
Self::Custom { f } => { f.trend_func(n_paths, t) },
}
}
}
impl ErrorTitle for TrendStationary {
fn error_title() -> String {
String::from("Trend Stationary Trend Type")
}
}
pub trait CustomStationaryError {
fn error_func(&self, n_paths: usize, dt: f64) -> Result<Array1<f64>, DigiFiError>;
fn error_title(&self) -> String;
fn validate(&self, n_paths: usize) -> Result<(), DigiFiError> {
let dt: f64 = FibonacciGenerator::new_shuffle(1)?.generate()?[0];
if self.error_func(n_paths, dt)?.len() != n_paths {
return Err(DigiFiError::CustomFunctionLengthVal { title: self.error_title() });
}
Ok(())
}
}
pub enum StationaryError {
Weiner {
sigma: f64,
},
Custom { f: Box<dyn CustomStationaryError> },
}
impl StationaryError {
pub fn validate(&self, n_paths: usize) -> Result<(), DigiFiError> {
match &self {
Self::Custom { f } => { f.validate(n_paths) },
_ => Ok(())
}
}
pub fn get_error(&self, n_paths: usize, dt: f64) -> Result<Array1<f64>, DigiFiError> {
match &self {
Self::Weiner { sigma } => {
let n: Array1<f64> = StandardNormalBoxMuller::new_shuffle(n_paths)?.generate()?;
Ok(sigma.powi(2) * dt * n)
},
Self::Custom { f } => f.error_func(n_paths, dt)
}
}
}
#[derive(Debug)]
pub struct DifferenceStationary {
n_paths: usize,
autoregression_params: Array1<f64>,
starting_values: Vec<Array1<f64>>,
}
impl DifferenceStationary {
pub fn build(n_paths: usize, autoregression_params: Array1<f64>, starting_values: Vec<Array1<f64>>) -> Result<Self, DigiFiError> {
Self::validate_values(n_paths, &starting_values)?;
Ok(Self { n_paths, autoregression_params, starting_values })
}
fn validate_values(n_paths: usize, values: &Vec<Array1<f64>>) -> Result<(), DigiFiError> {
let process_order: usize = values.len();
if values.len() != n_paths {
return Err(DigiFiError::WrongLength { title: Self::error_title(), arg: "previous values".to_owned(), len: n_paths, });
}
for v in values {
if v.len() != process_order {
return Err(DigiFiError::Other {
title: Self::error_title(),
details: format!("All arrays of previous values must match the order of process, {}.", process_order),
});
}
}
Ok(())
}
pub fn strating_values(&self) -> Vec<Array1<f64>> {
self.starting_values.clone()
}
pub fn get_autoregression(&self, previous_values: &Vec<Array1<f64>>) -> Result<Vec<f64>, DigiFiError> {
Self::validate_values(self.n_paths, previous_values)?;
let mut result: Vec<f64> = Vec::with_capacity(previous_values.len());
for process in previous_values {
result.push(process.dot(&self.autoregression_params));
}
Ok(result)
}
}
impl ErrorTitle for DifferenceStationary {
fn error_title() -> String {
String::from("Difference Stationary")
}
}