use std::collections::{BTreeSet, VecDeque};
use crate::causal::semi_markov_graph::SemiMarkovGraph;
#[derive(Debug, Clone)]
pub struct HedgeCertificate {
pub s_component: BTreeSet<String>,
pub blocking_x: BTreeSet<String>,
pub outcome_y: BTreeSet<String>,
pub explanation: String,
}
#[derive(Debug, Clone)]
pub struct HedgeError {
pub certificate: HedgeCertificate,
}
impl std::fmt::Display for HedgeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Non-identifiable: hedge found in c-component {:?} blocking effect of {:?} on {:?}. {}",
self.certificate.s_component,
self.certificate.blocking_x,
self.certificate.outcome_y,
self.certificate.explanation,
)
}
}
impl std::error::Error for HedgeError {}
pub struct HedgeFinder;
impl HedgeFinder {
pub fn find(graph: &SemiMarkovGraph, y: &[String], x: &[String]) -> Option<HedgeCertificate> {
let x_set: BTreeSet<String> = x.iter().cloned().collect();
let y_set: BTreeSet<String> = y.iter().cloned().collect();
let anc_y = ancestors_of(graph, y);
let all_vars: BTreeSet<String> = graph.nodes().cloned().collect();
let components = c_components_in_subgraph(graph, &all_vars);
for comp in &components {
if comp.len() < 2 {
continue;
}
let intersects_anc_y = comp.iter().any(|v| anc_y.contains(v) || y_set.contains(v));
if !intersects_anc_y {
continue;
}
let x_intersects_s: BTreeSet<String> = comp.intersection(&x_set).cloned().collect();
if x_intersects_s.is_empty() {
continue;
}
let has_unblocked_anc: bool = comp
.iter()
.any(|v| (anc_y.contains(v) || y_set.contains(v)) && !x_set.contains(v));
if !has_unblocked_anc {
continue;
}
let explanation = format!(
"C-component {comp:?} (size {}) contains ancestors of Y and is intersected by \
intervention variables {x_intersects_s:?}. The bidirected confounders \
inside this component cannot be eliminated by do(X), hence P(y|do(x)) \
is not identifiable from the observational distribution.",
comp.len()
);
return Some(HedgeCertificate {
s_component: comp.clone(),
blocking_x: x_intersects_s,
outcome_y: y_set,
explanation,
});
}
None
}
}
pub fn ancestors_of(graph: &SemiMarkovGraph, y: &[String]) -> BTreeSet<String> {
let mut visited: BTreeSet<String> = BTreeSet::new();
let mut queue: VecDeque<String> = y.iter().cloned().collect();
while let Some(node) = queue.pop_front() {
if visited.insert(node.clone()) {
for parent in graph.parents(&node) {
if !visited.contains(&parent) {
queue.push_back(parent);
}
}
}
}
visited
}
pub fn topological_order(graph: &SemiMarkovGraph) -> Vec<String> {
let nodes: Vec<String> = graph.nodes().cloned().collect();
let n = nodes.len();
let mut in_degree: std::collections::HashMap<String, usize> =
nodes.iter().map(|v| (v.clone(), 0)).collect();
for node in &nodes {
for child in graph.children(node) {
*in_degree.entry(child).or_insert(0) += 1;
}
}
let mut queue: VecDeque<String> = nodes
.iter()
.filter(|v| in_degree.get(*v).copied().unwrap_or(0) == 0)
.cloned()
.collect();
let mut queue_sorted: Vec<String> = queue.drain(..).collect();
queue_sorted.sort();
let mut queue: VecDeque<String> = queue_sorted.into();
let mut order = Vec::with_capacity(n);
while let Some(node) = queue.pop_front() {
order.push(node.clone());
let mut children: Vec<String> = graph.children(&node).collect();
children.sort();
for child in children {
let deg = in_degree.entry(child.clone()).or_insert(1);
*deg = deg.saturating_sub(1);
if *deg == 0 {
queue.push_back(child);
}
}
}
order
}
pub fn c_components_in_subgraph(
graph: &SemiMarkovGraph,
vars: &BTreeSet<String>,
) -> Vec<BTreeSet<String>> {
let var_list: Vec<String> = vars.iter().cloned().collect();
let n = var_list.len();
let mut parent: Vec<usize> = (0..n).collect();
let var_index: std::collections::HashMap<String, usize> = var_list
.iter()
.enumerate()
.map(|(i, v)| (v.clone(), i))
.collect();
fn find(parent: &mut Vec<usize>, i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut Vec<usize>, a: usize, b: usize) {
let ra = find(parent, a);
let rb = find(parent, b);
if ra != rb {
parent[ra] = rb;
}
}
for var in vars {
for neighbor in graph.bidirected_neighbors(var) {
if vars.contains(&neighbor) {
if let (Some(&i), Some(&j)) = (var_index.get(var), var_index.get(&neighbor)) {
union(&mut parent, i, j);
}
}
}
}
let mut components: std::collections::HashMap<usize, BTreeSet<String>> =
std::collections::HashMap::new();
for (i, var) in var_list.iter().enumerate() {
let root = find(&mut parent, i);
components.entry(root).or_default().insert(var.clone());
}
components.into_values().collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::semi_markov_graph::SemiMarkovGraph;
fn chain_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_node("X");
g.add_node("Y");
g.add_node("Z");
g.add_directed("X", "Y");
g.add_directed("Y", "Z");
g
}
fn confounded_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_node("X");
g.add_node("Y");
g.add_directed("X", "Y");
g.add_bidirected("X", "Y");
g
}
fn bidirected_chain_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_node("X");
g.add_node("Y");
g.add_node("Z");
g.add_bidirected("X", "Y");
g.add_bidirected("Y", "Z");
g
}
#[test]
fn test_c_components_chain_no_bidirected() {
let g = chain_graph();
let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
let comps = c_components_in_subgraph(&g, &vars);
assert_eq!(
comps.len(),
3,
"Expected 3 singleton components, got {}",
comps.len()
);
for comp in &comps {
assert_eq!(
comp.len(),
1,
"Each component should be a singleton, got {:?}",
comp
);
}
}
#[test]
fn test_c_components_fully_bidirected() {
let g = bidirected_chain_graph();
let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
let comps = c_components_in_subgraph(&g, &vars);
assert_eq!(comps.len(), 1, "Expected 1 component, got {}", comps.len());
let comp = &comps[0];
assert_eq!(
comp.len(),
3,
"Component should contain all 3 nodes, got {:?}",
comp
);
}
#[test]
fn test_c_components_partial_bidirected() {
let g = confounded_graph();
let vars: BTreeSet<String> = ["X", "Y"].iter().map(|s| s.to_string()).collect();
let comps = c_components_in_subgraph(&g, &vars);
assert_eq!(comps.len(), 1, "Expected 1 component, got {}", comps.len());
assert!(comps[0].contains("X") && comps[0].contains("Y"));
}
#[test]
fn test_ancestors_of_chain() {
let g = chain_graph();
let anc = ancestors_of(&g, &["Z".to_string()]);
assert!(anc.contains("X"), "X should be an ancestor of Z");
assert!(anc.contains("Y"), "Y should be an ancestor of Z");
assert!(anc.contains("Z"), "Z should be included");
}
#[test]
fn test_ancestors_of_root() {
let g = chain_graph();
let anc = ancestors_of(&g, &["X".to_string()]);
assert_eq!(anc.len(), 1);
assert!(anc.contains("X"));
}
#[test]
fn test_topological_order_chain() {
let g = chain_graph();
let order = topological_order(&g);
assert_eq!(order.len(), 3, "Order should have 3 elements");
let x_pos = order.iter().position(|v| v == "X").expect("X not in order");
let y_pos = order.iter().position(|v| v == "Y").expect("Y not in order");
let z_pos = order.iter().position(|v| v == "Z").expect("Z not in order");
assert!(x_pos < y_pos, "X must precede Y in topological order");
assert!(y_pos < z_pos, "Y must precede Z in topological order");
}
#[test]
fn test_hedge_finder_identifiable_chain() {
let g = chain_graph();
let cert = HedgeFinder::find(&g, &["Z".to_string()], &["X".to_string()]);
assert!(
cert.is_none(),
"Chain with no bidirected edges should have no hedge"
);
}
#[test]
fn test_hedge_finder_confounded_not_identifiable() {
let g = confounded_graph();
let cert = HedgeFinder::find(&g, &["Y".to_string()], &["X".to_string()]);
assert!(
cert.is_some(),
"Confounded graph X↔Y should produce a hedge certificate"
);
let cert = cert.expect("certificate");
assert!(
cert.s_component.contains("X") || cert.s_component.contains("Y"),
"Hedge component should include X or Y: {:?}",
cert.s_component
);
assert!(
cert.blocking_x.contains("X"),
"Blocking X should include X: {:?}",
cert.blocking_x
);
}
}