Skip to main content

exo_temporal/
causal.rs

1//! Causal graph for tracking antecedent relationships
2
3use crate::types::{PatternId, SubstrateTime};
4use dashmap::DashMap;
5use petgraph::algo::dijkstra;
6use petgraph::graph::{DiGraph, NodeIndex};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Type of causal cone for queries
12#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum CausalConeType {
14    /// Past light cone (all events that could have influenced reference)
15    Past,
16    /// Future light cone (all events that reference could influence)
17    Future,
18    /// Relativistic light cone with velocity constraint
19    LightCone {
20        /// Velocity of causal influence (fraction of c)
21        velocity: f32,
22    },
23}
24
25/// Causal graph tracking antecedent relationships
26pub struct CausalGraph {
27    /// Forward edges: cause -> effects
28    forward: DashMap<PatternId, Vec<PatternId>>,
29    /// Backward edges: effect -> causes
30    backward: DashMap<PatternId, Vec<PatternId>>,
31    /// Pattern timestamps for light cone calculations
32    timestamps: DashMap<PatternId, SubstrateTime>,
33    /// Cached graph representation for path finding
34    graph_cache:
35        Arc<parking_lot::RwLock<Option<(DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>)>>>,
36}
37
38impl CausalGraph {
39    /// Create new causal graph
40    pub fn new() -> Self {
41        Self {
42            forward: DashMap::new(),
43            backward: DashMap::new(),
44            timestamps: DashMap::new(),
45            graph_cache: Arc::new(parking_lot::RwLock::new(None)),
46        }
47    }
48
49    /// Add causal edge: cause -> effect
50    pub fn add_edge(&self, cause: PatternId, effect: PatternId) {
51        // Add to forward edges
52        self.forward
53            .entry(cause)
54            .or_insert_with(Vec::new)
55            .push(effect);
56
57        // Add to backward edges
58        self.backward
59            .entry(effect)
60            .or_insert_with(Vec::new)
61            .push(cause);
62
63        // Invalidate cache
64        *self.graph_cache.write() = None;
65    }
66
67    /// Add pattern with timestamp
68    pub fn add_pattern(&self, id: PatternId, timestamp: SubstrateTime) {
69        self.timestamps.insert(id, timestamp);
70    }
71
72    /// Get direct causes of a pattern
73    pub fn causes(&self, pattern: PatternId) -> Vec<PatternId> {
74        self.backward
75            .get(&pattern)
76            .map(|v| v.clone())
77            .unwrap_or_default()
78    }
79
80    /// Get direct effects of a pattern
81    pub fn effects(&self, pattern: PatternId) -> Vec<PatternId> {
82        self.forward
83            .get(&pattern)
84            .map(|v| v.clone())
85            .unwrap_or_default()
86    }
87
88    /// Get out-degree (number of effects)
89    pub fn out_degree(&self, pattern: PatternId) -> usize {
90        self.forward.get(&pattern).map(|v| v.len()).unwrap_or(0)
91    }
92
93    /// Get in-degree (number of causes)
94    pub fn in_degree(&self, pattern: PatternId) -> usize {
95        self.backward.get(&pattern).map(|v| v.len()).unwrap_or(0)
96    }
97
98    /// Compute shortest path distance between two patterns
99    pub fn distance(&self, from: PatternId, to: PatternId) -> Option<usize> {
100        if from == to {
101            return Some(0);
102        }
103
104        // Build or retrieve cached graph
105        let (graph, node_map) = {
106            let cache = self.graph_cache.read();
107            if let Some((g, m)) = cache.as_ref() {
108                (g.clone(), m.clone())
109            } else {
110                drop(cache);
111                let (g, m) = self.build_graph();
112                *self.graph_cache.write() = Some((g.clone(), m.clone()));
113                (g, m)
114            }
115        };
116
117        // Get node indices
118        let from_idx = *node_map.get(&from)?;
119        let to_idx = *node_map.get(&to)?;
120
121        // Run Dijkstra's algorithm
122        let distances = dijkstra(&graph, from_idx, Some(to_idx), |_| 1);
123
124        distances.get(&to_idx).copied()
125    }
126
127    /// Build petgraph representation for path finding
128    fn build_graph(&self) -> (DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>) {
129        let mut graph = DiGraph::new();
130        let mut node_map = HashMap::new();
131
132        // Add all nodes
133        for entry in self.forward.iter() {
134            let id = *entry.key();
135            if !node_map.contains_key(&id) {
136                let idx = graph.add_node(id);
137                node_map.insert(id, idx);
138            }
139
140            for &effect in entry.value() {
141                if !node_map.contains_key(&effect) {
142                    let idx = graph.add_node(effect);
143                    node_map.insert(effect, idx);
144                }
145            }
146        }
147
148        // Add edges
149        for entry in self.forward.iter() {
150            let from = *entry.key();
151            let from_idx = node_map[&from];
152
153            for &to in entry.value() {
154                let to_idx = node_map[&to];
155                graph.add_edge(from_idx, to_idx, ());
156            }
157        }
158
159        (graph, node_map)
160    }
161
162    /// Get all patterns in causal past
163    pub fn causal_past(&self, pattern: PatternId) -> Vec<PatternId> {
164        let mut result = Vec::new();
165        let mut visited = std::collections::HashSet::new();
166        let mut stack = vec![pattern];
167
168        while let Some(current) = stack.pop() {
169            if visited.contains(&current) {
170                continue;
171            }
172            visited.insert(current);
173
174            if let Some(causes) = self.backward.get(&current) {
175                for &cause in causes.iter() {
176                    if !visited.contains(&cause) {
177                        stack.push(cause);
178                        result.push(cause);
179                    }
180                }
181            }
182        }
183
184        result
185    }
186
187    /// Get all patterns in causal future
188    pub fn causal_future(&self, pattern: PatternId) -> Vec<PatternId> {
189        let mut result = Vec::new();
190        let mut visited = std::collections::HashSet::new();
191        let mut stack = vec![pattern];
192
193        while let Some(current) = stack.pop() {
194            if visited.contains(&current) {
195                continue;
196            }
197            visited.insert(current);
198
199            if let Some(effects) = self.forward.get(&current) {
200                for &effect in effects.iter() {
201                    if !visited.contains(&effect) {
202                        stack.push(effect);
203                        result.push(effect);
204                    }
205                }
206            }
207        }
208
209        result
210    }
211
212    /// Filter patterns by light cone constraint
213    pub fn filter_by_light_cone(
214        &self,
215        reference: PatternId,
216        reference_time: SubstrateTime,
217        cone_type: CausalConeType,
218        candidates: &[PatternId],
219    ) -> Vec<PatternId> {
220        candidates
221            .iter()
222            .filter(|&&id| self.is_in_light_cone(id, reference, reference_time, cone_type))
223            .copied()
224            .collect()
225    }
226
227    /// Check if pattern is within light cone
228    fn is_in_light_cone(
229        &self,
230        pattern: PatternId,
231        _reference: PatternId,
232        reference_time: SubstrateTime,
233        cone_type: CausalConeType,
234    ) -> bool {
235        let pattern_time = match self.timestamps.get(&pattern) {
236            Some(t) => *t,
237            None => return false,
238        };
239
240        match cone_type {
241            CausalConeType::Past => pattern_time <= reference_time,
242            CausalConeType::Future => pattern_time >= reference_time,
243            CausalConeType::LightCone { velocity: _ } => {
244                // Simplified relativistic constraint
245                // In full implementation, would include spatial distance
246                let time_diff = (reference_time - pattern_time).abs();
247                let time_diff_secs = (time_diff.0 / 1_000_000_000).abs() as f32;
248
249                // For now, just use temporal constraint
250                // In full version: spatial_distance <= velocity * time_diff
251                time_diff_secs >= 0.0 // Always true for temporal-only check
252            }
253        }
254    }
255
256    /// Get statistics about the causal graph
257    pub fn stats(&self) -> CausalGraphStats {
258        let num_nodes = self.timestamps.len();
259        let num_edges: usize = self.forward.iter().map(|e| e.value().len()).sum();
260
261        let avg_out_degree = if num_nodes > 0 {
262            num_edges as f32 / num_nodes as f32
263        } else {
264            0.0
265        };
266
267        CausalGraphStats {
268            num_nodes,
269            num_edges,
270            avg_out_degree,
271        }
272    }
273}
274
275impl Default for CausalGraph {
276    fn default() -> Self {
277        Self::new()
278    }
279}
280
281/// Statistics about causal graph
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct CausalGraphStats {
284    /// Number of nodes
285    pub num_nodes: usize,
286    /// Number of edges
287    pub num_edges: usize,
288    /// Average out-degree
289    pub avg_out_degree: f32,
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_causal_graph_basic() {
298        let graph = CausalGraph::new();
299
300        let p1 = PatternId::new();
301        let p2 = PatternId::new();
302        let p3 = PatternId::new();
303
304        let t1 = SubstrateTime::now();
305        let t2 = SubstrateTime::now();
306        let t3 = SubstrateTime::now();
307
308        graph.add_pattern(p1, t1);
309        graph.add_pattern(p2, t2);
310        graph.add_pattern(p3, t3);
311
312        // p1 -> p2 -> p3
313        graph.add_edge(p1, p2);
314        graph.add_edge(p2, p3);
315
316        assert_eq!(graph.out_degree(p1), 1);
317        assert_eq!(graph.in_degree(p2), 1);
318        assert_eq!(graph.distance(p1, p3), Some(2));
319
320        let past = graph.causal_past(p3);
321        assert!(past.contains(&p1));
322        assert!(past.contains(&p2));
323    }
324}