1use crate::types::{PatternId, SubstrateTime};
4use dashmap::DashMap;
5use petgraph::algo::dijkstra;
6use petgraph::graph::{DiGraph, NodeIndex};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum CausalConeType {
14 Past,
16 Future,
18 LightCone {
20 velocity: f32,
22 },
23}
24
25pub struct CausalGraph {
27 forward: DashMap<PatternId, Vec<PatternId>>,
29 backward: DashMap<PatternId, Vec<PatternId>>,
31 timestamps: DashMap<PatternId, SubstrateTime>,
33 graph_cache:
35 Arc<parking_lot::RwLock<Option<(DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>)>>>,
36}
37
38impl CausalGraph {
39 pub fn new() -> Self {
41 Self {
42 forward: DashMap::new(),
43 backward: DashMap::new(),
44 timestamps: DashMap::new(),
45 graph_cache: Arc::new(parking_lot::RwLock::new(None)),
46 }
47 }
48
49 pub fn add_edge(&self, cause: PatternId, effect: PatternId) {
51 self.forward
53 .entry(cause)
54 .or_insert_with(Vec::new)
55 .push(effect);
56
57 self.backward
59 .entry(effect)
60 .or_insert_with(Vec::new)
61 .push(cause);
62
63 *self.graph_cache.write() = None;
65 }
66
67 pub fn add_pattern(&self, id: PatternId, timestamp: SubstrateTime) {
69 self.timestamps.insert(id, timestamp);
70 }
71
72 pub fn causes(&self, pattern: PatternId) -> Vec<PatternId> {
74 self.backward
75 .get(&pattern)
76 .map(|v| v.clone())
77 .unwrap_or_default()
78 }
79
80 pub fn effects(&self, pattern: PatternId) -> Vec<PatternId> {
82 self.forward
83 .get(&pattern)
84 .map(|v| v.clone())
85 .unwrap_or_default()
86 }
87
88 pub fn out_degree(&self, pattern: PatternId) -> usize {
90 self.forward.get(&pattern).map(|v| v.len()).unwrap_or(0)
91 }
92
93 pub fn in_degree(&self, pattern: PatternId) -> usize {
95 self.backward.get(&pattern).map(|v| v.len()).unwrap_or(0)
96 }
97
98 pub fn distance(&self, from: PatternId, to: PatternId) -> Option<usize> {
100 if from == to {
101 return Some(0);
102 }
103
104 let (graph, node_map) = {
106 let cache = self.graph_cache.read();
107 if let Some((g, m)) = cache.as_ref() {
108 (g.clone(), m.clone())
109 } else {
110 drop(cache);
111 let (g, m) = self.build_graph();
112 *self.graph_cache.write() = Some((g.clone(), m.clone()));
113 (g, m)
114 }
115 };
116
117 let from_idx = *node_map.get(&from)?;
119 let to_idx = *node_map.get(&to)?;
120
121 let distances = dijkstra(&graph, from_idx, Some(to_idx), |_| 1);
123
124 distances.get(&to_idx).copied()
125 }
126
127 fn build_graph(&self) -> (DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>) {
129 let mut graph = DiGraph::new();
130 let mut node_map = HashMap::new();
131
132 for entry in self.forward.iter() {
134 let id = *entry.key();
135 if !node_map.contains_key(&id) {
136 let idx = graph.add_node(id);
137 node_map.insert(id, idx);
138 }
139
140 for &effect in entry.value() {
141 if !node_map.contains_key(&effect) {
142 let idx = graph.add_node(effect);
143 node_map.insert(effect, idx);
144 }
145 }
146 }
147
148 for entry in self.forward.iter() {
150 let from = *entry.key();
151 let from_idx = node_map[&from];
152
153 for &to in entry.value() {
154 let to_idx = node_map[&to];
155 graph.add_edge(from_idx, to_idx, ());
156 }
157 }
158
159 (graph, node_map)
160 }
161
162 pub fn causal_past(&self, pattern: PatternId) -> Vec<PatternId> {
164 let mut result = Vec::new();
165 let mut visited = std::collections::HashSet::new();
166 let mut stack = vec![pattern];
167
168 while let Some(current) = stack.pop() {
169 if visited.contains(¤t) {
170 continue;
171 }
172 visited.insert(current);
173
174 if let Some(causes) = self.backward.get(¤t) {
175 for &cause in causes.iter() {
176 if !visited.contains(&cause) {
177 stack.push(cause);
178 result.push(cause);
179 }
180 }
181 }
182 }
183
184 result
185 }
186
187 pub fn causal_future(&self, pattern: PatternId) -> Vec<PatternId> {
189 let mut result = Vec::new();
190 let mut visited = std::collections::HashSet::new();
191 let mut stack = vec![pattern];
192
193 while let Some(current) = stack.pop() {
194 if visited.contains(¤t) {
195 continue;
196 }
197 visited.insert(current);
198
199 if let Some(effects) = self.forward.get(¤t) {
200 for &effect in effects.iter() {
201 if !visited.contains(&effect) {
202 stack.push(effect);
203 result.push(effect);
204 }
205 }
206 }
207 }
208
209 result
210 }
211
212 pub fn filter_by_light_cone(
214 &self,
215 reference: PatternId,
216 reference_time: SubstrateTime,
217 cone_type: CausalConeType,
218 candidates: &[PatternId],
219 ) -> Vec<PatternId> {
220 candidates
221 .iter()
222 .filter(|&&id| self.is_in_light_cone(id, reference, reference_time, cone_type))
223 .copied()
224 .collect()
225 }
226
227 fn is_in_light_cone(
229 &self,
230 pattern: PatternId,
231 _reference: PatternId,
232 reference_time: SubstrateTime,
233 cone_type: CausalConeType,
234 ) -> bool {
235 let pattern_time = match self.timestamps.get(&pattern) {
236 Some(t) => *t,
237 None => return false,
238 };
239
240 match cone_type {
241 CausalConeType::Past => pattern_time <= reference_time,
242 CausalConeType::Future => pattern_time >= reference_time,
243 CausalConeType::LightCone { velocity: _ } => {
244 let time_diff = (reference_time - pattern_time).abs();
247 let time_diff_secs = (time_diff.0 / 1_000_000_000).abs() as f32;
248
249 time_diff_secs >= 0.0 }
253 }
254 }
255
256 pub fn stats(&self) -> CausalGraphStats {
258 let num_nodes = self.timestamps.len();
259 let num_edges: usize = self.forward.iter().map(|e| e.value().len()).sum();
260
261 let avg_out_degree = if num_nodes > 0 {
262 num_edges as f32 / num_nodes as f32
263 } else {
264 0.0
265 };
266
267 CausalGraphStats {
268 num_nodes,
269 num_edges,
270 avg_out_degree,
271 }
272 }
273}
274
275impl Default for CausalGraph {
276 fn default() -> Self {
277 Self::new()
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct CausalGraphStats {
284 pub num_nodes: usize,
286 pub num_edges: usize,
288 pub avg_out_degree: f32,
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_causal_graph_basic() {
298 let graph = CausalGraph::new();
299
300 let p1 = PatternId::new();
301 let p2 = PatternId::new();
302 let p3 = PatternId::new();
303
304 let t1 = SubstrateTime::now();
305 let t2 = SubstrateTime::now();
306 let t3 = SubstrateTime::now();
307
308 graph.add_pattern(p1, t1);
309 graph.add_pattern(p2, t2);
310 graph.add_pattern(p3, t3);
311
312 graph.add_edge(p1, p2);
314 graph.add_edge(p2, p3);
315
316 assert_eq!(graph.out_degree(p1), 1);
317 assert_eq!(graph.in_degree(p2), 1);
318 assert_eq!(graph.distance(p1, p3), Some(2));
319
320 let past = graph.causal_past(p3);
321 assert!(past.contains(&p1));
322 assert!(past.contains(&p2));
323 }
324}