use std::fmt::Display;
use ndarray::Array1;
use ndarray::ArrayView1;
use crate::traits::FloatExt;
#[derive(Default, Debug, Clone, Copy)]
pub enum ImpactKernel<T: FloatExt> {
#[default]
PowerLaw,
Exponential,
Custom(fn(T) -> T),
}
impl<T: FloatExt> Display for ImpactKernel<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PowerLaw => write!(f, "Power-law"),
Self::Exponential => write!(f, "Exponential"),
Self::Custom(_) => write!(f, "Custom"),
}
}
}
impl<T: FloatExt> ImpactKernel<T> {
pub fn evaluate(&self, lag: T, g0: T, beta: T) -> T {
match self {
Self::PowerLaw => g0 * (T::one() + lag).powf(-beta),
Self::Exponential => g0 * (-beta * lag).exp(),
Self::Custom(f) => f(lag),
}
}
}
pub fn propagator_price_impact<T: FloatExt>(
signed_volumes: ArrayView1<T>,
kernel: ImpactKernel<T>,
g0: T,
beta: T,
) -> T {
let n = signed_volumes.len();
if n == 0 {
return T::zero();
}
let mut acc = T::zero();
for s in 0..n {
let lag = T::from_usize_(n - 1 - s);
let weight = kernel.evaluate(lag, g0, beta);
acc += weight * signed_volumes[s];
}
acc
}
pub fn propagator_impact_path<T: FloatExt>(
signed_volumes: ArrayView1<T>,
kernel: ImpactKernel<T>,
g0: T,
beta: T,
) -> Array1<T> {
let n = signed_volumes.len();
let mut out = Array1::<T>::zeros(n);
for t in 0..n {
let mut acc = T::zero();
for s in 0..=t {
let lag = T::from_usize_(t - s);
acc += kernel.evaluate(lag, g0, beta) * signed_volumes[s];
}
out[t] = acc;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol
}
#[test]
fn power_law_kernel_starts_at_g0() {
let k = ImpactKernel::<f64>::PowerLaw;
assert!(approx(k.evaluate(0.0, 0.5, 0.7), 0.5, 1e-12));
}
#[test]
fn exponential_kernel_decays_as_expected() {
let k = ImpactKernel::<f64>::Exponential;
let v = k.evaluate(2.0, 1.0, 0.5);
assert!(approx(v, (-1.0_f64).exp(), 1e-12));
}
#[test]
fn custom_kernel_passes_through() {
let k = ImpactKernel::<f64>::Custom(|x| x * x);
assert!(approx(k.evaluate(3.0, 0.0, 0.0), 9.0, 1e-12));
}
#[test]
fn impact_path_matches_per_step_total() {
let v = ndarray::array![1.0_f64, -1.0, 1.0, 1.0];
let path = propagator_impact_path(v.view(), ImpactKernel::PowerLaw, 1.0, 0.5);
for t in 0..v.len() {
let expected =
propagator_price_impact(v.slice(ndarray::s![..=t]), ImpactKernel::PowerLaw, 1.0, 0.5);
assert!(approx(path[t], expected, 1e-12));
}
}
#[test]
fn empty_sequence_yields_zero_impact() {
let v = Array1::<f64>::zeros(0);
let p = propagator_price_impact(v.view(), ImpactKernel::PowerLaw, 1.0, 0.5);
assert!(approx(p, 0.0, 1e-12));
}
}