Skip to main content

dsfb_tmtr/
causal.rs

1use std::collections::{HashMap, VecDeque};
2
3use serde::Serialize;
4
5use crate::observer::ObserverSeries;
6use crate::tmtr::CorrectionEvent;
7
8#[derive(Debug, Clone, Serialize)]
9pub struct CausalNode {
10    pub id: String,
11    pub time: f64,
12    pub level: usize,
13    pub kind: String,
14}
15
16#[derive(Debug, Clone, Serialize)]
17pub struct CausalEdge {
18    pub scenario: String,
19    pub mode: String,
20    pub edge_type: String,
21    pub source_node: String,
22    pub source_time: f64,
23    pub source_level: usize,
24    pub target_node: String,
25    pub target_time: f64,
26    pub target_level: usize,
27    pub trust_weight: f64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct CausalGraph {
32    pub nodes: Vec<CausalNode>,
33    pub edges: Vec<CausalEdge>,
34}
35
36#[derive(Debug, Clone, Serialize)]
37pub struct CausalMetricsSummary {
38    pub edge_count: usize,
39    pub backward_edge_count: usize,
40    pub cycle_count: usize,
41    pub reachable_nodes_from_anchor: usize,
42    pub local_window_edge_density: f64,
43    pub max_in_degree: usize,
44    pub max_out_degree: usize,
45    pub max_path_length: usize,
46    pub mean_path_length: f64,
47}
48
49pub fn build_causal_graph(
50    scenario: &str,
51    mode: &str,
52    observers: &[ObserverSeries],
53    correction_events: &[CorrectionEvent],
54    min_trust_gap: f64,
55) -> CausalGraph {
56    let mut nodes = Vec::new();
57    let mut edges = Vec::new();
58
59    for observer in observers {
60        for step in 0..observer.estimate.len() {
61            nodes.push(CausalNode {
62                id: state_node_id(observer.level, step),
63                time: step as f64,
64                level: observer.level,
65                kind: "state".to_string(),
66            });
67            if step > 0 {
68                edges.push(CausalEdge {
69                    scenario: scenario.to_string(),
70                    mode: mode.to_string(),
71                    edge_type: "state_propagation".to_string(),
72                    source_node: state_node_id(observer.level, step - 1),
73                    source_time: (step - 1) as f64,
74                    source_level: observer.level,
75                    target_node: state_node_id(observer.level, step),
76                    target_time: step as f64,
77                    target_level: observer.level,
78                    trust_weight: observer.trust[step],
79                });
80            }
81        }
82    }
83
84    for step in 0..observers[0].estimate.len().saturating_sub(1) {
85        for source_index in (1..observers.len()).rev() {
86            let target_index = source_index - 1;
87            let source = &observers[source_index];
88            let target = &observers[target_index];
89            let trust = source.trust[step];
90            if trust > target.trust[step] + min_trust_gap {
91                edges.push(CausalEdge {
92                    scenario: scenario.to_string(),
93                    mode: mode.to_string(),
94                    edge_type: "trust_gate".to_string(),
95                    source_node: state_node_id(source.level, step),
96                    source_time: step as f64,
97                    source_level: source.level,
98                    target_node: state_node_id(target.level, step + 1),
99                    target_time: (step + 1) as f64,
100                    target_level: target.level,
101                    trust_weight: trust,
102                });
103            }
104        }
105    }
106
107    for (index, event) in correction_events.iter().enumerate() {
108        let correction_time = event.anchor_time as f64 + 0.1;
109        let correction_node = format!("corr:{index}");
110        let commit_time = correction_time + 0.1;
111        let commit_node = format!("commit:{index}");
112        nodes.push(CausalNode {
113            id: correction_node.clone(),
114            time: correction_time,
115            level: event.target_level,
116            kind: "correction".to_string(),
117        });
118        nodes.push(CausalNode {
119            id: commit_node.clone(),
120            time: commit_time,
121            level: event.target_level,
122            kind: "commit".to_string(),
123        });
124        edges.push(CausalEdge {
125            scenario: scenario.to_string(),
126            mode: mode.to_string(),
127            edge_type: "correction_source".to_string(),
128            source_node: state_node_id(event.source_level, event.anchor_time),
129            source_time: event.anchor_time as f64,
130            source_level: event.source_level,
131            target_node: correction_node.clone(),
132            target_time: correction_time,
133            target_level: event.target_level,
134            trust_weight: event.trust_weight,
135        });
136        edges.push(CausalEdge {
137            scenario: scenario.to_string(),
138            mode: mode.to_string(),
139            edge_type: "correction_context".to_string(),
140            source_node: state_node_id(event.target_level, event.corrected_time),
141            source_time: event.corrected_time as f64,
142            source_level: event.target_level,
143            target_node: correction_node.clone(),
144            target_time: correction_time,
145            target_level: event.target_level,
146            trust_weight: event.trust_weight,
147        });
148        edges.push(CausalEdge {
149            scenario: scenario.to_string(),
150            mode: mode.to_string(),
151            edge_type: "correction_commit".to_string(),
152            source_node: correction_node,
153            source_time: correction_time,
154            source_level: event.target_level,
155            target_node: commit_node,
156            target_time: commit_time,
157            target_level: event.target_level,
158            trust_weight: event.trust_weight,
159        });
160    }
161
162    CausalGraph { nodes, edges }
163}
164
165pub fn summarize_causal_graph(graph: &CausalGraph, delta: usize) -> CausalMetricsSummary {
166    let edge_count = graph.edges.len();
167    let backward_edge_count = graph
168        .edges
169        .iter()
170        .filter(|edge| edge.target_time < edge.source_time)
171        .count();
172    let cycle_count = if has_cycle(&graph.nodes, &graph.edges) {
173        1
174    } else {
175        0
176    };
177
178    let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
179    let mut indegree: HashMap<&str, usize> = HashMap::new();
180    let mut outdegree: HashMap<&str, usize> = HashMap::new();
181    let node_times: HashMap<&str, f64> = graph
182        .nodes
183        .iter()
184        .map(|node| (node.id.as_str(), node.time))
185        .collect();
186    for node in &graph.nodes {
187        indegree.entry(node.id.as_str()).or_insert(0);
188        outdegree.entry(node.id.as_str()).or_insert(0);
189    }
190    for edge in &graph.edges {
191        adjacency
192            .entry(edge.source_node.as_str())
193            .or_default()
194            .push(edge.target_node.as_str());
195        *indegree.entry(edge.target_node.as_str()).or_insert(0) += 1;
196        *outdegree.entry(edge.source_node.as_str()).or_insert(0) += 1;
197    }
198
199    let max_in_degree = indegree.values().copied().max().unwrap_or(0);
200    let max_out_degree = outdegree.values().copied().max().unwrap_or(0);
201
202    let anchor_node = graph
203        .nodes
204        .iter()
205        .find(|node| node.kind == "correction")
206        .or_else(|| graph.nodes.first());
207    let reachable_nodes_from_anchor = anchor_node
208        .map(|node| reachable_count(node.id.as_str(), &adjacency))
209        .unwrap_or(0);
210
211    let local_window_edges = graph
212        .edges
213        .iter()
214        .filter(|edge| edge.target_time - edge.source_time <= delta as f64 + 1.0)
215        .count();
216    let node_count = graph.nodes.len().max(1);
217    let possible_local_edges = node_count * delta.max(1);
218    let local_window_edge_density = local_window_edges as f64 / possible_local_edges as f64;
219
220    let longest_paths = longest_path_lengths(&graph.nodes, &graph.edges, &node_times);
221    let max_path_length = longest_paths.values().copied().max().unwrap_or(0);
222    let mean_path_length = if longest_paths.is_empty() {
223        0.0
224    } else {
225        longest_paths
226            .values()
227            .copied()
228            .map(|value| value as f64)
229            .sum::<f64>()
230            / longest_paths.len() as f64
231    };
232
233    CausalMetricsSummary {
234        edge_count,
235        backward_edge_count,
236        cycle_count,
237        reachable_nodes_from_anchor,
238        local_window_edge_density,
239        max_in_degree,
240        max_out_degree,
241        max_path_length,
242        mean_path_length,
243    }
244}
245
246fn state_node_id(level: usize, step: usize) -> String {
247    format!("state:L{level}:t{step}")
248}
249
250fn reachable_count<'a>(anchor: &'a str, adjacency: &HashMap<&'a str, Vec<&'a str>>) -> usize {
251    let mut seen = HashMap::<&str, bool>::new();
252    let mut queue = VecDeque::new();
253    queue.push_back(anchor);
254    while let Some(node) = queue.pop_front() {
255        if seen.insert(node, true).is_some() {
256            continue;
257        }
258        if let Some(children) = adjacency.get(node) {
259            for child in children {
260                queue.push_back(child);
261            }
262        }
263    }
264    seen.len()
265}
266
267fn has_cycle(nodes: &[CausalNode], edges: &[CausalEdge]) -> bool {
268    let mut indegree = HashMap::<&str, usize>::new();
269    let mut adjacency = HashMap::<&str, Vec<&str>>::new();
270    for node in nodes {
271        indegree.insert(node.id.as_str(), 0);
272    }
273    for edge in edges {
274        *indegree.entry(edge.target_node.as_str()).or_insert(0) += 1;
275        adjacency
276            .entry(edge.source_node.as_str())
277            .or_default()
278            .push(edge.target_node.as_str());
279    }
280    let mut queue = VecDeque::new();
281    for (node, degree) in &indegree {
282        if *degree == 0 {
283            queue.push_back(*node);
284        }
285    }
286    let mut visited = 0usize;
287    while let Some(node) = queue.pop_front() {
288        visited += 1;
289        if let Some(children) = adjacency.get(node) {
290            for child in children {
291                if let Some(entry) = indegree.get_mut(child) {
292                    *entry = entry.saturating_sub(1);
293                    if *entry == 0 {
294                        queue.push_back(child);
295                    }
296                }
297            }
298        }
299    }
300    visited != indegree.len()
301}
302
303fn longest_path_lengths<'a>(
304    nodes: &'a [CausalNode],
305    edges: &'a [CausalEdge],
306    node_times: &HashMap<&'a str, f64>,
307) -> HashMap<&'a str, usize> {
308    let mut adjacency = HashMap::<&str, Vec<&str>>::new();
309    let mut indegree = HashMap::<&str, usize>::new();
310    for node in nodes {
311        indegree.insert(node.id.as_str(), 0);
312    }
313    for edge in edges {
314        adjacency
315            .entry(edge.source_node.as_str())
316            .or_default()
317            .push(edge.target_node.as_str());
318        *indegree.entry(edge.target_node.as_str()).or_insert(0) += 1;
319    }
320    let mut queue = VecDeque::new();
321    let mut distance = HashMap::<&str, usize>::new();
322    let mut ordered = nodes.iter().collect::<Vec<_>>();
323    ordered.sort_by(|left, right| left.time.total_cmp(&right.time));
324    for node in &ordered {
325        distance.insert(node.id.as_str(), 0);
326        if indegree.get(node.id.as_str()).copied().unwrap_or(0) == 0 {
327            queue.push_back(node.id.as_str());
328        }
329    }
330    while let Some(node) = queue.pop_front() {
331        let source_distance = distance.get(node).copied().unwrap_or(0);
332        if let Some(children) = adjacency.get(node) {
333            for child in children {
334                let source_time = node_times.get(node).copied().unwrap_or_default();
335                let target_time = node_times.get(child).copied().unwrap_or_default();
336                if target_time >= source_time {
337                    let candidate = source_distance + 1;
338                    if candidate > distance.get(child).copied().unwrap_or(0) {
339                        distance.insert(child, candidate);
340                    }
341                }
342                if let Some(entry) = indegree.get_mut(child) {
343                    *entry = entry.saturating_sub(1);
344                    if *entry == 0 {
345                        queue.push_back(child);
346                    }
347                }
348            }
349        }
350    }
351    distance
352}
353
354#[cfg(test)]
355mod tests {
356    use super::{summarize_causal_graph, CausalEdge, CausalGraph, CausalNode};
357
358    #[test]
359    fn forward_edges_do_not_trigger_backward_detection() {
360        let graph = CausalGraph {
361            nodes: vec![
362                CausalNode {
363                    id: "a".to_string(),
364                    time: 0.0,
365                    level: 1,
366                    kind: "state".to_string(),
367                },
368                CausalNode {
369                    id: "b".to_string(),
370                    time: 1.0,
371                    level: 1,
372                    kind: "state".to_string(),
373                },
374            ],
375            edges: vec![CausalEdge {
376                scenario: "test".to_string(),
377                mode: "tmtr".to_string(),
378                edge_type: "state".to_string(),
379                source_node: "a".to_string(),
380                source_time: 0.0,
381                source_level: 1,
382                target_node: "b".to_string(),
383                target_time: 1.0,
384                target_level: 1,
385                trust_weight: 1.0,
386            }],
387        };
388        let summary = summarize_causal_graph(&graph, 4);
389        assert_eq!(summary.backward_edge_count, 0);
390        assert_eq!(summary.cycle_count, 0);
391    }
392}