exo_temporal/
causal.rs

1//! Causal graph for tracking antecedent relationships
2
3use crate::types::{PatternId, SubstrateTime};
4use dashmap::DashMap;
5use petgraph::graph::{DiGraph, NodeIndex};
6use petgraph::algo::dijkstra;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Type of causal cone for queries
12#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum CausalConeType {
14    /// Past light cone (all events that could have influenced reference)
15    Past,
16    /// Future light cone (all events that reference could influence)
17    Future,
18    /// Relativistic light cone with velocity constraint
19    LightCone {
20        /// Velocity of causal influence (fraction of c)
21        velocity: f32,
22    },
23}
24
25/// Causal graph tracking antecedent relationships
26pub struct CausalGraph {
27    /// Forward edges: cause -> effects
28    forward: DashMap<PatternId, Vec<PatternId>>,
29    /// Backward edges: effect -> causes
30    backward: DashMap<PatternId, Vec<PatternId>>,
31    /// Pattern timestamps for light cone calculations
32    timestamps: DashMap<PatternId, SubstrateTime>,
33    /// Cached graph representation for path finding
34    graph_cache: Arc<parking_lot::RwLock<Option<(DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>)>>>,
35}
36
37impl CausalGraph {
38    /// Create new causal graph
39    pub fn new() -> Self {
40        Self {
41            forward: DashMap::new(),
42            backward: DashMap::new(),
43            timestamps: DashMap::new(),
44            graph_cache: Arc::new(parking_lot::RwLock::new(None)),
45        }
46    }
47
48    /// Add causal edge: cause -> effect
49    pub fn add_edge(&self, cause: PatternId, effect: PatternId) {
50        // Add to forward edges
51        self.forward
52            .entry(cause)
53            .or_insert_with(Vec::new)
54            .push(effect);
55
56        // Add to backward edges
57        self.backward
58            .entry(effect)
59            .or_insert_with(Vec::new)
60            .push(cause);
61
62        // Invalidate cache
63        *self.graph_cache.write() = None;
64    }
65
66    /// Add pattern with timestamp
67    pub fn add_pattern(&self, id: PatternId, timestamp: SubstrateTime) {
68        self.timestamps.insert(id, timestamp);
69    }
70
71    /// Get direct causes of a pattern
72    pub fn causes(&self, pattern: PatternId) -> Vec<PatternId> {
73        self.backward
74            .get(&pattern)
75            .map(|v| v.clone())
76            .unwrap_or_default()
77    }
78
79    /// Get direct effects of a pattern
80    pub fn effects(&self, pattern: PatternId) -> Vec<PatternId> {
81        self.forward
82            .get(&pattern)
83            .map(|v| v.clone())
84            .unwrap_or_default()
85    }
86
87    /// Get out-degree (number of effects)
88    pub fn out_degree(&self, pattern: PatternId) -> usize {
89        self.forward
90            .get(&pattern)
91            .map(|v| v.len())
92            .unwrap_or(0)
93    }
94
95    /// Get in-degree (number of causes)
96    pub fn in_degree(&self, pattern: PatternId) -> usize {
97        self.backward
98            .get(&pattern)
99            .map(|v| v.len())
100            .unwrap_or(0)
101    }
102
103    /// Compute shortest path distance between two patterns
104    pub fn distance(&self, from: PatternId, to: PatternId) -> Option<usize> {
105        if from == to {
106            return Some(0);
107        }
108
109        // Build or retrieve cached graph
110        let (graph, node_map) = {
111            let cache = self.graph_cache.read();
112            if let Some((g, m)) = cache.as_ref() {
113                (g.clone(), m.clone())
114            } else {
115                drop(cache);
116                let (g, m) = self.build_graph();
117                *self.graph_cache.write() = Some((g.clone(), m.clone()));
118                (g, m)
119            }
120        };
121
122        // Get node indices
123        let from_idx = *node_map.get(&from)?;
124        let to_idx = *node_map.get(&to)?;
125
126        // Run Dijkstra's algorithm
127        let distances = dijkstra(&graph, from_idx, Some(to_idx), |_| 1);
128
129        distances.get(&to_idx).copied()
130    }
131
132    /// Build petgraph representation for path finding
133    fn build_graph(&self) -> (DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>) {
134        let mut graph = DiGraph::new();
135        let mut node_map = HashMap::new();
136
137        // Add all nodes
138        for entry in self.forward.iter() {
139            let id = *entry.key();
140            if !node_map.contains_key(&id) {
141                let idx = graph.add_node(id);
142                node_map.insert(id, idx);
143            }
144
145            for &effect in entry.value() {
146                if !node_map.contains_key(&effect) {
147                    let idx = graph.add_node(effect);
148                    node_map.insert(effect, idx);
149                }
150            }
151        }
152
153        // Add edges
154        for entry in self.forward.iter() {
155            let from = *entry.key();
156            let from_idx = node_map[&from];
157
158            for &to in entry.value() {
159                let to_idx = node_map[&to];
160                graph.add_edge(from_idx, to_idx, ());
161            }
162        }
163
164        (graph, node_map)
165    }
166
167    /// Get all patterns in causal past
168    pub fn causal_past(&self, pattern: PatternId) -> Vec<PatternId> {
169        let mut result = Vec::new();
170        let mut visited = std::collections::HashSet::new();
171        let mut stack = vec![pattern];
172
173        while let Some(current) = stack.pop() {
174            if visited.contains(&current) {
175                continue;
176            }
177            visited.insert(current);
178
179            if let Some(causes) = self.backward.get(&current) {
180                for &cause in causes.iter() {
181                    if !visited.contains(&cause) {
182                        stack.push(cause);
183                        result.push(cause);
184                    }
185                }
186            }
187        }
188
189        result
190    }
191
192    /// Get all patterns in causal future
193    pub fn causal_future(&self, pattern: PatternId) -> Vec<PatternId> {
194        let mut result = Vec::new();
195        let mut visited = std::collections::HashSet::new();
196        let mut stack = vec![pattern];
197
198        while let Some(current) = stack.pop() {
199            if visited.contains(&current) {
200                continue;
201            }
202            visited.insert(current);
203
204            if let Some(effects) = self.forward.get(&current) {
205                for &effect in effects.iter() {
206                    if !visited.contains(&effect) {
207                        stack.push(effect);
208                        result.push(effect);
209                    }
210                }
211            }
212        }
213
214        result
215    }
216
217    /// Filter patterns by light cone constraint
218    pub fn filter_by_light_cone(
219        &self,
220        reference: PatternId,
221        reference_time: SubstrateTime,
222        cone_type: CausalConeType,
223        candidates: &[PatternId],
224    ) -> Vec<PatternId> {
225        candidates
226            .iter()
227            .filter(|&&id| {
228                self.is_in_light_cone(id, reference, reference_time, cone_type)
229            })
230            .copied()
231            .collect()
232    }
233
234    /// Check if pattern is within light cone
235    fn is_in_light_cone(
236        &self,
237        pattern: PatternId,
238        _reference: PatternId,
239        reference_time: SubstrateTime,
240        cone_type: CausalConeType,
241    ) -> bool {
242        let pattern_time = match self.timestamps.get(&pattern) {
243            Some(t) => *t,
244            None => return false,
245        };
246
247        match cone_type {
248            CausalConeType::Past => pattern_time <= reference_time,
249            CausalConeType::Future => pattern_time >= reference_time,
250            CausalConeType::LightCone { velocity: _ } => {
251                // Simplified relativistic constraint
252                // In full implementation, would include spatial distance
253                let time_diff = (reference_time - pattern_time).abs();
254                let time_diff_secs = (time_diff.0 / 1_000_000_000).abs() as f32;
255
256                // For now, just use temporal constraint
257                // In full version: spatial_distance <= velocity * time_diff
258                time_diff_secs >= 0.0 // Always true for temporal-only check
259            }
260        }
261    }
262
263    /// Get statistics about the causal graph
264    pub fn stats(&self) -> CausalGraphStats {
265        let num_nodes = self.timestamps.len();
266        let num_edges: usize = self.forward.iter().map(|e| e.value().len()).sum();
267
268        let avg_out_degree = if num_nodes > 0 {
269            num_edges as f32 / num_nodes as f32
270        } else {
271            0.0
272        };
273
274        CausalGraphStats {
275            num_nodes,
276            num_edges,
277            avg_out_degree,
278        }
279    }
280}
281
282impl Default for CausalGraph {
283    fn default() -> Self {
284        Self::new()
285    }
286}
287
288/// Statistics about causal graph
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct CausalGraphStats {
291    /// Number of nodes
292    pub num_nodes: usize,
293    /// Number of edges
294    pub num_edges: usize,
295    /// Average out-degree
296    pub avg_out_degree: f32,
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_causal_graph_basic() {
305        let graph = CausalGraph::new();
306
307        let p1 = PatternId::new();
308        let p2 = PatternId::new();
309        let p3 = PatternId::new();
310
311        let t1 = SubstrateTime::now();
312        let t2 = SubstrateTime::now();
313        let t3 = SubstrateTime::now();
314
315        graph.add_pattern(p1, t1);
316        graph.add_pattern(p2, t2);
317        graph.add_pattern(p3, t3);
318
319        // p1 -> p2 -> p3
320        graph.add_edge(p1, p2);
321        graph.add_edge(p2, p3);
322
323        assert_eq!(graph.out_degree(p1), 1);
324        assert_eq!(graph.in_degree(p2), 1);
325        assert_eq!(graph.distance(p1, p3), Some(2));
326
327        let past = graph.causal_past(p3);
328        assert!(past.contains(&p1));
329        assert!(past.contains(&p2));
330    }
331}