use std::fmt::Display;
use ndarray::Array1;
use ndarray::Array2;
use ndarray::ArrayView1;
use ndarray_linalg::LeastSquaresSvd;
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PairsSignal {
#[default]
Flat,
LongSpread,
ShortSpread,
}
impl Display for PairsSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Flat => write!(f, "Flat"),
Self::LongSpread => write!(f, "Long spread"),
Self::ShortSpread => write!(f, "Short spread"),
}
}
}
#[derive(Debug, Clone)]
pub struct PairsStrategy {
pub alpha: f64,
pub beta: f64,
pub spread: Array1<f64>,
pub z_score: Array1<f64>,
pub signals: Array1<PairsSignal>,
pub spread_mean: f64,
pub spread_std: f64,
}
pub fn pairs_signals(
y: ArrayView1<f64>,
x: ArrayView1<f64>,
entry_z: f64,
exit_z: f64,
) -> PairsStrategy {
assert_eq!(y.len(), x.len(), "y and x must have equal length");
assert!(entry_z > exit_z, "entry_z must exceed exit_z");
let n = y.len();
assert!(n >= 3, "need at least three observations");
let mut design = Array2::<f64>::zeros((n, 2));
for i in 0..n {
design[[i, 0]] = 1.0;
design[[i, 1]] = x[i];
}
let y_owned = y.to_owned();
let sol = design.least_squares(&y_owned).expect(
"pairs-trading hedge-ratio OLS failed (one of the inputs may be a constant series — \
try_pairs_signals returns Result for graceful handling)",
);
let alpha = sol.solution[0];
let beta = sol.solution[1];
let mut spread = Array1::<f64>::zeros(n);
for i in 0..n {
spread[i] = y[i] - alpha - beta * x[i];
}
let mean = spread.iter().sum::<f64>() / n as f64;
let var = spread.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n as f64 - 1.0);
let sd = var.sqrt().max(1e-12);
let z = spread.mapv(|v| (v - mean) / sd);
let mut signals = Array1::<PairsSignal>::from_elem(n, PairsSignal::Flat);
let mut state = PairsSignal::Flat;
for i in 0..n {
let zi = z[i];
state = match state {
PairsSignal::Flat => {
if zi >= entry_z {
PairsSignal::ShortSpread
} else if zi <= -entry_z {
PairsSignal::LongSpread
} else {
PairsSignal::Flat
}
}
PairsSignal::LongSpread => {
if zi.abs() <= exit_z {
PairsSignal::Flat
} else {
PairsSignal::LongSpread
}
}
PairsSignal::ShortSpread => {
if zi.abs() <= exit_z {
PairsSignal::Flat
} else {
PairsSignal::ShortSpread
}
}
};
signals[i] = state;
}
PairsStrategy {
alpha,
beta,
spread,
z_score: z,
signals,
spread_mean: mean,
spread_std: sd,
}
}
#[cfg(test)]
mod tests {
use ndarray::array;
use stochastic_rs_distributions::normal::SimdNormal;
use super::*;
#[test]
fn signals_flat_when_spread_within_band() {
let y = array![1.0_f64, 1.01, 1.0, 1.02, 1.005];
let x = array![0.5_f64, 0.51, 0.5, 0.52, 0.505];
let s = pairs_signals(y.view(), x.view(), 2.0, 0.5);
assert!(s.signals.iter().all(|&v| matches!(v, PairsSignal::Flat)));
}
#[test]
fn long_spread_triggered_by_negative_extreme() {
let mut y_buf = vec![0.0_f64; 200];
let mut x_buf = vec![0.0_f64; 200];
let dist = SimdNormal::<f64>::with_seed(0.0, 0.01, 1);
let mut shocks = vec![0.0_f64; 200];
dist.fill_slice_fast(&mut shocks);
for i in 0..200 {
x_buf[i] = 100.0 + 0.05 * i as f64;
y_buf[i] = 2.0 * x_buf[i] + 1.0 + shocks[i];
}
y_buf[150] -= 5.0;
let y = Array1::from(y_buf);
let x = Array1::from(x_buf);
let s = pairs_signals(y.view(), x.view(), 1.5, 0.5);
assert!(matches!(s.signals[150], PairsSignal::LongSpread));
}
#[test]
fn beta_recovered_from_linear_relationship() {
let dist = SimdNormal::<f64>::with_seed(0.0, 0.005, 7);
let mut shocks = vec![0.0_f64; 500];
dist.fill_slice_fast(&mut shocks);
let mut x_buf = vec![0.0_f64; 500];
let mut y_buf = vec![0.0_f64; 500];
for i in 0..500 {
x_buf[i] = (i as f64) * 0.01;
y_buf[i] = 0.3 + 1.7 * x_buf[i] + shocks[i];
}
let y = Array1::from(y_buf);
let x = Array1::from(x_buf);
let s = pairs_signals(y.view(), x.view(), 2.0, 0.5);
assert!((s.beta - 1.7).abs() < 0.01);
assert!((s.alpha - 0.3).abs() < 0.05);
}
}