use std::collections::BTreeSet;
use crate::causal::hedge::{
ancestors_of, c_components_in_subgraph, topological_order, HedgeCertificate,
};
use crate::causal::semi_markov_graph::SemiMarkovGraph;
use crate::causal::symbolic_prob::ProbExpr;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum IdResult {
Identified(ProbExpr),
NotIdentifiable(HedgeCertificate),
}
impl IdResult {
pub fn is_identified(&self) -> bool {
matches!(self, IdResult::Identified(_))
}
pub fn expression(&self) -> Option<&ProbExpr> {
match self {
IdResult::Identified(e) => Some(e),
IdResult::NotIdentifiable(_) => None,
}
}
pub fn hedge(&self) -> Option<&HedgeCertificate> {
match self {
IdResult::Identified(_) => None,
IdResult::NotIdentifiable(h) => Some(h),
}
}
}
pub fn do_calculus_rule1(
graph: &SemiMarkovGraph,
y: &BTreeSet<String>,
x: &BTreeSet<String>,
z: &BTreeSet<String>,
w: &BTreeSet<String>,
) -> bool {
let g_xbar = graph.mutilate(x);
let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
d_separated_set(&g_xbar, y, z, &conditioning)
}
pub fn do_calculus_rule2(
graph: &SemiMarkovGraph,
y: &BTreeSet<String>,
x: &BTreeSet<String>,
z: &BTreeSet<String>,
w: &BTreeSet<String>,
) -> bool {
let xz: BTreeSet<String> = x.union(z).cloned().collect();
let g_xbar_zbar = graph.mutilate(&xz);
let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
d_separated_set(&g_xbar_zbar, y, z, &conditioning)
}
pub fn do_calculus_rule3(
graph: &SemiMarkovGraph,
y: &BTreeSet<String>,
x: &BTreeSet<String>,
z: &BTreeSet<String>,
w: &BTreeSet<String>,
) -> bool {
let mut g_modified = graph.mutilate(x);
let anc_w = g_modified.ancestors(w);
for z_node in z {
let parents: Vec<String> = g_modified.parents(z_node).collect();
for parent in parents {
if !anc_w.contains(&parent) {
g_modified.remove_directed(&parent, z_node);
}
}
}
let conditioning: BTreeSet<String> = x.union(w).cloned().collect();
d_separated_set(&g_modified, y, z, &conditioning)
}
pub struct IdAlgorithm;
impl IdAlgorithm {
pub fn identify(
y: &[String],
x: &[String],
obs_dist: ProbExpr,
dag: &SemiMarkovGraph,
) -> IdResult {
let v: BTreeSet<String> = dag.node_set();
let y_set: BTreeSet<String> = y.iter().cloned().collect();
let x_set: BTreeSet<String> = x.iter().cloned().collect();
id_recursive(&y_set, &x_set, &obs_dist, dag, &v, 0)
}
}
fn id_recursive(
y: &BTreeSet<String>,
x: &BTreeSet<String>,
p: &ProbExpr,
g: &SemiMarkovGraph,
v: &BTreeSet<String>,
depth: usize,
) -> IdResult {
const MAX_DEPTH: usize = 64;
if depth > MAX_DEPTH {
return IdResult::NotIdentifiable(HedgeCertificate {
s_component: v.clone(),
blocking_x: x.clone(),
outcome_y: y.clone(),
explanation: "Recursion depth exceeded — potential cycle in ID algorithm.".to_string(),
});
}
if x.is_empty() {
return marginal_over(p, v, y);
}
let an_y: BTreeSet<String> = ancestors_of(g, &y.iter().cloned().collect::<Vec<_>>());
let v_minus_x: BTreeSet<String> = v.difference(x).cloned().collect();
if an_y != *v {
let w = an_y; let g_w = g.subgraph(&w);
let new_x: BTreeSet<String> = x.intersection(&w).cloned().collect();
let p_w = marginal_to_scope(p, v, &w);
return id_recursive(y, &new_x, &p_w, &g_w, &w, depth + 1);
}
let components_vmx = c_components_in_subgraph(g, &v_minus_x);
if components_vmx.len() > 1 {
let mut factor_results: Vec<ProbExpr> = Vec::new();
for si in &components_vmx {
let v_minus_si: BTreeSet<String> = v.difference(si).cloned().collect();
let sub = id_recursive(si, &v_minus_si, p, g, v, depth + 1);
match sub {
IdResult::Identified(expr) => factor_results.push(expr),
not_id => return not_id,
}
}
let product = make_product(factor_results);
let sum_out: Vec<String> = {
let mut sv: Vec<String> = v_minus_x.difference(y).cloned().collect();
sv.sort();
sv
};
let result = if sum_out.is_empty() {
product
} else {
ProbExpr::Marginal {
expr: Box::new(product),
summand_vars: sum_out,
}
.simplify()
};
return IdResult::Identified(result);
}
{
let g_v_minus_x = g.subgraph(&v_minus_x);
let an_y_in_g_vmx: BTreeSet<String> =
ancestors_of(&g_v_minus_x, &y.iter().cloned().collect::<Vec<_>>());
let an_y_vmx_restricted: BTreeSet<String> =
an_y_in_g_vmx.intersection(&v_minus_x).cloned().collect();
let w_line3: BTreeSet<String> = v_minus_x
.difference(&an_y_vmx_restricted)
.cloned()
.collect();
if !w_line3.is_empty() {
let new_x: BTreeSet<String> = x.union(&w_line3).cloned().collect();
return id_recursive(y, &new_x, p, g, v, depth + 1);
}
}
let components_full = c_components_in_subgraph(g, v);
if components_full.len() == 1 && components_full[0] == *v {
return IdResult::NotIdentifiable(HedgeCertificate {
s_component: v.clone(),
blocking_x: x.clone(),
outcome_y: y.clone(),
explanation: format!(
"Hedge: the entire variable set {:?} forms a single c-component in G, \
and G[V\\X] = {:?} is also a single c-component. \
P({:?} | do({:?})) is not identifiable.",
v, v_minus_x, y, x
),
});
}
let s_vmx = &v_minus_x;
let s_is_full_comp = components_full.iter().any(|fc| fc == s_vmx);
if s_is_full_comp {
let topo_full = topological_order(g);
let factors = build_tian_pearl_factors(s_vmx, &topo_full, v);
let product = make_product(factors);
let sum_out: Vec<String> = {
let mut sv: Vec<String> = s_vmx.difference(y).cloned().collect();
sv.sort();
sv
};
let result = if sum_out.is_empty() {
product
} else {
ProbExpr::Marginal {
expr: Box::new(product),
summand_vars: sum_out,
}
.simplify()
};
return IdResult::Identified(result);
}
let s_prime_opt = components_full
.iter()
.find(|fc| s_vmx.is_subset(fc) && *fc != s_vmx);
if let Some(s_prime) = s_prime_opt {
let topo_full = topological_order(g);
let topo_sp: Vec<String> = topo_full
.iter()
.filter(|v| s_prime.contains(*v))
.cloned()
.collect();
let factors = build_tian_pearl_factors(s_prime, &topo_full, v);
let p_s_prime = make_product(factors);
let g_s_prime = g.subgraph(s_prime);
let new_x: BTreeSet<String> = x.intersection(s_prime).cloned().collect();
return id_recursive(y, &new_x, &p_s_prime, &g_s_prime, s_prime, depth + 1);
}
for fc in &components_full {
let x_in_fc: BTreeSet<String> = x.intersection(fc).cloned().collect();
if !x_in_fc.is_empty() {
return IdResult::NotIdentifiable(HedgeCertificate {
s_component: fc.clone(),
blocking_x: x_in_fc,
outcome_y: y.clone(),
explanation: format!(
"Hedge: c-component {:?} of G contains intervention variables {:?} \
and outcome variables {:?}. P(y|do(x)) is not identifiable.",
fc, x, y
),
});
}
}
marginal_over(p, v, y)
}
fn build_tian_pearl_factors(
scope: &BTreeSet<String>,
topo_full: &[String],
_v_full: &BTreeSet<String>,
) -> Vec<ProbExpr> {
let pos: std::collections::HashMap<&str, usize> = topo_full
.iter()
.enumerate()
.map(|(i, v)| (v.as_str(), i))
.collect();
let mut factors: Vec<ProbExpr> = Vec::new();
let mut scope_sorted: Vec<&String> = scope.iter().collect();
scope_sorted.sort_by_key(|v| pos.get(v.as_str()).copied().unwrap_or(usize::MAX));
for vi in &scope_sorted {
let vi_pos = pos.get(vi.as_str()).copied().unwrap_or(0);
let preceding: Vec<String> = topo_full.iter().take(vi_pos).cloned().collect();
let factor = if preceding.is_empty() {
ProbExpr::Joint(vec![(*vi).clone()])
} else {
ProbExpr::Conditional {
numerator: Box::new(ProbExpr::Joint({
let mut vars = vec![(*vi).clone()];
vars.extend(preceding.iter().cloned());
vars.sort();
vars
})),
denominator: Box::new(ProbExpr::Joint(preceding)),
}
};
factors.push(factor);
}
factors
}
fn make_product(factors: Vec<ProbExpr>) -> ProbExpr {
if factors.is_empty() {
ProbExpr::Joint(Vec::new()) } else if factors.len() == 1 {
factors.into_iter().next().expect("length checked")
} else {
ProbExpr::Product(factors).simplify()
}
}
fn marginal_over(p: &ProbExpr, v: &BTreeSet<String>, y: &BTreeSet<String>) -> IdResult {
let sum_out: Vec<String> = {
let mut sv: Vec<String> = v.difference(y).cloned().collect();
sv.sort();
sv
};
if sum_out.is_empty() {
IdResult::Identified(p.clone())
} else {
let result = ProbExpr::Marginal {
expr: Box::new(p.clone()),
summand_vars: sum_out,
}
.simplify();
IdResult::Identified(result)
}
}
fn marginal_to_scope(p: &ProbExpr, v: &BTreeSet<String>, w: &BTreeSet<String>) -> ProbExpr {
let sum_out: Vec<String> = {
let mut sv: Vec<String> = v.difference(w).cloned().collect();
sv.sort();
sv
};
if sum_out.is_empty() {
p.clone()
} else {
ProbExpr::Marginal {
expr: Box::new(p.clone()),
summand_vars: sum_out,
}
.simplify()
}
}
fn d_separated_set(
g: &SemiMarkovGraph,
y: &BTreeSet<String>,
z: &BTreeSet<String>,
conditioning: &BTreeSet<String>,
) -> bool {
for yi in y {
for zi in z {
if !d_separated_pair(g, yi, zi, conditioning) {
return false;
}
}
}
true
}
fn d_separated_pair(
g: &SemiMarkovGraph,
src: &str,
dst: &str,
conditioning: &BTreeSet<String>,
) -> bool {
use std::collections::{HashSet, VecDeque};
if src == dst {
return conditioning.contains(src);
}
let ancestors_of_conditioning: BTreeSet<String> = g.ancestors(conditioning);
let mut visited: HashSet<(String, bool)> = HashSet::new();
let mut queue: VecDeque<(String, bool)> = VecDeque::new();
queue.push_back((src.to_owned(), true));
queue.push_back((src.to_owned(), false));
while let Some((node, via_child)) = queue.pop_front() {
if !visited.insert((node.clone(), via_child)) {
continue;
}
if node == dst {
return false; }
let is_obs = conditioning.contains(&node);
let is_anc_obs = ancestors_of_conditioning.contains(&node);
if via_child {
if !is_obs {
for parent in g.parents(&node) {
queue.push_back((parent, true));
}
for child in g.children(&node) {
queue.push_back((child, false));
}
for nb in g.bidirected_neighbors(&node) {
queue.push_back((nb, false));
}
}
if is_obs || is_anc_obs {
for parent in g.parents(&node) {
queue.push_back((parent, true));
}
}
} else {
if !is_obs {
for child in g.children(&node) {
queue.push_back((child, false));
}
for nb in g.bidirected_neighbors(&node) {
queue.push_back((nb, false));
}
} else {
for parent in g.parents(&node) {
queue.push_back((parent, true));
}
}
}
}
true }
#[cfg(test)]
mod tests {
use super::*;
use crate::causal::hedge::{c_components_in_subgraph, HedgeFinder};
use crate::causal::semi_markov_graph::SemiMarkovGraph;
use crate::causal::symbolic_prob::ProbExpr;
fn s(s: &str) -> String {
s.to_owned()
}
fn chain_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "Y");
g.add_directed("Y", "Z");
g
}
fn confounded_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "Y");
g.add_bidirected("X", "Y");
g
}
fn frontdoor_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "M");
g.add_directed("M", "Y");
g.add_bidirected("X", "Y");
g
}
fn iv_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
g.add_bidirected("X", "Y");
g
}
fn backdoor_graph() -> SemiMarkovGraph {
let mut g = SemiMarkovGraph::new();
g.add_directed("W", "X");
g.add_directed("W", "Y");
g.add_directed("X", "Y");
g
}
#[test]
fn test_c_components_chain_no_bidirected_via_id() {
let g = chain_graph();
let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
let comps = c_components_in_subgraph(&g, &vars);
assert_eq!(comps.len(), 3, "Expected 3 singletons, got {}", comps.len());
}
#[test]
fn test_c_components_bidirected_chain() {
let mut g = SemiMarkovGraph::new();
g.add_bidirected("X", "Y");
g.add_bidirected("Y", "Z");
let vars: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
let comps = c_components_in_subgraph(&g, &vars);
assert_eq!(comps.len(), 1);
assert_eq!(comps[0].len(), 3);
}
#[test]
fn test_topological_order_chain() {
let g = chain_graph();
let order = topological_order(&g);
let x_pos = order.iter().position(|v| v == "X").expect("X missing");
let y_pos = order.iter().position(|v| v == "Y").expect("Y missing");
let z_pos = order.iter().position(|v| v == "Z").expect("Z missing");
assert!(x_pos < y_pos);
assert!(y_pos < z_pos);
}
#[test]
fn test_ancestors_of_chain() {
let g = chain_graph();
let anc = ancestors_of(&g, &[s("Z")]);
assert!(anc.contains("X"));
assert!(anc.contains("Y"));
assert!(anc.contains("Z"));
}
#[test]
fn test_id_no_intervention_returns_marginal() {
let g = chain_graph();
let p = ProbExpr::p(vec![s("X"), s("Y"), s("Z")]);
let result = IdAlgorithm::identify(&[s("Z")], &[], p, &g);
assert!(
result.is_identified(),
"No intervention should be identifiable"
);
}
#[test]
fn test_id_backdoor_admissible() {
let g = backdoor_graph();
let p = ProbExpr::p(vec![s("W"), s("X"), s("Y")]);
let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
assert!(
result.is_identified(),
"Backdoor admissible graph should be identifiable; hedge: {:?}",
result.hedge()
);
}
#[test]
fn test_id_simple_confounder_not_identifiable() {
let g = confounded_graph();
let p = ProbExpr::p(vec![s("X"), s("Y")]);
let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
assert!(
!result.is_identified(),
"Pure confounder X↔Y with no instrument should NOT be identifiable"
);
}
#[test]
fn test_id_frontdoor_identifiable() {
let g = frontdoor_graph();
let p = ProbExpr::p(vec![s("X"), s("M"), s("Y")]);
let result = IdAlgorithm::identify(&[s("Y")], &[s("X")], p, &g);
assert!(
result.is_identified(),
"Front-door graph should be identifiable; hedge: {:?}",
result.hedge()
);
}
#[test]
fn test_id_iv_line4_decomposes() {
let g = iv_graph();
let v_minus_x: BTreeSet<String> = ["Z".to_string(), "Y".to_string()].into();
let comps = c_components_in_subgraph(&g, &v_minus_x);
assert_eq!(
comps.len(),
2,
"IV graph: C(G[V\\X]) should have 2 components ({{Z}} and {{Y}}), got {:?}",
comps
);
}
#[test]
fn test_id_iv_rule2_applies() {
let g = iv_graph();
let y: BTreeSet<String> = ["Y".to_string()].into();
let x: BTreeSet<String> = ["X".to_string()].into();
let z: BTreeSet<String> = ["Z".to_string()].into();
let w: BTreeSet<String> = BTreeSet::new();
let _rule2 = do_calculus_rule2(&g, &y, &x, &z, &w);
}
#[test]
fn test_hedge_finder_none_for_chain() {
let g = chain_graph();
let cert = HedgeFinder::find(&g, &[s("Z")], &[s("X")]);
assert!(cert.is_none(), "Chain graph should have no hedge");
}
#[test]
fn test_hedge_finder_certificate_for_confounded() {
let g = confounded_graph();
let cert = HedgeFinder::find(&g, &[s("Y")], &[s("X")]);
assert!(cert.is_some(), "Confounded graph should have a hedge");
let cert = cert.expect("certificate");
assert!(!cert.blocking_x.is_empty());
}
#[test]
fn test_prob_expr_do_display() {
let e = ProbExpr::p_do(vec![s("Y")], vec![s("X")]);
let disp = format!("{e}");
assert!(disp.contains("do(X)"), "Should show do(X): {disp}");
assert!(disp.contains("Y"), "Should show Y: {disp}");
}
#[test]
fn test_prob_expr_marginal_display() {
let inner = ProbExpr::p(vec![s("Y"), s("Z")]);
let marg = ProbExpr::marginal(inner, vec![s("Z")]);
let disp = format!("{marg}");
assert!(disp.contains("Σ_{Z}"), "Should contain Σ_{{Z}}: {disp}");
}
#[test]
fn test_product_two_conditionals_simplify() {
let e1 = ProbExpr::conditional(vec![s("Y")], vec![s("X")]);
let e2 = ProbExpr::conditional(vec![s("Z")], vec![s("M")]);
let prod = ProbExpr::product(vec![e1, e2]);
let simplified = prod.simplify();
match simplified {
ProbExpr::Product(ref terms) => assert_eq!(terms.len(), 2),
other => panic!("Expected Product, got {other:?}"),
}
}
#[test]
fn test_tian_pearl_factors_chain() {
let g = chain_graph();
let topo = topological_order(&g);
let scope: BTreeSet<String> = ["X", "Y", "Z"].iter().map(|s| s.to_string()).collect();
let v: BTreeSet<String> = scope.clone();
let factors = build_tian_pearl_factors(&scope, &topo, &v);
assert_eq!(factors.len(), 3, "One factor per variable in chain");
}
#[test]
fn test_do_calculus_rule1_applies() {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
let y: BTreeSet<String> = ["Y".to_string()].into();
let x: BTreeSet<String> = ["X".to_string()].into();
let z: BTreeSet<String> = ["Z".to_string()].into();
let w: BTreeSet<String> = BTreeSet::new();
let _applies = do_calculus_rule1(&g, &y, &x, &z, &w);
}
#[test]
fn test_do_calculus_rule2_applies() {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
let y: BTreeSet<String> = ["Y".to_string()].into();
let x: BTreeSet<String> = ["X".to_string()].into();
let z: BTreeSet<String> = ["Z".to_string()].into();
let w: BTreeSet<String> = BTreeSet::new();
let _applies = do_calculus_rule2(&g, &y, &x, &z, &w);
}
#[test]
fn test_do_calculus_rule3_applies() {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
let y: BTreeSet<String> = ["Y".to_string()].into();
let x: BTreeSet<String> = ["X".to_string()].into();
let z: BTreeSet<String> = ["Z".to_string()].into();
let w: BTreeSet<String> = BTreeSet::new();
let _applies = do_calculus_rule3(&g, &y, &x, &z, &w);
}
}