Skip to main content

do_memory_mcp/patterns/predictive/
causal.rs

1use anyhow::{Result, anyhow};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tracing::{debug, info, instrument, warn};
5
6use super::forecasting::types::PredictiveConfig;
7
8pub struct CausalAnalyzer {
9    config: PredictiveConfig,
10}
11
12impl CausalAnalyzer {
13    /// Create a new causal analyzer
14    pub fn new() -> Result<Self> {
15        Self::with_config(PredictiveConfig::default())
16    }
17
18    /// Create a new causal analyzer with custom config
19    pub fn with_config(config: PredictiveConfig) -> Result<Self> {
20        Ok(Self { config })
21    }
22
23    /// Analyze causal relationships between variables
24    #[instrument(skip(self, data))]
25    pub fn analyze_causality(&self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<CausalResult>> {
26        if !self.config.enable_causal_inference {
27            return Ok(Vec::new());
28        }
29
30        let mut results = Vec::new();
31        let variables: Vec<&String> = data.keys().collect();
32
33        info!(
34            "Analyzing causal relationships between {} variables",
35            variables.len()
36        );
37
38        // Analyze pairwise causal relationships
39        let pairs: Vec<_> = variables
40            .iter()
41            .enumerate()
42            .flat_map(|(i, &var1)| variables[i + 1..].iter().map(move |&var2| (var1, var2)))
43            .collect();
44
45        for (var1, var2) in pairs {
46            if let (Some(data1), Some(data2)) = (data.get(var1), data.get(var2)) {
47                if let Some(causal_result) =
48                    self.analyze_pair_causality(var1, var2, data1, data2)?
49                {
50                    results.push(causal_result);
51                }
52            }
53        }
54
55        debug!("Found {} causal relationships", results.len());
56        Ok(results)
57    }
58
59    /// Analyze causality between a pair of variables
60    fn analyze_pair_causality(
61        &self,
62        cause: &str,
63        effect: &str,
64        cause_data: &[f64],
65        effect_data: &[f64],
66    ) -> Result<Option<CausalResult>> {
67        if cause_data.len() != effect_data.len() || cause_data.len() < 10 {
68            return Ok(None);
69        }
70
71        // Simplified Granger causality test
72        // In practice, you'd use proper time series causality tests
73        let correlation = self.calculate_correlation(cause_data, effect_data)?;
74
75        // Calculate cross-correlation at different lags
76        let max_lag = 5.min(cause_data.len() / 4);
77        let mut max_cross_corr: f64 = 0.0;
78        let mut best_lag = 0;
79
80        for lag in 1..=max_lag {
81            if let Some(cross_corr) = self.cross_correlation(cause_data, effect_data, lag) {
82                if cross_corr.abs() > max_cross_corr.abs() {
83                    max_cross_corr = cross_corr;
84                    best_lag = lag;
85                }
86            }
87        }
88
89        // Determine causal relationship type
90        let relationship_type = if max_cross_corr.abs() > 0.7 && best_lag > 0 {
91            CausalType::Direct
92        } else if correlation.abs() > 0.5 {
93            CausalType::Indirect
94        } else if correlation.abs() < 0.2 {
95            CausalType::None
96        } else {
97            CausalType::Spurious
98        };
99
100        // Calculate significance (simplified)
101        let n = cause_data.len() as f64;
102        let t_stat = correlation.abs() * ((n - 2.0) / (1.0 - correlation * correlation)).sqrt();
103        let p_value = 2.0 * (1.0 - Self::normal_cdf(t_stat));
104        let significant = p_value < 0.05;
105
106        let strength = correlation.abs().min(1.0);
107
108        // Confidence interval (simplified)
109        let se = (1.0 - correlation * correlation) / (n - 2.0).sqrt();
110        let margin = 1.96 * se;
111        let confidence_interval = (
112            (correlation - margin).max(-1.0),
113            (correlation + margin).min(1.0),
114        );
115
116        Ok(Some(CausalResult {
117            cause: cause.to_string(),
118            effect: effect.to_string(),
119            strength,
120            significant,
121            relationship_type,
122            confidence_interval,
123        }))
124    }
125
126    /// Calculate Pearson correlation
127    fn calculate_correlation(&self, x: &[f64], y: &[f64]) -> Result<f64> {
128        if x.len() != y.len() {
129            return Err(anyhow!("Data lengths don't match"));
130        }
131
132        let n = x.len() as f64;
133        let sum_x: f64 = x.iter().sum();
134        let sum_y: f64 = y.iter().sum();
135        let sum_xy: f64 = x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum();
136        let sum_x2: f64 = x.iter().map(|&a| a * a).sum();
137        let sum_y2: f64 = y.iter().map(|&a| a * a).sum();
138
139        let numerator = n * sum_xy - sum_x * sum_y;
140        let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
141
142        if denominator == 0.0 {
143            Ok(0.0)
144        } else {
145            Ok(numerator / denominator)
146        }
147    }
148
149    /// Calculate cross-correlation at a specific lag
150    fn cross_correlation(&self, x: &[f64], y: &[f64], lag: usize) -> Option<f64> {
151        if lag >= x.len() || lag >= y.len() {
152            return None;
153        }
154
155        let x_slice = &x[lag..];
156        let y_slice = &y[..y.len() - lag];
157
158        self.calculate_correlation(x_slice, y_slice).ok()
159    }
160
161    /// Normal cumulative distribution function
162    fn normal_cdf(x: f64) -> f64 {
163        0.5 * (1.0 + Self::erf(x / 2.0_f64.sqrt()))
164    }
165
166    /// Error function
167    fn erf(x: f64) -> f64 {
168        let sign = if x < 0.0 { -1.0 } else { 1.0 };
169        let x = x.abs();
170
171        let a1 = 0.254829592;
172        let a2 = -0.284496736;
173        let a3 = 1.421413741;
174        let a4 = -1.453152027;
175        let a5 = 1.061405429;
176        let p = 0.3275911;
177
178        let t = 1.0 / (1.0 + p * x);
179        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
180
181        sign * y
182    }
183}
184
185/// Causal inference results
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct CausalResult {
188    /// Cause variable
189    pub cause: String,
190    /// Effect variable
191    pub effect: String,
192    /// Causal strength (0.0 to 1.0)
193    pub strength: f64,
194    /// Statistical significance
195    pub significant: bool,
196    /// Causal relationship type
197    pub relationship_type: CausalType,
198    /// Confidence interval
199    pub confidence_interval: (f64, f64),
200}
201
202/// Types of causal relationships
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub enum CausalType {
205    /// Direct causation
206    Direct,
207    /// Indirect causation through mediators
208    Indirect,
209    /// Spurious correlation
210    Spurious,
211    /// No causal relationship
212    None,
213}