use std::collections::{HashMap, HashSet};
use crate::causal_graph::dag::CausalDAG;
use crate::error::{StatsError, StatsResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DoCalculusRule {
Rule1,
Rule2,
Rule3,
None,
}
#[derive(Debug, Clone)]
pub struct BackdoorResult {
pub is_admissible: bool,
pub adjustment_set: Vec<String>,
pub all_minimal_sets: Vec<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct FrontdoorResult {
pub is_applicable: bool,
pub mediator_set: Vec<String>,
pub formula: String,
}
#[derive(Debug, Clone)]
pub struct IdResult {
pub identifiable: bool,
pub expression: String,
pub explanation: String,
}
#[derive(Debug, Clone)]
pub struct CComponent {
pub nodes: HashSet<usize>,
}
pub fn check_do_calculus_rule(
dag: &CausalDAG,
y: &[&str],
x: &[&str],
z: &[&str],
w: &[&str],
rule: DoCalculusRule,
) -> bool {
match rule {
DoCalculusRule::Rule1 => {
let mut g_xbar = dag.clone();
remove_incoming_edges(&mut g_xbar, x);
let mut conditioning: Vec<&str> = Vec::new();
conditioning.extend_from_slice(x);
conditioning.extend_from_slice(w);
check_d_separation_all(&g_xbar, y, z, &conditioning)
}
DoCalculusRule::Rule2 => {
let mut g = dag.clone();
remove_incoming_edges(&mut g, x);
remove_incoming_edges(&mut g, z);
let mut conditioning: Vec<&str> = Vec::new();
conditioning.extend_from_slice(x);
conditioning.extend_from_slice(w);
check_d_separation_all(&g, y, z, &conditioning)
}
DoCalculusRule::Rule3 => {
let mut g_xbar = dag.clone();
remove_incoming_edges(&mut g_xbar, x);
let w_ancestors = ancestors_of_names(&g_xbar, w);
let z_not_w_anc: Vec<&str> = z
.iter()
.filter(|&&zz| {
let idx = dag.node_index(zz);
!idx.map(|i| w_ancestors.contains(&i)).unwrap_or(false)
})
.copied()
.collect();
let mut g = g_xbar;
remove_outgoing_edges(&mut g, &z_not_w_anc);
let mut conditioning: Vec<&str> = Vec::new();
conditioning.extend_from_slice(x);
conditioning.extend_from_slice(w);
check_d_separation_all(&g, y, z, &conditioning)
}
DoCalculusRule::None => false,
}
}
pub fn satisfies_backdoor(dag: &CausalDAG, x: &str, y: &str, z_set: &[&str]) -> bool {
let desc_x = dag.descendants(x);
for &z in z_set {
if let Some(zi) = dag.node_index(z) {
if desc_x.contains(&zi) {
return false;
}
}
}
let mut g = dag.clone();
remove_outgoing_edges(&mut g, &[x]);
g.is_d_separated(x, y, z_set)
}
pub fn find_backdoor_sets(
dag: &CausalDAG,
x: &str,
y: &str,
max_set_size: usize,
) -> BackdoorResult {
let desc_x = dag.descendants(x);
let xi = dag.node_index(x).unwrap_or(usize::MAX);
let yi = dag.node_index(y).unwrap_or(usize::MAX);
let candidates: Vec<usize> = (0..dag.n_nodes())
.filter(|&i| i != xi && i != yi && !desc_x.contains(&i))
.collect();
let mut all_minimal: Vec<Vec<String>> = Vec::new();
let mut found_any = false;
'outer: for size in 0..=max_set_size.min(candidates.len()) {
for subset in subsets(&candidates, size) {
let z_names: Vec<&str> = subset.iter().filter_map(|&i| dag.node_name(i)).collect();
if satisfies_backdoor(dag, x, y, &z_names) {
let z_strings: Vec<String> = z_names.iter().map(|s| s.to_string()).collect();
all_minimal.push(z_strings);
found_any = true;
if all_minimal.len() >= 20 {
break 'outer;
}
}
}
if found_any && size < max_set_size {
}
}
let best = all_minimal.first().cloned().unwrap_or_default();
BackdoorResult {
is_admissible: found_any,
adjustment_set: best,
all_minimal_sets: all_minimal,
}
}
pub fn satisfies_frontdoor(dag: &CausalDAG, x: &str, y: &str, m_set: &[&str]) -> bool {
if !intercepts_all_paths(dag, x, y, m_set) {
return false;
}
let mut g_xbar = dag.clone();
remove_incoming_edges(&mut g_xbar, &[x]);
remove_outgoing_edges(&mut g_xbar, &[x]);
for &m in m_set {
if !g_xbar.is_d_separated(x, m, &[]) {
return false;
}
}
for &m in m_set {
if !satisfies_backdoor(dag, m, y, &[x]) {
return false;
}
}
true
}
pub fn find_frontdoor_set(dag: &CausalDAG, x: &str, y: &str) -> FrontdoorResult {
let xi = dag.node_index(x).unwrap_or(usize::MAX);
let yi = dag.node_index(y).unwrap_or(usize::MAX);
let descendants_x = dag.descendants(x);
let candidates: Vec<usize> = descendants_x
.iter()
.filter(|&&i| i != yi && i != xi)
.copied()
.collect();
for size in 1..=candidates.len() {
for subset in subsets(&candidates, size) {
let m_names: Vec<&str> = subset.iter().filter_map(|&i| dag.node_name(i)).collect();
if satisfies_frontdoor(dag, x, y, &m_names) {
let formula = frontdoor_formula(x, y, &m_names);
return FrontdoorResult {
is_applicable: true,
mediator_set: m_names.iter().map(|s| s.to_string()).collect(),
formula,
};
}
}
}
FrontdoorResult {
is_applicable: false,
mediator_set: Vec::new(),
formula: "Not identifiable via frontdoor".to_owned(),
}
}
pub fn id_algorithm(dag: &CausalDAG, y: &[&str], x: &[&str]) -> IdResult {
if x.is_empty() {
return IdResult {
identifiable: true,
expression: format!("P({})", y.join(", ")),
explanation: "No intervention; trivially identified as the observational distribution."
.to_owned(),
};
}
if x.len() == 1 && y.len() == 1 {
let xv = x[0];
let yv = y[0];
if satisfies_backdoor(dag, xv, yv, &[]) {
return IdResult {
identifiable: true,
expression: format!("P({yv} | {xv})"),
explanation: "Identified via empty backdoor set (no confounding).".to_owned(),
};
}
let bd = find_backdoor_sets(dag, xv, yv, 5);
if bd.is_admissible {
let z_str = bd.adjustment_set.join(", ");
return IdResult {
identifiable: true,
expression: format!("Σ_{{{}}} P({yv} | {xv}, {z_str}) P({z_str})", z_str,),
explanation: format!("Identified via backdoor adjustment on {{{z_str}}}."),
};
}
let fd = find_frontdoor_set(dag, xv, yv);
if fd.is_applicable {
return IdResult {
identifiable: true,
expression: fd.formula,
explanation: format!(
"Identified via frontdoor criterion through mediators: {:?}.",
fd.mediator_set
),
};
}
}
let tian = tian_pearl_id(dag, y, x);
if tian.identifiable {
return tian;
}
IdResult {
identifiable: false,
expression: String::new(),
explanation: format!(
"P({y} | do({x})) is not identifiable by the ID algorithm with the given DAG.",
y = y.join(", "),
x = x.join(", ")
),
}
}
pub fn tian_pearl_id(dag: &CausalDAG, y: &[&str], x: &[&str]) -> IdResult {
let topo = dag.topological_sort();
let n = dag.n_nodes();
let topo_pos: HashMap<&str, usize> = topo
.iter()
.enumerate()
.map(|(i, &name)| (name, i))
.collect();
let y_set: HashSet<&str> = y.iter().copied().collect();
let x_set: HashSet<&str> = x.iter().copied().collect();
let sum_over: Vec<&str> = topo
.iter()
.copied()
.filter(|&v| !y_set.contains(v) && !x_set.contains(v))
.collect();
let mut numerator_parts: Vec<String> = Vec::new();
let mut denominator_parts: Vec<String> = Vec::new();
for &node in &topo {
let pos = topo_pos[node];
let pa: Vec<&str> = dag.parents(node);
let prior: Vec<&str> = topo[..pos].to_vec();
let cond: Vec<String> = pa
.iter()
.map(|s| s.to_string())
.chain(prior.iter().map(|s| s.to_string()))
.collect();
let cond_str = if cond.is_empty() {
String::new()
} else {
format!(" | {}", cond.join(", "))
};
if !x_set.contains(node) {
numerator_parts.push(format!("P({node}{cond_str})"));
}
if pa.iter().any(|p| x_set.contains(*p)) || prior.iter().any(|p| x_set.contains(*p)) {
denominator_parts.push(format!("P({node}{cond_str})"));
}
}
let sum_str = if sum_over.is_empty() {
String::new()
} else {
format!("Σ_{{{}}}", sum_over.join(","))
};
let num_str = numerator_parts.join(" ");
let expr = if denominator_parts.is_empty() {
format!("{sum_str} {num_str}")
} else {
format!("{sum_str} {num_str} / ({})", denominator_parts.join(" "))
};
IdResult {
identifiable: n > 0,
expression: expr.trim().to_owned(),
explanation: "Tian-Pearl c-component factorization (DAG, no hidden variables).".to_owned(),
}
}
pub fn c_components_with_hidden(dag: &CausalDAG, bidirected: &[(&str, &str)]) -> Vec<CComponent> {
let n = dag.n_nodes();
let mut union_find: Vec<usize> = (0..n).collect();
fn find(uf: &mut Vec<usize>, mut i: usize) -> usize {
while uf[i] != i {
uf[i] = uf[uf[i]]; i = uf[i];
}
i
}
fn union(uf: &mut Vec<usize>, a: usize, b: usize) {
let ra = find(uf, a);
let rb = find(uf, b);
if ra != rb {
uf[ra] = rb;
}
}
for &(u, v) in bidirected {
if let (Some(ui), Some(vi)) = (dag.node_index(u), dag.node_index(v)) {
union(&mut union_find, ui, vi);
}
}
let mut comp_map: HashMap<usize, HashSet<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut union_find, i);
comp_map.entry(root).or_default().insert(i);
}
comp_map
.into_values()
.map(|nodes| CComponent { nodes })
.collect()
}
fn remove_incoming_edges(dag: &mut CausalDAG, targets: &[&str]) {
let target_idxs: HashSet<usize> = targets.iter().filter_map(|&t| dag.node_index(t)).collect();
dag.remove_incoming_edges_for(&target_idxs);
}
fn remove_outgoing_edges(dag: &mut CausalDAG, targets: &[&str]) {
let target_idxs: HashSet<usize> = targets.iter().filter_map(|&t| dag.node_index(t)).collect();
dag.remove_outgoing_edges_for(&target_idxs);
}
fn check_d_separation_all(dag: &CausalDAG, y: &[&str], z: &[&str], conditioning: &[&str]) -> bool {
for &yi in y {
for &zi in z {
if !dag.is_d_separated(yi, zi, conditioning) {
return false;
}
}
}
true
}
fn ancestors_of_names(dag: &CausalDAG, names: &[&str]) -> HashSet<usize> {
let mut all_anc = HashSet::new();
for &name in names {
for anc in dag.ancestors(name) {
all_anc.insert(anc);
}
}
all_anc
}
fn intercepts_all_paths(dag: &CausalDAG, x: &str, y: &str, m_set: &[&str]) -> bool {
let xi = match dag.node_index(x) {
None => return true,
Some(i) => i,
};
let yi = match dag.node_index(y) {
None => return true,
Some(i) => i,
};
let m_idxs: HashSet<usize> = m_set.iter().filter_map(|&m| dag.node_index(m)).collect();
let mut stack: Vec<usize> = vec![xi];
let mut visited: HashSet<usize> = HashSet::new();
while let Some(cur) = stack.pop() {
if cur == yi {
return false; }
if !visited.insert(cur) {
continue;
}
for c in dag.children(dag.node_name(cur).unwrap_or("")) {
if let Some(ci) = dag.node_index(c) {
if !m_idxs.contains(&ci) {
stack.push(ci);
}
}
}
}
true
}
fn frontdoor_formula(x: &str, y: &str, m_set: &[&str]) -> String {
let m_str = m_set.join(", ");
format!(
"Σ_{{{m_str}}} P({m_str} | {x}) Σ_{{{x}'}} P({y} | {x}', {m_str}) P({x}')",
m_str = m_str,
x = x,
y = y,
)
}
fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
if k == 0 {
return vec![Vec::new()];
}
if k > items.len() {
return Vec::new();
}
let mut result = Vec::new();
for i in 0..=(items.len() - k) {
for mut rest in subsets(&items[i + 1..], k - 1) {
rest.insert(0, items[i]);
result.push(rest);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::causal_graph::dag::CausalDAG;
fn smoke_dag() -> CausalDAG {
let mut dag = CausalDAG::new();
dag.add_edge("X", "M").unwrap();
dag.add_edge("M", "Y").unwrap();
dag
}
fn confounded_dag() -> CausalDAG {
let mut dag = CausalDAG::new();
dag.add_edge("Z", "X").unwrap();
dag.add_edge("Z", "Y").unwrap();
dag.add_edge("X", "Y").unwrap();
dag
}
#[test]
fn test_backdoor_with_z() {
let dag = confounded_dag();
assert!(satisfies_backdoor(&dag, "X", "Y", &["Z"]));
assert!(!satisfies_backdoor(&dag, "X", "Y", &[]));
}
#[test]
fn test_find_backdoor_set() {
let dag = confounded_dag();
let res = find_backdoor_sets(&dag, "X", "Y", 3);
assert!(res.is_admissible);
assert!(res.adjustment_set.contains(&"Z".to_string()));
}
#[test]
fn test_frontdoor() {
let dag = smoke_dag();
assert!(satisfies_frontdoor(&dag, "X", "Y", &["M"]));
let fd = find_frontdoor_set(&dag, "X", "Y");
assert!(fd.is_applicable);
}
#[test]
fn test_id_trivial() {
let dag = smoke_dag();
let res = id_algorithm(&dag, &["Y"], &[]);
assert!(res.identifiable);
assert!(res.expression.contains('P'));
}
#[test]
fn test_tian_pearl() {
let dag = smoke_dag();
let res = tian_pearl_id(&dag, &["Y"], &["X"]);
assert!(res.identifiable);
}
#[test]
fn test_c_components_with_hidden() {
let dag = smoke_dag();
let comps = c_components_with_hidden(&dag, &[]);
assert_eq!(comps.len(), dag.n_nodes());
let comps2 = c_components_with_hidden(&dag, &[("X", "Y")]);
assert!(comps2.len() < dag.n_nodes());
}
#[test]
fn test_do_calculus_rule1() {
let dag = confounded_dag();
let applies =
check_do_calculus_rule(&dag, &["Y"], &["X"], &["Z"], &[], DoCalculusRule::Rule1);
let _ = applies;
}
#[test]
fn test_subsets() {
let items = vec![1, 2, 3];
assert_eq!(subsets(&items, 0).len(), 1);
assert_eq!(subsets(&items, 1).len(), 3);
assert_eq!(subsets(&items, 2).len(), 3);
assert_eq!(subsets(&items, 3).len(), 1);
}
}