use std::collections::{HashMap, HashSet, VecDeque};
use crate::error::{StatsError, StatsResult};
#[derive(Debug, Clone)]
pub struct CausalDAG {
nodes: Vec<String>,
edges: Vec<(usize, usize)>,
node_map: HashMap<String, usize>,
}
impl Default for CausalDAG {
fn default() -> Self {
Self::new()
}
}
impl CausalDAG {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
node_map: HashMap::new(),
}
}
pub fn add_node(&mut self, name: &str) -> usize {
if let Some(&idx) = self.node_map.get(name) {
return idx;
}
let idx = self.nodes.len();
self.nodes.push(name.to_owned());
self.node_map.insert(name.to_owned(), idx);
idx
}
pub fn add_edge(&mut self, parent: &str, child: &str) -> StatsResult<()> {
let p = self.add_node(parent);
let c = self.add_node(child);
if p == c {
return Err(StatsError::InvalidArgument(
"Self-loops are not allowed in a DAG".to_owned(),
));
}
self.edges.push((p, c));
if self.has_cycle() {
self.edges.pop();
return Err(StatsError::InvalidArgument(format!(
"Adding edge {parent} → {child} would create a cycle"
)));
}
Ok(())
}
pub fn n_nodes(&self) -> usize {
self.nodes.len()
}
pub fn n_edges(&self) -> usize {
self.edges.len()
}
pub fn node_name(&self, idx: usize) -> Option<&str> {
self.nodes.get(idx).map(String::as_str)
}
pub fn node_index(&self, name: &str) -> Option<usize> {
self.node_map.get(name).copied()
}
pub fn node_names(&self) -> Vec<&str> {
self.nodes.iter().map(String::as_str).collect()
}
pub fn edge_list(&self) -> Vec<(&str, &str)> {
self.edges
.iter()
.map(|&(p, c)| (self.nodes[p].as_str(), self.nodes[c].as_str()))
.collect()
}
fn parent_indices(&self, idx: usize) -> Vec<usize> {
self.edges
.iter()
.filter(|&&(_, c)| c == idx)
.map(|&(p, _)| p)
.collect()
}
fn child_indices(&self, idx: usize) -> Vec<usize> {
self.edges
.iter()
.filter(|&&(p, _)| p == idx)
.map(|&(_, c)| c)
.collect()
}
pub fn parents(&self, node: &str) -> Vec<&str> {
match self.node_map.get(node) {
None => Vec::new(),
Some(&idx) => self
.parent_indices(idx)
.into_iter()
.map(|i| self.nodes[i].as_str())
.collect(),
}
}
pub fn children(&self, node: &str) -> Vec<&str> {
match self.node_map.get(node) {
None => Vec::new(),
Some(&idx) => self
.child_indices(idx)
.into_iter()
.map(|i| self.nodes[i].as_str())
.collect(),
}
}
pub fn ancestors(&self, node: &str) -> HashSet<usize> {
let start = match self.node_map.get(node) {
None => return HashSet::new(),
Some(&i) => i,
};
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for p in self.parent_indices(start) {
queue.push_back(p);
}
while let Some(cur) = queue.pop_front() {
if visited.insert(cur) {
for p in self.parent_indices(cur) {
if !visited.contains(&p) {
queue.push_back(p);
}
}
}
}
visited
}
pub fn descendants(&self, node: &str) -> HashSet<usize> {
let start = match self.node_map.get(node) {
None => return HashSet::new(),
Some(&i) => i,
};
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for c in self.child_indices(start) {
queue.push_back(c);
}
while let Some(cur) = queue.pop_front() {
if visited.insert(cur) {
for c in self.child_indices(cur) {
if !visited.contains(&c) {
queue.push_back(c);
}
}
}
}
visited
}
pub fn is_d_separated(&self, x: &str, y: &str, z: &[&str]) -> bool {
let xi = match self.node_map.get(x) {
None => return true,
Some(&i) => i,
};
let yi = match self.node_map.get(y) {
None => return true,
Some(&i) => i,
};
let observed: HashSet<usize> = z
.iter()
.filter_map(|name| self.node_map.get(*name).copied())
.collect();
let observed_ancestors = self.ancestors_of_set(&observed);
let mut visited: HashSet<(usize, bool)> = HashSet::new();
let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
queue.push_back((xi, true));
queue.push_back((xi, false));
while let Some((node, via_child)) = queue.pop_front() {
if !visited.insert((node, via_child)) {
continue;
}
if node == yi {
return false;
}
let is_observed = observed.contains(&node);
let is_anc_obs = observed_ancestors.contains(&node);
if via_child && !is_observed {
for p in self.parent_indices(node) {
queue.push_back((p, true));
}
for c in self.child_indices(node) {
queue.push_back((c, false));
}
}
if !via_child && !is_observed {
for c in self.child_indices(node) {
queue.push_back((c, false));
}
}
if !via_child && is_observed {
for p in self.parent_indices(node) {
queue.push_back((p, true));
}
}
if via_child && (is_observed || is_anc_obs) {
for p in self.parent_indices(node) {
queue.push_back((p, true));
}
}
}
true
}
pub fn markov_blanket(&self, node: &str) -> Vec<&str> {
let idx = match self.node_map.get(node) {
None => return Vec::new(),
Some(&i) => i,
};
let mut mb = HashSet::new();
for p in self.parent_indices(idx) {
mb.insert(p);
}
for c in self.child_indices(idx) {
mb.insert(c);
for p in self.parent_indices(c) {
if p != idx {
mb.insert(p);
}
}
}
mb.into_iter().map(|i| self.nodes[i].as_str()).collect()
}
pub fn topological_sort(&self) -> Vec<&str> {
let n = self.nodes.len();
let mut in_degree = vec![0usize; n];
for &(_, c) in &self.edges {
in_degree[c] += 1;
}
let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(u) = queue.pop_front() {
order.push(u);
for c in self.child_indices(u) {
in_degree[c] -= 1;
if in_degree[c] == 0 {
queue.push_back(c);
}
}
}
order.into_iter().map(|i| self.nodes[i].as_str()).collect()
}
pub fn c_components(&self) -> Vec<HashSet<usize>> {
let n = self.nodes.len();
let mut component = vec![usize::MAX; n];
let mut next_comp = 0usize;
for i in 0..n {
if component[i] == usize::MAX {
component[i] = next_comp;
next_comp += 1;
}
}
let mut comps: Vec<HashSet<usize>> = Vec::new();
for i in 0..n {
let c = component[i];
while comps.len() <= c {
comps.push(HashSet::new());
}
comps[c].insert(i);
}
comps
}
fn has_cycle(&self) -> bool {
let n = self.nodes.len();
let mut colour = vec![0u8; n];
for start in 0..n {
if colour[start] == 0 && self.dfs_cycle(start, &mut colour) {
return true;
}
}
false
}
fn dfs_cycle(&self, node: usize, colour: &mut Vec<u8>) -> bool {
colour[node] = 1; for c in self.child_indices(node) {
if colour[c] == 1 {
return true; }
if colour[c] == 0 && self.dfs_cycle(c, colour) {
return true;
}
}
colour[node] = 2; false
}
fn ancestors_of_set(&self, nodes: &HashSet<usize>) -> HashSet<usize> {
let mut ancestors = HashSet::new();
let mut queue: VecDeque<usize> = nodes.iter().copied().collect();
while let Some(cur) = queue.pop_front() {
for p in self.parent_indices(cur) {
if ancestors.insert(p) {
queue.push_back(p);
}
}
}
ancestors
}
pub(crate) fn adjacency(&self, idx: usize) -> (Vec<usize>, Vec<usize>) {
(self.parent_indices(idx), self.child_indices(idx))
}
pub(crate) fn remove_incoming_edges_for(
&mut self,
target_indices: &std::collections::HashSet<usize>,
) {
self.edges.retain(|&(_, c)| !target_indices.contains(&c));
}
pub(crate) fn remove_outgoing_edges_for(
&mut self,
target_indices: &std::collections::HashSet<usize>,
) {
self.edges.retain(|&(p, _)| !target_indices.contains(&p));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_chain() -> CausalDAG {
let mut dag = CausalDAG::new();
dag.add_edge("X", "Y").unwrap();
dag.add_edge("Y", "Z").unwrap();
dag
}
fn build_fork() -> CausalDAG {
let mut dag = CausalDAG::new();
dag.add_edge("X", "Y").unwrap();
dag.add_edge("X", "Z").unwrap();
dag
}
fn build_collider() -> CausalDAG {
let mut dag = CausalDAG::new();
dag.add_edge("X", "Z").unwrap();
dag.add_edge("Y", "Z").unwrap();
dag
}
#[test]
fn test_cycle_detection() {
let mut dag = CausalDAG::new();
dag.add_edge("A", "B").unwrap();
dag.add_edge("B", "C").unwrap();
let res = dag.add_edge("C", "A");
assert!(res.is_err(), "Should reject cycle A→B→C→A");
}
#[test]
fn test_self_loop_rejected() {
let mut dag = CausalDAG::new();
assert!(dag.add_edge("A", "A").is_err());
}
#[test]
fn test_parents_children() {
let mut dag = CausalDAG::new();
dag.add_edge("A", "B").unwrap();
dag.add_edge("A", "C").unwrap();
dag.add_edge("B", "C").unwrap();
let mut pa_c = dag.parents("C");
pa_c.sort();
assert_eq!(pa_c, vec!["A", "B"]);
let mut ch_a = dag.children("A");
ch_a.sort();
assert_eq!(ch_a, vec!["B", "C"]);
}
#[test]
fn test_ancestors_descendants() {
let dag = build_chain();
let xi = dag.node_index("X").unwrap();
let yi = dag.node_index("Y").unwrap();
let zi = dag.node_index("Z").unwrap();
let anc_z = dag.ancestors("Z");
assert!(anc_z.contains(&xi));
assert!(anc_z.contains(&yi));
let desc_x = dag.descendants("X");
assert!(desc_x.contains(&yi));
assert!(desc_x.contains(&zi));
}
#[test]
fn test_d_separation_chain() {
let dag = build_chain();
assert!(dag.is_d_separated("X", "Z", &["Y"]));
assert!(!dag.is_d_separated("X", "Z", &[]));
}
#[test]
fn test_d_separation_fork() {
let dag = build_fork();
assert!(dag.is_d_separated("Y", "Z", &["X"]));
assert!(!dag.is_d_separated("Y", "Z", &[]));
}
#[test]
fn test_d_separation_collider() {
let dag = build_collider();
assert!(dag.is_d_separated("X", "Y", &[]));
assert!(!dag.is_d_separated("X", "Y", &["Z"]));
}
#[test]
fn test_markov_blanket() {
let mut dag = CausalDAG::new();
dag.add_edge("A", "B").unwrap();
dag.add_edge("C", "B").unwrap();
dag.add_edge("B", "D").unwrap();
dag.add_edge("E", "D").unwrap();
let mut mb = dag.markov_blanket("B");
mb.sort();
assert_eq!(mb, vec!["A", "C", "D", "E"]);
}
#[test]
fn test_topological_sort() {
let dag = build_chain();
let order = dag.topological_sort();
let xi = order.iter().position(|&s| s == "X").unwrap();
let yi = order.iter().position(|&s| s == "Y").unwrap();
let zi = order.iter().position(|&s| s == "Z").unwrap();
assert!(xi < yi && yi < zi);
}
}