use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
use serde::{Deserialize, Serialize};
use crate::project::Project;
#[derive(Debug, Clone)]
pub struct CausalGraph {
nodes: BTreeSet<String>,
parents: HashMap<String, BTreeSet<String>>,
children: HashMap<String, BTreeSet<String>>,
}
impl CausalGraph {
#[must_use]
pub fn from_project(project: &Project) -> Self {
let mut nodes = BTreeSet::new();
let mut parents: HashMap<String, BTreeSet<String>> = HashMap::new();
let mut children: HashMap<String, BTreeSet<String>> = HashMap::new();
for f in &project.findings {
nodes.insert(f.id.clone());
parents.entry(f.id.clone()).or_default();
children.entry(f.id.clone()).or_default();
}
for f in &project.findings {
for link in &f.links {
if !matches!(link.link_type.as_str(), "depends" | "supports") {
continue;
}
if link.target.contains('@') {
continue;
}
if !nodes.contains(&link.target) {
continue;
}
parents
.entry(f.id.clone())
.or_default()
.insert(link.target.clone());
children
.entry(link.target.clone())
.or_default()
.insert(f.id.clone());
}
}
Self {
nodes,
parents,
children,
}
}
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn edge_count(&self) -> usize {
self.parents.values().map(BTreeSet::len).sum()
}
#[must_use]
pub fn contains(&self, node: &str) -> bool {
self.nodes.contains(node)
}
#[must_use]
pub fn parents_of(&self, node: &str) -> impl Iterator<Item = &str> {
self.parents
.get(node)
.into_iter()
.flat_map(|s| s.iter().map(String::as_str))
}
#[must_use]
pub fn children_of(&self, node: &str) -> impl Iterator<Item = &str> {
self.children
.get(node)
.into_iter()
.flat_map(|s| s.iter().map(String::as_str))
}
#[must_use]
pub fn ancestors(&self, node: &str) -> HashSet<String> {
let mut seen = HashSet::new();
let mut queue: VecDeque<String> = VecDeque::new();
if let Some(ps) = self.parents.get(node) {
for p in ps {
queue.push_back(p.clone());
}
}
while let Some(n) = queue.pop_front() {
if !seen.insert(n.clone()) {
continue;
}
if let Some(ps) = self.parents.get(&n) {
for p in ps {
if !seen.contains(p) {
queue.push_back(p.clone());
}
}
}
}
seen
}
#[must_use]
pub fn descendants(&self, node: &str) -> HashSet<String> {
let mut seen = HashSet::new();
let mut queue: VecDeque<String> = VecDeque::new();
if let Some(cs) = self.children.get(node) {
for c in cs {
queue.push_back(c.clone());
}
}
while let Some(n) = queue.pop_front() {
if !seen.insert(n.clone()) {
continue;
}
if let Some(cs) = self.children.get(&n) {
for c in cs {
if !seen.contains(c) {
queue.push_back(c.clone());
}
}
}
}
seen
}
#[must_use]
pub fn is_descendant_of(&self, candidate: &str, source: &str) -> bool {
self.descendants(source).contains(candidate)
}
pub fn paths_between(
&self,
start: &str,
end: &str,
max_paths: usize,
max_len: usize,
) -> Vec<Vec<String>> {
if !self.contains(start) || !self.contains(end) || start == end {
return Vec::new();
}
let mut all_paths = Vec::new();
let mut current: Vec<String> = vec![start.to_string()];
let mut visited: HashSet<String> = HashSet::new();
visited.insert(start.to_string());
self.dfs_paths(
start,
end,
&mut current,
&mut visited,
&mut all_paths,
max_paths,
max_len,
);
all_paths
}
fn dfs_paths(
&self,
node: &str,
end: &str,
current: &mut Vec<String>,
visited: &mut HashSet<String>,
all_paths: &mut Vec<Vec<String>>,
max_paths: usize,
max_len: usize,
) {
if all_paths.len() >= max_paths {
return;
}
if current.len() > max_len {
return;
}
let mut neighbors: BTreeSet<String> = BTreeSet::new();
if let Some(ps) = self.parents.get(node) {
for p in ps {
neighbors.insert(p.clone());
}
}
if let Some(cs) = self.children.get(node) {
for c in cs {
neighbors.insert(c.clone());
}
}
for next in &neighbors {
if visited.contains(next) {
continue;
}
current.push(next.clone());
visited.insert(next.clone());
if next == end {
all_paths.push(current.clone());
} else {
self.dfs_paths(next, end, current, visited, all_paths, max_paths, max_len);
}
visited.remove(next);
current.pop();
if all_paths.len() >= max_paths {
return;
}
}
}
fn node_role_in_path(&self, prev: &str, node: &str, next: &str) -> NodeRole {
let prev_is_parent_of_node = self.parents.get(node).is_some_and(|ps| ps.contains(prev));
let next_is_parent_of_node = self.parents.get(node).is_some_and(|ps| ps.contains(next));
let prev_is_child_of_node = self.children.get(node).is_some_and(|cs| cs.contains(prev));
let next_is_child_of_node = self.children.get(node).is_some_and(|cs| cs.contains(next));
match (
prev_is_parent_of_node,
next_is_parent_of_node,
prev_is_child_of_node,
next_is_child_of_node,
) {
(true, true, _, _) => NodeRole::Collider,
(_, _, true, true) => NodeRole::Fork,
_ => NodeRole::Chain,
}
}
#[must_use]
pub fn is_path_blocked(&self, path: &[String], z: &HashSet<String>) -> bool {
if path.len() < 3 {
return false;
}
for i in 1..path.len() - 1 {
let prev = &path[i - 1];
let node = &path[i];
let next = &path[i + 1];
let role = self.node_role_in_path(prev, node, next);
match role {
NodeRole::Chain | NodeRole::Fork => {
if z.contains(node) {
return true;
}
}
NodeRole::Collider => {
let in_z = z.contains(node);
let descendant_in_z = self.descendants(node).iter().any(|d| z.contains(d));
if !in_z && !descendant_in_z {
return true;
}
}
}
}
false
}
#[must_use]
pub fn is_back_door_path(&self, path: &[String], x: &str) -> bool {
if path.len() < 2 || path[0] != x {
return false;
}
let second = &path[1];
self.parents.get(x).is_some_and(|ps| ps.contains(second))
}
#[must_use]
pub fn is_directed_path(&self, path: &[String]) -> bool {
if path.len() < 2 {
return false;
}
for i in 0..path.len() - 1 {
let a = &path[i];
let b = &path[i + 1];
let a_is_parent_of_b = self.parents.get(b).is_some_and(|ps| ps.contains(a));
if !a_is_parent_of_b {
return false;
}
}
true
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NodeRole {
Chain,
Fork,
Collider,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum CausalEffectVerdict {
Identified {
adjustment_set: Vec<String>,
back_door_paths_considered: usize,
},
IdentifiedByFrontDoor {
mediator_set: Vec<String>,
},
NoCausalPath { reason: String },
Underidentified {
unblocked_back_door_paths: Vec<Vec<String>>,
candidates_tried: usize,
},
UnknownNode { which: String },
}
impl CausalEffectVerdict {
pub fn as_str(&self) -> &'static str {
match self {
CausalEffectVerdict::Identified { .. } => "identified",
CausalEffectVerdict::IdentifiedByFrontDoor { .. } => "identified_by_front_door",
CausalEffectVerdict::NoCausalPath { .. } => "no_causal_path",
CausalEffectVerdict::Underidentified { .. } => "underidentified",
CausalEffectVerdict::UnknownNode { .. } => "unknown_node",
}
}
}
pub fn identify_effect(project: &Project, source: &str, target: &str) -> CausalEffectVerdict {
let graph = CausalGraph::from_project(project);
identify_effect_in_graph(&graph, source, target)
}
pub fn identify_effect_in_graph(
graph: &CausalGraph,
source: &str,
target: &str,
) -> CausalEffectVerdict {
if !graph.contains(source) {
return CausalEffectVerdict::UnknownNode {
which: format!("source not in frontier: {source}"),
};
}
if !graph.contains(target) {
return CausalEffectVerdict::UnknownNode {
which: format!("target not in frontier: {target}"),
};
}
if source == target {
return CausalEffectVerdict::NoCausalPath {
reason: "source equals target".into(),
};
}
const MAX_PATHS: usize = 200;
const MAX_LEN: usize = 8;
let all_paths = graph.paths_between(source, target, MAX_PATHS, MAX_LEN);
let back_door_paths: Vec<Vec<String>> = all_paths
.iter()
.filter(|p| graph.is_back_door_path(p, source))
.cloned()
.collect();
if all_paths.is_empty() {
return CausalEffectVerdict::NoCausalPath {
reason: format!("no path between {source} and {target} (search depth {MAX_LEN})"),
};
}
let descendants_of_source = graph.descendants(source);
let candidates: Vec<String> = graph
.nodes
.iter()
.filter(|n| n.as_str() != source && n.as_str() != target)
.filter(|n| !descendants_of_source.contains(n.as_str()))
.cloned()
.collect();
let blocks_all = |z: &HashSet<String>| -> bool {
back_door_paths.iter().all(|p| graph.is_path_blocked(p, z))
};
let empty: HashSet<String> = HashSet::new();
if blocks_all(&empty) {
return CausalEffectVerdict::Identified {
adjustment_set: Vec::new(),
back_door_paths_considered: back_door_paths.len(),
};
}
let mut tried = 1usize;
for c in &candidates {
let mut z = HashSet::new();
z.insert(c.clone());
tried += 1;
if blocks_all(&z) {
return CausalEffectVerdict::Identified {
adjustment_set: vec![c.clone()],
back_door_paths_considered: back_door_paths.len(),
};
}
}
for i in 0..candidates.len() {
for j in (i + 1)..candidates.len() {
let mut z = HashSet::new();
z.insert(candidates[i].clone());
z.insert(candidates[j].clone());
tried += 1;
if blocks_all(&z) {
return CausalEffectVerdict::Identified {
adjustment_set: vec![candidates[i].clone(), candidates[j].clone()],
back_door_paths_considered: back_door_paths.len(),
};
}
if tried > 2_000 {
break;
}
}
if tried > 2_000 {
break;
}
}
if let Some(mediators) = find_front_door_set(graph, source, target, &all_paths) {
return CausalEffectVerdict::IdentifiedByFrontDoor {
mediator_set: mediators,
};
}
let unblocked: Vec<Vec<String>> = back_door_paths
.iter()
.filter(|p| !graph.is_path_blocked(p, &empty))
.take(5)
.cloned()
.collect();
CausalEffectVerdict::Underidentified {
unblocked_back_door_paths: unblocked,
candidates_tried: tried,
}
}
fn find_front_door_set(
graph: &CausalGraph,
source: &str,
target: &str,
all_paths_source_target: &[Vec<String>],
) -> Option<Vec<String>> {
let directed_st: Vec<Vec<String>> = all_paths_source_target
.iter()
.filter(|p| graph.is_directed_path(p))
.cloned()
.collect();
if directed_st.is_empty() {
return None;
}
let descendants_of_source = graph.descendants(source);
let ancestors_of_target = graph.ancestors(target);
let mediator_candidates: Vec<&str> = graph
.nodes
.iter()
.filter(|n| {
n.as_str() != source
&& n.as_str() != target
&& descendants_of_source.contains(n.as_str())
&& ancestors_of_target.contains(n.as_str())
})
.map(String::as_str)
.collect();
let source_set: HashSet<String> = std::iter::once(source.to_string()).collect();
for m in mediator_candidates {
let intercepts_all = directed_st
.iter()
.all(|p| p.iter().any(|n| n.as_str() == m));
if !intercepts_all {
continue;
}
const MAX_PATHS: usize = 100;
const MAX_LEN: usize = 6;
let paths_sm = graph.paths_between(source, m, MAX_PATHS, MAX_LEN);
let back_door_sm: Vec<&Vec<String>> = paths_sm
.iter()
.filter(|p| graph.is_back_door_path(p, source))
.collect();
let empty: HashSet<String> = HashSet::new();
let any_open = back_door_sm
.iter()
.any(|p| !graph.is_path_blocked(p, &empty));
if any_open {
continue;
}
let paths_mt = graph.paths_between(m, target, MAX_PATHS, MAX_LEN);
let back_door_mt: Vec<&Vec<String>> = paths_mt
.iter()
.filter(|p| graph.is_back_door_path(p, m))
.collect();
let all_blocked_by_source = back_door_mt
.iter()
.all(|p| graph.is_path_blocked(p, &source_set));
if !all_blocked_by_source {
continue;
}
return Some(vec![m.to_string()]);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bundle::*;
use crate::project;
fn finding(id: &str) -> FindingBundle {
let mut b = FindingBundle::new(
Assertion {
text: format!("claim {id}"),
assertion_type: "mechanism".into(),
entities: vec![],
relation: None,
direction: None,
causal_claim: None,
causal_evidence_grade: None,
},
Evidence {
evidence_type: "experimental".into(),
model_system: String::new(),
species: None,
method: String::new(),
sample_size: None,
effect_size: None,
p_value: None,
replicated: false,
replication_count: None,
evidence_spans: vec![],
},
Conditions::default_for_test(),
Confidence::raw(0.7, "test", 0.85),
Provenance::default_for_test(),
Flags::default(),
);
b.id = id.to_string();
b
}
fn link(target: &str, kind: &str) -> Link {
Link {
target: target.into(),
link_type: kind.into(),
note: String::new(),
inferred_by: "test".into(),
created_at: String::new(),
mechanism: None,
}
}
impl Conditions {
fn default_for_test() -> Self {
Self {
text: String::new(),
species_verified: vec![],
species_unverified: vec![],
in_vitro: false,
in_vivo: false,
human_data: false,
clinical_trial: false,
concentration_range: None,
duration: None,
age_group: None,
cell_type: None,
}
}
}
impl Provenance {
fn default_for_test() -> Self {
Self {
source_type: "published_paper".into(),
doi: None,
pmid: None,
pmc: None,
openalex_id: None,
url: None,
title: "Test".into(),
authors: vec![],
year: Some(2025),
journal: None,
license: None,
publisher: None,
funders: vec![],
extraction: Extraction::default(),
review: None,
citation_count: None,
}
}
}
fn proj(findings: Vec<FindingBundle>) -> Project {
project::assemble("test", findings, 1, 0, "test")
}
#[test]
fn chain_a_to_c_identifiable_empty() {
let a = finding("vf_a");
let mut b = finding("vf_b");
b.links.push(link("vf_a", "depends"));
let mut c = finding("vf_c");
c.links.push(link("vf_b", "depends"));
let p = proj(vec![a, b, c]);
let v = identify_effect(&p, "vf_a", "vf_c");
match v {
CausalEffectVerdict::Identified { adjustment_set, .. } => {
assert!(
adjustment_set.is_empty(),
"chain should need no adjustment, got {adjustment_set:?}"
);
}
other => panic!("expected Identified for A→B→C, got {other:?}"),
}
}
#[test]
fn confounder_requires_z_in_adjustment_set() {
let z = finding("vf_z");
let mut a = finding("vf_a");
a.links.push(link("vf_z", "depends"));
let mut b = finding("vf_b");
b.links.push(link("vf_z", "depends"));
let p = proj(vec![z, a, b]);
let v = identify_effect(&p, "vf_a", "vf_b");
match v {
CausalEffectVerdict::Identified { adjustment_set, .. } => {
assert_eq!(adjustment_set, vec!["vf_z"], "expected Z in adjustment set");
}
CausalEffectVerdict::NoCausalPath { .. } => {
}
other => panic!("expected Identified or NoCausalPath, got {other:?}"),
}
}
#[test]
fn mediator_not_in_adjustment_set() {
let a = finding("vf_a");
let mut m = finding("vf_m");
m.links.push(link("vf_a", "depends"));
let mut b = finding("vf_b");
b.links.push(link("vf_m", "depends"));
let p = proj(vec![a, m, b]);
let v = identify_effect(&p, "vf_a", "vf_b");
match v {
CausalEffectVerdict::Identified { adjustment_set, .. } => {
assert!(
!adjustment_set.contains(&"vf_m".to_string()),
"mediator must not be in adjustment set: {adjustment_set:?}"
);
}
other => panic!("expected Identified for A→M→B, got {other:?}"),
}
}
#[test]
fn collider_not_used_as_confounder() {
let a = finding("vf_a");
let b = finding("vf_b");
let mut c = finding("vf_c");
c.links.push(link("vf_a", "depends"));
c.links.push(link("vf_b", "depends"));
let p = proj(vec![a, b, c]);
let v = identify_effect(&p, "vf_a", "vf_b");
match v {
CausalEffectVerdict::Identified { adjustment_set, .. } => {
assert!(
!adjustment_set.contains(&"vf_c".to_string()),
"collider must not be in adjustment set: {adjustment_set:?}"
);
}
CausalEffectVerdict::NoCausalPath { .. } => {
}
other => panic!("expected Identified or NoCausalPath, got {other:?}"),
}
}
#[test]
fn unknown_node_reported() {
let a = finding("vf_a");
let p = proj(vec![a]);
let v = identify_effect(&p, "vf_missing", "vf_a");
assert!(matches!(v, CausalEffectVerdict::UnknownNode { .. }));
}
#[test]
fn graph_basic_construction() {
let a = finding("vf_a");
let mut b = finding("vf_b");
b.links.push(link("vf_a", "depends"));
let p = proj(vec![a, b]);
let g = CausalGraph::from_project(&p);
assert_eq!(g.node_count(), 2);
assert_eq!(g.edge_count(), 1);
assert!(g.parents_of("vf_b").any(|p| p == "vf_a"));
assert!(g.children_of("vf_a").any(|c| c == "vf_b"));
}
#[test]
fn front_door_when_confounder_unobserved() {
let x = finding("vf_x");
let mut m = finding("vf_m");
m.links.push(link("vf_x", "depends"));
let mut y = finding("vf_y");
y.links.push(link("vf_m", "depends"));
let p = proj(vec![x, m, y]);
let v = identify_effect(&p, "vf_x", "vf_y");
match v {
CausalEffectVerdict::Identified { .. }
| CausalEffectVerdict::IdentifiedByFrontDoor { .. } => {}
other => panic!("expected identified or front-door, got {other:?}"),
}
}
#[test]
fn descendants_transitive() {
let a = finding("vf_a");
let mut b = finding("vf_b");
b.links.push(link("vf_a", "depends"));
let mut c = finding("vf_c");
c.links.push(link("vf_b", "depends"));
let p = proj(vec![a, b, c]);
let g = CausalGraph::from_project(&p);
let desc = g.descendants("vf_a");
assert!(desc.contains("vf_b"));
assert!(desc.contains("vf_c"));
}
}