use crate::StatsError;
use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct DAG {
pub n_nodes: usize,
pub parents: Vec<Vec<usize>>,
pub children: Vec<Vec<usize>>,
pub node_names: Vec<String>,
}
impl DAG {
pub fn new(n: usize) -> Self {
let node_names = (0..n).map(|i| format!("X_{i}")).collect();
Self {
n_nodes: n,
parents: vec![Vec::new(); n],
children: vec![Vec::new(); n],
node_names,
}
}
pub fn with_names(names: Vec<String>) -> Self {
let n = names.len();
Self {
n_nodes: n,
parents: vec![Vec::new(); n],
children: vec![Vec::new(); n],
node_names: names,
}
}
pub fn add_edge(&mut self, from: usize, to: usize) -> Result<(), StatsError> {
if from >= self.n_nodes || to >= self.n_nodes {
return Err(StatsError::InvalidInput(format!(
"Node index out of range: from={from}, to={to}, n={}",
self.n_nodes
)));
}
if from == to {
return Err(StatsError::InvalidInput(
"Self-loops are not allowed in a DAG".to_string(),
));
}
if self.parents[to].contains(&from) {
return Ok(()); }
if self.can_reach(to, from) {
return Err(StatsError::InvalidInput(format!(
"Adding edge {from}→{to} would create a cycle"
)));
}
let pos = self.parents[to].partition_point(|&p| p < from);
self.parents[to].insert(pos, from);
let pos = self.children[from].partition_point(|&c| c < to);
self.children[from].insert(pos, to);
Ok(())
}
pub fn remove_edge(&mut self, from: usize, to: usize) -> bool {
let before = self.parents[to].len();
self.parents[to].retain(|&p| p != from);
self.children[from].retain(|&c| c != to);
self.parents[to].len() < before
}
pub fn is_dag(&self) -> bool {
self.topological_sort_full().is_some()
}
pub fn topological_sort(&self) -> Vec<usize> {
self.topological_sort_full().unwrap_or_default()
}
fn topological_sort_full(&self) -> Option<Vec<usize>> {
let mut in_degree: Vec<usize> = self.parents.iter().map(|p| p.len()).collect();
let mut queue: VecDeque<usize> = (0..self.n_nodes).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(self.n_nodes);
while let Some(node) = queue.pop_front() {
order.push(node);
for &child in &self.children[node] {
in_degree[child] -= 1;
if in_degree[child] == 0 {
queue.push_back(child);
}
}
}
if order.len() == self.n_nodes {
Some(order)
} else {
None }
}
fn can_reach(&self, start: usize, target: usize) -> bool {
let mut visited = vec![false; self.n_nodes];
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if node == target {
return true;
}
if visited[node] {
continue;
}
visited[node] = true;
for &child in &self.children[node] {
stack.push(child);
}
}
false
}
pub fn d_separation(&self, x: usize, y: usize, z: &[usize]) -> bool {
let z_set: HashSet<usize> = z.iter().copied().collect();
let mut z_or_ancestor_of_z: HashSet<usize> = z_set.clone();
for &zn in &z_set {
self.collect_ancestors_into(zn, &mut z_or_ancestor_of_z);
}
let mut visited: HashSet<(usize, bool)> = HashSet::new();
let mut queue: VecDeque<(usize, bool)> = VecDeque::new();
queue.push_back((x, true));
queue.push_back((x, false));
while let Some((node, going_up)) = queue.pop_front() {
if node == y {
return false; }
if visited.contains(&(node, going_up)) {
continue;
}
visited.insert((node, going_up));
if going_up {
if !z_set.contains(&node) {
for &parent in &self.parents[node] {
queue.push_back((parent, true));
}
for &child in &self.children[node] {
queue.push_back((child, false));
}
}
} else {
if !z_set.contains(&node) {
for &child in &self.children[node] {
queue.push_back((child, false));
}
}
if z_or_ancestor_of_z.contains(&node) {
for &parent in &self.parents[node] {
queue.push_back((parent, true));
}
}
}
}
true }
fn collect_ancestors_into(&self, node: usize, set: &mut HashSet<usize>) {
let mut stack = vec![node];
while let Some(n) = stack.pop() {
for &parent in &self.parents[n] {
if set.insert(parent) {
stack.push(parent);
}
}
}
}
fn collect_descendants(&self, node: usize, set: &mut HashSet<usize>) {
let mut stack = vec![node];
while let Some(n) = stack.pop() {
for &child in &self.children[n] {
if set.insert(child) {
stack.push(child);
}
}
}
}
pub fn ancestors(&self, node: usize) -> HashSet<usize> {
let mut set = HashSet::new();
let mut stack = vec![node];
while let Some(n) = stack.pop() {
for &parent in &self.parents[n] {
if set.insert(parent) {
stack.push(parent);
}
}
}
set
}
pub fn descendants(&self, node: usize) -> HashSet<usize> {
let mut set = HashSet::new();
self.collect_descendants(node, &mut set);
set
}
pub fn moral_graph(&self) -> Vec<Vec<bool>> {
let n = self.n_nodes;
let mut adj = vec![vec![false; n]; n];
for node in 0..n {
for &parent in &self.parents[node] {
adj[node][parent] = true;
adj[parent][node] = true;
}
let parents = &self.parents[node];
for (i, &p1) in parents.iter().enumerate() {
for &p2 in &parents[(i + 1)..] {
adj[p1][p2] = true;
adj[p2][p1] = true;
}
}
}
adj
}
pub fn markov_blanket(&self, node: usize) -> Vec<usize> {
let mut blanket: BTreeSet<usize> = BTreeSet::new();
for &p in &self.parents[node] {
blanket.insert(p);
}
for &child in &self.children[node] {
blanket.insert(child);
for &co_parent in &self.parents[child] {
if co_parent != node {
blanket.insert(co_parent);
}
}
}
blanket.into_iter().collect()
}
pub fn v_structures(&self) -> Vec<(usize, usize, usize)> {
let mut result = Vec::new();
for node in 0..self.n_nodes {
let parents = &self.parents[node];
for (i, &p1) in parents.iter().enumerate() {
for &p2 in &parents[(i + 1)..] {
if !self.children[p1].contains(&p2) && !self.children[p2].contains(&p1) {
result.push((p1, node, p2));
}
}
}
}
result
}
pub fn node_index(&self, name: &str) -> Option<usize> {
self.node_names.iter().position(|n| n == name)
}
pub fn n_edges(&self) -> usize {
self.children.iter().map(|c| c.len()).sum()
}
pub fn has_edge(&self, from: usize, to: usize) -> bool {
self.children[from].contains(&to)
}
pub fn name_map(&self) -> HashMap<&str, usize> {
self.node_names
.iter()
.enumerate()
.map(|(i, n)| (n.as_str(), i))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn chain_dag() -> DAG {
let mut dag = DAG::new(3);
dag.add_edge(0, 1).unwrap();
dag.add_edge(1, 2).unwrap();
dag
}
fn fork_dag() -> DAG {
let mut dag = DAG::new(3);
dag.add_edge(0, 1).unwrap();
dag.add_edge(0, 2).unwrap();
dag
}
fn collider_dag() -> DAG {
let mut dag = DAG::new(3);
dag.add_edge(0, 2).unwrap();
dag.add_edge(1, 2).unwrap();
dag
}
#[test]
fn test_add_edge_cycle() {
let mut dag = DAG::new(3);
dag.add_edge(0, 1).unwrap();
dag.add_edge(1, 2).unwrap();
assert!(dag.add_edge(2, 0).is_err(), "cycle should be rejected");
}
#[test]
fn test_add_edge_self_loop() {
let mut dag = DAG::new(2);
assert!(dag.add_edge(0, 0).is_err());
}
#[test]
fn test_topological_sort_chain() {
let dag = chain_dag();
let order = dag.topological_sort();
assert_eq!(order, vec![0, 1, 2]);
}
#[test]
fn test_is_dag() {
let dag = chain_dag();
assert!(dag.is_dag());
}
#[test]
fn test_d_separation_chain() {
let dag = chain_dag();
assert!(dag.d_separation(0, 2, &[1]));
assert!(!dag.d_separation(0, 2, &[]));
}
#[test]
fn test_d_separation_fork() {
let dag = fork_dag();
assert!(dag.d_separation(1, 2, &[0]));
assert!(!dag.d_separation(1, 2, &[]));
}
#[test]
fn test_d_separation_collider() {
let dag = collider_dag();
assert!(dag.d_separation(0, 1, &[]));
assert!(!dag.d_separation(0, 1, &[2]));
}
#[test]
fn test_moral_graph() {
let dag = collider_dag();
let moral = dag.moral_graph();
assert!(moral[0][1], "0 and 1 should be connected in moral graph");
assert!(moral[1][0]);
}
#[test]
fn test_markov_blanket() {
let mut dag = DAG::new(4);
dag.add_edge(0, 2).unwrap();
dag.add_edge(1, 2).unwrap();
dag.add_edge(1, 3).unwrap();
let mb2 = dag.markov_blanket(2);
assert!(mb2.contains(&0));
assert!(mb2.contains(&1));
let mb1 = dag.markov_blanket(1);
assert!(mb1.contains(&2));
assert!(mb1.contains(&3));
assert!(mb1.contains(&0));
}
#[test]
fn test_v_structures() {
let dag = collider_dag();
let vs = dag.v_structures();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], (0, 2, 1));
}
#[test]
fn test_ancestors_descendants() {
let dag = chain_dag();
let anc = dag.ancestors(2);
assert!(anc.contains(&0));
assert!(anc.contains(&1));
let desc = dag.descendants(0);
assert!(desc.contains(&1));
assert!(desc.contains(&2));
}
#[test]
fn test_remove_edge() {
let mut dag = chain_dag();
assert!(dag.remove_edge(0, 1));
assert!(!dag.has_edge(0, 1));
assert!(!dag.parents[1].contains(&0));
}
}