quant_metrics/cointegration/
mod.rs1use rust_decimal::Decimal;
13use thiserror::Error;
14
15pub const MIN_OBSERVATIONS: usize = 30;
17
18#[derive(Debug, Error)]
20pub enum CointegrationMathError {
21 #[error("insufficient data: need at least {minimum} observations, got {actual}")]
22 InsufficientData { actual: usize, minimum: usize },
23
24 #[error("series length mismatch: {len_a} vs {len_b}")]
25 LengthMismatch { len_a: usize, len_b: usize },
26
27 #[error("degenerate data: {0}")]
28 DegenerateData(String),
29}
30
31#[derive(Debug, Clone)]
33pub struct EngleGrangerResult {
34 pub adf_statistic: f64,
36
37 pub p_value: f64,
40
41 pub beta: f64,
43
44 pub alpha: f64,
46
47 pub half_life: Option<f64>,
51
52 pub correlation: f64,
54
55 pub spread_mean: f64,
57
58 pub spread_std: f64,
60}
61
62pub fn engle_granger(
67 prices_a: &[f64],
68 prices_b: &[f64],
69) -> Result<EngleGrangerResult, CointegrationMathError> {
70 let n = prices_a.len();
71
72 if n != prices_b.len() {
73 return Err(CointegrationMathError::LengthMismatch {
74 len_a: n,
75 len_b: prices_b.len(),
76 });
77 }
78
79 if n < MIN_OBSERVATIONS {
80 return Err(CointegrationMathError::InsufficientData {
81 actual: n,
82 minimum: MIN_OBSERVATIONS,
83 });
84 }
85
86 let (alpha, beta) = ols_regression(prices_a, prices_b)?;
88
89 let residuals: Vec<f64> = prices_b
91 .iter()
92 .zip(prices_a.iter())
93 .map(|(y, x)| y - alpha - beta * x)
94 .collect();
95
96 let spread_mean = mean(&residuals);
97 let spread_std = std_dev(&residuals, spread_mean);
98
99 if spread_std < 1e-15 {
100 return Err(CointegrationMathError::DegenerateData(
101 "spread has zero variance".to_string(),
102 ));
103 }
104
105 let adf_statistic = adf_test_statistic(&residuals)?;
107
108 let p_value = adf_p_value(adf_statistic, n);
110
111 let half_life = ornstein_uhlenbeck_half_life(&residuals);
113
114 let correlation = pearson_correlation(prices_a, prices_b)?;
116
117 Ok(EngleGrangerResult {
118 adf_statistic,
119 p_value,
120 beta,
121 alpha,
122 half_life,
123 correlation,
124 spread_mean,
125 spread_std,
126 })
127}
128
129pub fn spread_stats_to_decimal(mean: f64, std: f64) -> (Decimal, Decimal) {
131 let ratio_mean = Decimal::from_f64_retain(mean).unwrap_or(Decimal::ZERO);
132 let ratio_std = Decimal::from_f64_retain(std).unwrap_or(Decimal::ZERO);
133 (ratio_mean, ratio_std)
134}
135
136fn ols_regression(x: &[f64], y: &[f64]) -> Result<(f64, f64), CointegrationMathError> {
143 let n = x.len() as f64;
144 let mean_x = x.iter().sum::<f64>() / n;
145 let mean_y = y.iter().sum::<f64>() / n;
146
147 let mut sum_dx_dy: f64 = 0.0;
148 let mut sum_dx2: f64 = 0.0;
149
150 for (xi, yi) in x.iter().zip(y.iter()) {
151 let dx = xi - mean_x;
152 let dy = yi - mean_y;
153 sum_dx_dy += dx * dy;
154 sum_dx2 += dx * dx;
155 }
156
157 if sum_dx2 < 1e-15 {
158 return Err(CointegrationMathError::DegenerateData(
159 "x series has zero variance".to_string(),
160 ));
161 }
162
163 let beta = sum_dx_dy / sum_dx2;
164 let alpha = mean_y - beta * mean_x;
165
166 Ok((alpha, beta))
167}
168
169fn adf_test_statistic(residuals: &[f64]) -> Result<f64, CointegrationMathError> {
179 let n = residuals.len();
180 if n < 3 {
181 return Err(CointegrationMathError::InsufficientData {
182 actual: n,
183 minimum: 3,
184 });
185 }
186
187 let mut sum_xy: f64 = 0.0;
190 let mut sum_x2: f64 = 0.0;
191 let mut sum_e2: f64 = 0.0;
192
193 for t in 1..n {
194 let x = residuals[t - 1];
195 let y = residuals[t] - residuals[t - 1];
196 sum_xy += x * y;
197 sum_x2 += x * x;
198 }
199
200 if sum_x2 < 1e-15 {
201 return Err(CointegrationMathError::DegenerateData(
202 "lagged residuals have zero variance".to_string(),
203 ));
204 }
205
206 let theta = sum_xy / sum_x2;
207
208 for t in 1..n {
210 let x = residuals[t - 1];
211 let y = residuals[t] - residuals[t - 1];
212 let e = y - theta * x;
213 sum_e2 += e * e;
214 }
215
216 let m = (n - 1) as f64; let sigma2 = sum_e2 / (m - 1.0); let se_theta = (sigma2 / sum_x2).sqrt();
219
220 if se_theta < 1e-15 {
221 return Err(CointegrationMathError::DegenerateData(
222 "standard error of theta is zero".to_string(),
223 ));
224 }
225
226 Ok(theta / se_theta)
227}
228
229fn adf_p_value(t_stat: f64, _n: usize) -> f64 {
238 if t_stat <= -3.90 {
240 0.01
241 } else if t_stat <= -3.34 {
242 0.01 + (0.04) * (t_stat - (-3.90)) / ((-3.34) - (-3.90))
244 } else if t_stat <= -3.04 {
245 0.05 + (0.05) * (t_stat - (-3.34)) / ((-3.04) - (-3.34))
247 } else if t_stat <= -2.50 {
248 0.10 + (0.15) * (t_stat - (-3.04)) / ((-2.50) - (-3.04))
250 } else {
251 0.50_f64
253 .min(0.25 + (0.25) * (t_stat - (-2.50)) / ((-1.50) - (-2.50)))
254 .max(0.25)
255 }
256}
257
258fn ornstein_uhlenbeck_half_life(residuals: &[f64]) -> Option<f64> {
268 let n = residuals.len();
269 if n < 3 {
270 return None;
271 }
272
273 let mut sum_xy: f64 = 0.0;
275 let mut sum_x2: f64 = 0.0;
276
277 for t in 1..n {
278 sum_xy += residuals[t] * residuals[t - 1];
279 sum_x2 += residuals[t - 1] * residuals[t - 1];
280 }
281
282 if sum_x2 < 1e-15 {
283 return None;
284 }
285
286 let phi = sum_xy / sum_x2;
287
288 if phi >= 1.0 || phi <= 0.0 {
289 return None; }
291
292 let half_life = -f64::ln(2.0) / f64::ln(phi);
293
294 if half_life.is_finite() && half_life > 0.0 {
295 Some(half_life)
296 } else {
297 None
298 }
299}
300
301pub fn pearson_correlation(x: &[f64], y: &[f64]) -> Result<f64, CointegrationMathError> {
305 let n = x.len() as f64;
306 let mean_x = mean(x);
307 let mean_y = mean(y);
308
309 let mut sum_xy: f64 = 0.0;
310 let mut sum_x2: f64 = 0.0;
311 let mut sum_y2: f64 = 0.0;
312
313 for (xi, yi) in x.iter().zip(y.iter()) {
314 let dx = xi - mean_x;
315 let dy = yi - mean_y;
316 sum_xy += dx * dy;
317 sum_x2 += dx * dx;
318 sum_y2 += dy * dy;
319 }
320
321 let denom = (sum_x2 * sum_y2).sqrt();
322 if denom < 1e-15 * n {
323 return Ok(0.0);
324 }
325
326 Ok(sum_xy / denom)
327}
328
329fn mean(data: &[f64]) -> f64 {
332 if data.is_empty() {
333 return 0.0;
334 }
335 data.iter().sum::<f64>() / data.len() as f64
336}
337
338fn std_dev(data: &[f64], m: f64) -> f64 {
339 if data.len() < 2 {
340 return 0.0;
341 }
342 let variance = data.iter().map(|x| (x - m).powi(2)).sum::<f64>() / (data.len() - 1) as f64;
343 variance.sqrt()
344}
345
346#[cfg(test)]
347mod tests;