1use crate::types::{PatternId, SubstrateTime};
4use dashmap::DashMap;
5use petgraph::graph::{DiGraph, NodeIndex};
6use petgraph::algo::dijkstra;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum CausalConeType {
14 Past,
16 Future,
18 LightCone {
20 velocity: f32,
22 },
23}
24
25pub struct CausalGraph {
27 forward: DashMap<PatternId, Vec<PatternId>>,
29 backward: DashMap<PatternId, Vec<PatternId>>,
31 timestamps: DashMap<PatternId, SubstrateTime>,
33 graph_cache: Arc<parking_lot::RwLock<Option<(DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>)>>>,
35}
36
37impl CausalGraph {
38 pub fn new() -> Self {
40 Self {
41 forward: DashMap::new(),
42 backward: DashMap::new(),
43 timestamps: DashMap::new(),
44 graph_cache: Arc::new(parking_lot::RwLock::new(None)),
45 }
46 }
47
48 pub fn add_edge(&self, cause: PatternId, effect: PatternId) {
50 self.forward
52 .entry(cause)
53 .or_insert_with(Vec::new)
54 .push(effect);
55
56 self.backward
58 .entry(effect)
59 .or_insert_with(Vec::new)
60 .push(cause);
61
62 *self.graph_cache.write() = None;
64 }
65
66 pub fn add_pattern(&self, id: PatternId, timestamp: SubstrateTime) {
68 self.timestamps.insert(id, timestamp);
69 }
70
71 pub fn causes(&self, pattern: PatternId) -> Vec<PatternId> {
73 self.backward
74 .get(&pattern)
75 .map(|v| v.clone())
76 .unwrap_or_default()
77 }
78
79 pub fn effects(&self, pattern: PatternId) -> Vec<PatternId> {
81 self.forward
82 .get(&pattern)
83 .map(|v| v.clone())
84 .unwrap_or_default()
85 }
86
87 pub fn out_degree(&self, pattern: PatternId) -> usize {
89 self.forward
90 .get(&pattern)
91 .map(|v| v.len())
92 .unwrap_or(0)
93 }
94
95 pub fn in_degree(&self, pattern: PatternId) -> usize {
97 self.backward
98 .get(&pattern)
99 .map(|v| v.len())
100 .unwrap_or(0)
101 }
102
103 pub fn distance(&self, from: PatternId, to: PatternId) -> Option<usize> {
105 if from == to {
106 return Some(0);
107 }
108
109 let (graph, node_map) = {
111 let cache = self.graph_cache.read();
112 if let Some((g, m)) = cache.as_ref() {
113 (g.clone(), m.clone())
114 } else {
115 drop(cache);
116 let (g, m) = self.build_graph();
117 *self.graph_cache.write() = Some((g.clone(), m.clone()));
118 (g, m)
119 }
120 };
121
122 let from_idx = *node_map.get(&from)?;
124 let to_idx = *node_map.get(&to)?;
125
126 let distances = dijkstra(&graph, from_idx, Some(to_idx), |_| 1);
128
129 distances.get(&to_idx).copied()
130 }
131
132 fn build_graph(&self) -> (DiGraph<PatternId, ()>, HashMap<PatternId, NodeIndex>) {
134 let mut graph = DiGraph::new();
135 let mut node_map = HashMap::new();
136
137 for entry in self.forward.iter() {
139 let id = *entry.key();
140 if !node_map.contains_key(&id) {
141 let idx = graph.add_node(id);
142 node_map.insert(id, idx);
143 }
144
145 for &effect in entry.value() {
146 if !node_map.contains_key(&effect) {
147 let idx = graph.add_node(effect);
148 node_map.insert(effect, idx);
149 }
150 }
151 }
152
153 for entry in self.forward.iter() {
155 let from = *entry.key();
156 let from_idx = node_map[&from];
157
158 for &to in entry.value() {
159 let to_idx = node_map[&to];
160 graph.add_edge(from_idx, to_idx, ());
161 }
162 }
163
164 (graph, node_map)
165 }
166
167 pub fn causal_past(&self, pattern: PatternId) -> Vec<PatternId> {
169 let mut result = Vec::new();
170 let mut visited = std::collections::HashSet::new();
171 let mut stack = vec![pattern];
172
173 while let Some(current) = stack.pop() {
174 if visited.contains(¤t) {
175 continue;
176 }
177 visited.insert(current);
178
179 if let Some(causes) = self.backward.get(¤t) {
180 for &cause in causes.iter() {
181 if !visited.contains(&cause) {
182 stack.push(cause);
183 result.push(cause);
184 }
185 }
186 }
187 }
188
189 result
190 }
191
192 pub fn causal_future(&self, pattern: PatternId) -> Vec<PatternId> {
194 let mut result = Vec::new();
195 let mut visited = std::collections::HashSet::new();
196 let mut stack = vec![pattern];
197
198 while let Some(current) = stack.pop() {
199 if visited.contains(¤t) {
200 continue;
201 }
202 visited.insert(current);
203
204 if let Some(effects) = self.forward.get(¤t) {
205 for &effect in effects.iter() {
206 if !visited.contains(&effect) {
207 stack.push(effect);
208 result.push(effect);
209 }
210 }
211 }
212 }
213
214 result
215 }
216
217 pub fn filter_by_light_cone(
219 &self,
220 reference: PatternId,
221 reference_time: SubstrateTime,
222 cone_type: CausalConeType,
223 candidates: &[PatternId],
224 ) -> Vec<PatternId> {
225 candidates
226 .iter()
227 .filter(|&&id| {
228 self.is_in_light_cone(id, reference, reference_time, cone_type)
229 })
230 .copied()
231 .collect()
232 }
233
234 fn is_in_light_cone(
236 &self,
237 pattern: PatternId,
238 _reference: PatternId,
239 reference_time: SubstrateTime,
240 cone_type: CausalConeType,
241 ) -> bool {
242 let pattern_time = match self.timestamps.get(&pattern) {
243 Some(t) => *t,
244 None => return false,
245 };
246
247 match cone_type {
248 CausalConeType::Past => pattern_time <= reference_time,
249 CausalConeType::Future => pattern_time >= reference_time,
250 CausalConeType::LightCone { velocity: _ } => {
251 let time_diff = (reference_time - pattern_time).abs();
254 let time_diff_secs = (time_diff.0 / 1_000_000_000).abs() as f32;
255
256 time_diff_secs >= 0.0 }
260 }
261 }
262
263 pub fn stats(&self) -> CausalGraphStats {
265 let num_nodes = self.timestamps.len();
266 let num_edges: usize = self.forward.iter().map(|e| e.value().len()).sum();
267
268 let avg_out_degree = if num_nodes > 0 {
269 num_edges as f32 / num_nodes as f32
270 } else {
271 0.0
272 };
273
274 CausalGraphStats {
275 num_nodes,
276 num_edges,
277 avg_out_degree,
278 }
279 }
280}
281
282impl Default for CausalGraph {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct CausalGraphStats {
291 pub num_nodes: usize,
293 pub num_edges: usize,
295 pub avg_out_degree: f32,
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_causal_graph_basic() {
305 let graph = CausalGraph::new();
306
307 let p1 = PatternId::new();
308 let p2 = PatternId::new();
309 let p3 = PatternId::new();
310
311 let t1 = SubstrateTime::now();
312 let t2 = SubstrateTime::now();
313 let t3 = SubstrateTime::now();
314
315 graph.add_pattern(p1, t1);
316 graph.add_pattern(p2, t2);
317 graph.add_pattern(p3, t3);
318
319 graph.add_edge(p1, p2);
321 graph.add_edge(p2, p3);
322
323 assert_eq!(graph.out_degree(p1), 1);
324 assert_eq!(graph.in_degree(p2), 1);
325 assert_eq!(graph.distance(p1, p3), Some(2));
326
327 let past = graph.causal_past(p3);
328 assert!(past.contains(&p1));
329 assert!(past.contains(&p2));
330 }
331}