do_memory_mcp/patterns/predictive/
causal.rs1use anyhow::{Result, anyhow};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tracing::{debug, info, instrument, warn};
5
6use super::forecasting::types::PredictiveConfig;
7
8pub struct CausalAnalyzer {
9 config: PredictiveConfig,
10}
11
12impl CausalAnalyzer {
13 pub fn new() -> Result<Self> {
15 Self::with_config(PredictiveConfig::default())
16 }
17
18 pub fn with_config(config: PredictiveConfig) -> Result<Self> {
20 Ok(Self { config })
21 }
22
23 #[instrument(skip(self, data))]
25 pub fn analyze_causality(&self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<CausalResult>> {
26 if !self.config.enable_causal_inference {
27 return Ok(Vec::new());
28 }
29
30 let mut results = Vec::new();
31 let variables: Vec<&String> = data.keys().collect();
32
33 info!(
34 "Analyzing causal relationships between {} variables",
35 variables.len()
36 );
37
38 let pairs: Vec<_> = variables
40 .iter()
41 .enumerate()
42 .flat_map(|(i, &var1)| variables[i + 1..].iter().map(move |&var2| (var1, var2)))
43 .collect();
44
45 for (var1, var2) in pairs {
46 if let (Some(data1), Some(data2)) = (data.get(var1), data.get(var2)) {
47 if let Some(causal_result) =
48 self.analyze_pair_causality(var1, var2, data1, data2)?
49 {
50 results.push(causal_result);
51 }
52 }
53 }
54
55 debug!("Found {} causal relationships", results.len());
56 Ok(results)
57 }
58
59 fn analyze_pair_causality(
61 &self,
62 cause: &str,
63 effect: &str,
64 cause_data: &[f64],
65 effect_data: &[f64],
66 ) -> Result<Option<CausalResult>> {
67 if cause_data.len() != effect_data.len() || cause_data.len() < 10 {
68 return Ok(None);
69 }
70
71 let correlation = self.calculate_correlation(cause_data, effect_data)?;
74
75 let max_lag = 5.min(cause_data.len() / 4);
77 let mut max_cross_corr: f64 = 0.0;
78 let mut best_lag = 0;
79
80 for lag in 1..=max_lag {
81 if let Some(cross_corr) = self.cross_correlation(cause_data, effect_data, lag) {
82 if cross_corr.abs() > max_cross_corr.abs() {
83 max_cross_corr = cross_corr;
84 best_lag = lag;
85 }
86 }
87 }
88
89 let relationship_type = if max_cross_corr.abs() > 0.7 && best_lag > 0 {
91 CausalType::Direct
92 } else if correlation.abs() > 0.5 {
93 CausalType::Indirect
94 } else if correlation.abs() < 0.2 {
95 CausalType::None
96 } else {
97 CausalType::Spurious
98 };
99
100 let n = cause_data.len() as f64;
102 let t_stat = correlation.abs() * ((n - 2.0) / (1.0 - correlation * correlation)).sqrt();
103 let p_value = 2.0 * (1.0 - Self::normal_cdf(t_stat));
104 let significant = p_value < 0.05;
105
106 let strength = correlation.abs().min(1.0);
107
108 let se = (1.0 - correlation * correlation) / (n - 2.0).sqrt();
110 let margin = 1.96 * se;
111 let confidence_interval = (
112 (correlation - margin).max(-1.0),
113 (correlation + margin).min(1.0),
114 );
115
116 Ok(Some(CausalResult {
117 cause: cause.to_string(),
118 effect: effect.to_string(),
119 strength,
120 significant,
121 relationship_type,
122 confidence_interval,
123 }))
124 }
125
126 fn calculate_correlation(&self, x: &[f64], y: &[f64]) -> Result<f64> {
128 if x.len() != y.len() {
129 return Err(anyhow!("Data lengths don't match"));
130 }
131
132 let n = x.len() as f64;
133 let sum_x: f64 = x.iter().sum();
134 let sum_y: f64 = y.iter().sum();
135 let sum_xy: f64 = x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum();
136 let sum_x2: f64 = x.iter().map(|&a| a * a).sum();
137 let sum_y2: f64 = y.iter().map(|&a| a * a).sum();
138
139 let numerator = n * sum_xy - sum_x * sum_y;
140 let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
141
142 if denominator == 0.0 {
143 Ok(0.0)
144 } else {
145 Ok(numerator / denominator)
146 }
147 }
148
149 fn cross_correlation(&self, x: &[f64], y: &[f64], lag: usize) -> Option<f64> {
151 if lag >= x.len() || lag >= y.len() {
152 return None;
153 }
154
155 let x_slice = &x[lag..];
156 let y_slice = &y[..y.len() - lag];
157
158 self.calculate_correlation(x_slice, y_slice).ok()
159 }
160
161 fn normal_cdf(x: f64) -> f64 {
163 0.5 * (1.0 + Self::erf(x / 2.0_f64.sqrt()))
164 }
165
166 fn erf(x: f64) -> f64 {
168 let sign = if x < 0.0 { -1.0 } else { 1.0 };
169 let x = x.abs();
170
171 let a1 = 0.254829592;
172 let a2 = -0.284496736;
173 let a3 = 1.421413741;
174 let a4 = -1.453152027;
175 let a5 = 1.061405429;
176 let p = 0.3275911;
177
178 let t = 1.0 / (1.0 + p * x);
179 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
180
181 sign * y
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct CausalResult {
188 pub cause: String,
190 pub effect: String,
192 pub strength: f64,
194 pub significant: bool,
196 pub relationship_type: CausalType,
198 pub confidence_interval: (f64, f64),
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204pub enum CausalType {
205 Direct,
207 Indirect,
209 Spurious,
211 None,
213}