use std::collections::{HashSet, VecDeque};
use super::data::{BackdoorAdjustment, Intervention};
use super::error::CausalError;
use super::graph::CausalGraph;
pub fn backdoor_criterion(
graph: &CausalGraph,
treatment: &str,
outcome: &str,
adjustment_set: &[&str],
) -> bool {
let treatment_idx = match graph.node_index(treatment) {
Some(i) => i,
None => return false,
};
if graph.node_index(outcome).is_none() {
return false;
}
let treatment_desc = graph.descendants_of(treatment);
let treatment_desc_set: HashSet<String> = treatment_desc.into_iter().collect();
for &z in adjustment_set {
if treatment_desc_set.contains(z) {
return false;
}
}
let adj_idx_set: HashSet<usize> = adjustment_set
.iter()
.filter_map(|&z| graph.node_index(z))
.collect();
let outcome_idx = graph.node_index(outcome).unwrap_or(usize::MAX);
!graph.has_unblocked_backdoor_path(treatment_idx, outcome_idx, &adj_idx_set)
}
pub fn find_backdoor_adjustment(
graph: &CausalGraph,
treatment: &str,
outcome: &str,
) -> Result<BackdoorAdjustment, CausalError> {
if graph.node_index(treatment).is_none() {
return Err(CausalError::NodeNotFound(treatment.to_string()));
}
if graph.node_index(outcome).is_none() {
return Err(CausalError::NodeNotFound(outcome.to_string()));
}
if backdoor_criterion(graph, treatment, outcome, &[]) {
return Ok(BackdoorAdjustment {
adjustment_set: vec![],
valid: true,
});
}
let parents = graph.parents_of(treatment);
let parent_refs: Vec<&str> = parents.iter().map(|s| s.as_str()).collect();
if backdoor_criterion(graph, treatment, outcome, &parent_refs) {
return Ok(BackdoorAdjustment {
adjustment_set: parents,
valid: true,
});
}
let ancestors = graph.ancestors_of(treatment);
let treatment_desc: HashSet<String> = graph.descendants_of(treatment).into_iter().collect();
let mut candidate: Vec<String> = parents;
for anc in &ancestors {
if !treatment_desc.contains(anc) && !candidate.contains(anc) {
candidate.push(anc.clone());
let refs: Vec<&str> = candidate.iter().map(|s| s.as_str()).collect();
if backdoor_criterion(graph, treatment, outcome, &refs) {
return Ok(BackdoorAdjustment {
adjustment_set: candidate,
valid: true,
});
}
}
}
let refs: Vec<&str> = candidate.iter().map(|s| s.as_str()).collect();
let valid = backdoor_criterion(graph, treatment, outcome, &refs);
Ok(BackdoorAdjustment {
adjustment_set: candidate,
valid,
})
}
pub fn frontdoor_criterion(
graph: &CausalGraph,
treatment: &str,
outcome: &str,
mediator_set: &[&str],
) -> bool {
if graph.node_index(treatment).is_none() || graph.node_index(outcome).is_none() {
return false;
}
if mediator_set.is_empty() {
return false;
}
let mediator_idxs: HashSet<usize> = mediator_set
.iter()
.filter_map(|&m| graph.node_index(m))
.collect();
let treatment_idx = match graph.node_index(treatment) {
Some(i) => i,
None => return false,
};
let outcome_idx = match graph.node_index(outcome) {
Some(i) => i,
None => return false,
};
let bypasses_mediators = {
let mut visited: HashSet<usize> = HashSet::new();
let mut queue: VecDeque<usize> = VecDeque::new();
queue.push_back(treatment_idx);
let mut found = false;
while let Some(cur) = queue.pop_front() {
if cur == outcome_idx {
found = true;
break;
}
if !visited.insert(cur) {
continue;
}
for &(p, c) in &graph.edges {
if p == cur
&& !visited.contains(&c)
&& (c == outcome_idx || !mediator_idxs.contains(&c))
{
queue.push_back(c);
}
}
}
found
};
if bypasses_mediators {
return false;
}
let treatment_set: HashSet<usize> = std::iter::once(treatment_idx).collect();
for &m in mediator_set {
let m_idx = match graph.node_index(m) {
Some(i) => i,
None => return false,
};
if graph.has_unblocked_backdoor_path(treatment_idx, m_idx, &HashSet::new()) {
return false;
}
let _ = treatment_set.len(); }
let treatment_as_adj: HashSet<usize> = std::iter::once(treatment_idx).collect();
for &m in mediator_set {
let m_idx = match graph.node_index(m) {
Some(i) => i,
None => return false,
};
if graph.has_unblocked_backdoor_path(m_idx, outcome_idx, &treatment_as_adj) {
return false;
}
}
true
}
pub fn do_intervention(graph: &CausalGraph, intervention: &Intervention) -> CausalGraph {
let var_idx = graph.node_index(&intervention.variable);
let new_edges: Vec<(usize, usize)> = match var_idx {
None => graph.edges.clone(),
Some(idx) => graph
.edges
.iter()
.filter(|&&(_, c)| c != idx)
.cloned()
.collect(),
};
CausalGraph {
nodes: graph.nodes.clone(),
edges: new_edges,
}
}