use std::collections::{BTreeMap, BTreeSet, VecDeque};
#[derive(Debug, Clone)]
pub struct SemiMarkovGraph {
node_list: Vec<String>,
node_set: BTreeSet<String>,
directed_children: BTreeMap<String, BTreeSet<String>>,
directed_parents: BTreeMap<String, BTreeSet<String>>,
bidirected: BTreeMap<String, BTreeSet<String>>,
}
impl Default for SemiMarkovGraph {
fn default() -> Self {
Self::new()
}
}
impl SemiMarkovGraph {
pub fn new() -> Self {
Self {
node_list: Vec::new(),
node_set: BTreeSet::new(),
directed_children: BTreeMap::new(),
directed_parents: BTreeMap::new(),
bidirected: BTreeMap::new(),
}
}
pub fn add_node(&mut self, name: &str) {
if self.node_set.insert(name.to_owned()) {
self.node_list.push(name.to_owned());
self.directed_children.entry(name.to_owned()).or_default();
self.directed_parents.entry(name.to_owned()).or_default();
self.bidirected.entry(name.to_owned()).or_default();
}
}
pub fn has_node(&self, name: &str) -> bool {
self.node_set.contains(name)
}
pub fn nodes(&self) -> impl Iterator<Item = &String> {
self.node_list.iter()
}
pub fn n_nodes(&self) -> usize {
self.node_list.len()
}
pub fn add_directed(&mut self, from: &str, to: &str) {
if from == to {
return;
}
self.add_node(from);
self.add_node(to);
self.directed_children
.entry(from.to_owned())
.or_default()
.insert(to.to_owned());
self.directed_parents
.entry(to.to_owned())
.or_default()
.insert(from.to_owned());
}
pub fn remove_directed(&mut self, from: &str, to: &str) -> bool {
let removed = self
.directed_children
.get_mut(from)
.map(|set| set.remove(to))
.unwrap_or(false);
if removed {
if let Some(set) = self.directed_parents.get_mut(to) {
set.remove(from);
}
}
removed
}
pub fn has_directed(&self, from: &str, to: &str) -> bool {
self.directed_children
.get(from)
.map(|s| s.contains(to))
.unwrap_or(false)
}
pub fn add_bidirected(&mut self, a: &str, b: &str) {
if a == b {
return;
}
self.add_node(a);
self.add_node(b);
self.bidirected
.entry(a.to_owned())
.or_default()
.insert(b.to_owned());
self.bidirected
.entry(b.to_owned())
.or_default()
.insert(a.to_owned());
}
pub fn remove_bidirected(&mut self, a: &str, b: &str) -> bool {
let removed = self
.bidirected
.get_mut(a)
.map(|set| set.remove(b))
.unwrap_or(false);
if removed {
if let Some(set) = self.bidirected.get_mut(b) {
set.remove(a);
}
}
removed
}
pub fn has_bidirected(&self, a: &str, b: &str) -> bool {
self.bidirected
.get(a)
.map(|s| s.contains(b))
.unwrap_or(false)
}
pub fn children<'a>(&'a self, node: &str) -> impl Iterator<Item = String> + 'a {
self.directed_children
.get(node)
.into_iter()
.flat_map(|s| s.iter().cloned())
}
pub fn parents<'a>(&'a self, node: &str) -> impl Iterator<Item = String> + 'a {
self.directed_parents
.get(node)
.into_iter()
.flat_map(|s| s.iter().cloned())
}
pub fn bidirected_neighbors<'a>(&'a self, node: &str) -> impl Iterator<Item = String> + 'a {
self.bidirected
.get(node)
.into_iter()
.flat_map(|s| s.iter().cloned())
}
pub fn subgraph(&self, vars: &BTreeSet<String>) -> Self {
let mut g = SemiMarkovGraph::new();
for v in vars {
if self.has_node(v) {
g.add_node(v);
}
}
for v in vars {
for child in self.children(v) {
if vars.contains(&child) {
g.add_directed(v, &child);
}
}
}
for v in vars {
for nb in self.bidirected_neighbors(v) {
if vars.contains(&nb) {
g.add_bidirected(v, &nb);
}
}
}
g
}
pub fn mutilate(&self, x_vars: &BTreeSet<String>) -> Self {
let mut g = self.clone();
for x in x_vars {
let parents: Vec<String> = g.parents(x).collect();
for parent in parents {
g.remove_directed(&parent, x);
}
}
g
}
pub fn ancestors(&self, y: &BTreeSet<String>) -> BTreeSet<String> {
let mut visited: BTreeSet<String> = BTreeSet::new();
let mut queue: VecDeque<String> = y.iter().cloned().collect();
while let Some(node) = queue.pop_front() {
if visited.insert(node.clone()) {
for parent in self.parents(&node) {
if !visited.contains(&parent) {
queue.push_back(parent);
}
}
}
}
visited
}
pub fn descendants(&self, x: &BTreeSet<String>) -> BTreeSet<String> {
let mut visited: BTreeSet<String> = BTreeSet::new();
let mut queue: VecDeque<String> = x.iter().cloned().collect();
while let Some(node) = queue.pop_front() {
if visited.insert(node.clone()) {
for child in self.children(&node) {
if !visited.contains(&child) {
queue.push_back(child);
}
}
}
}
visited
}
pub fn node_set(&self) -> BTreeSet<String> {
self.node_set.clone()
}
pub fn directed_edges(&self) -> Vec<(String, String)> {
let mut edges = Vec::new();
for (from, children) in &self.directed_children {
for to in children {
edges.push((from.clone(), to.clone()));
}
}
edges.sort();
edges
}
pub fn bidirected_edges(&self) -> Vec<(String, String)> {
let mut seen: BTreeSet<(String, String)> = BTreeSet::new();
for (a, neighbors) in &self.bidirected {
for b in neighbors {
let pair = if a <= b {
(a.clone(), b.clone())
} else {
(b.clone(), a.clone())
};
seen.insert(pair);
}
}
seen.into_iter().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_nodes_and_edges() {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "Y");
g.add_directed("Y", "Z");
g.add_bidirected("X", "Z");
assert!(g.has_node("X"));
assert!(g.has_node("Y"));
assert!(g.has_node("Z"));
assert!(g.has_directed("X", "Y"));
assert!(g.has_directed("Y", "Z"));
assert!(!g.has_directed("X", "Z"));
assert!(g.has_bidirected("X", "Z"));
assert!(g.has_bidirected("Z", "X")); assert!(!g.has_bidirected("X", "Y"));
}
#[test]
fn test_subgraph() {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "Y");
g.add_directed("Y", "Z");
g.add_bidirected("X", "Z");
let vars: BTreeSet<String> = ["X", "Y"].iter().map(|s| s.to_string()).collect();
let sub = g.subgraph(&vars);
assert!(sub.has_node("X"));
assert!(sub.has_node("Y"));
assert!(!sub.has_node("Z"));
assert!(sub.has_directed("X", "Y"));
assert!(!sub.has_directed("Y", "Z"));
assert!(!sub.has_bidirected("X", "Z"));
}
#[test]
fn test_mutilate() {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
g.add_bidirected("X", "Y");
let x_vars: BTreeSet<String> = ["X".to_string()].into();
let m = g.mutilate(&x_vars);
assert!(!m.has_directed("Z", "X"), "Z→X should be cut");
assert!(m.has_directed("X", "Y"), "X→Y should remain");
assert!(m.has_bidirected("X", "Y"), "X↔Y should remain");
}
#[test]
fn test_ancestors() {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "M");
g.add_directed("M", "Y");
let y_set: BTreeSet<String> = ["Y".to_string()].into();
let anc = g.ancestors(&y_set);
assert!(anc.contains("X"), "X is an ancestor of Y");
assert!(anc.contains("M"), "M is an ancestor of Y");
assert!(anc.contains("Y"), "Y is included");
}
#[test]
fn test_descendants() {
let mut g = SemiMarkovGraph::new();
g.add_directed("X", "M");
g.add_directed("M", "Y");
let x_set: BTreeSet<String> = ["X".to_string()].into();
let desc = g.descendants(&x_set);
assert!(desc.contains("X"), "X is included");
assert!(desc.contains("M"), "M is a descendant");
assert!(desc.contains("Y"), "Y is a descendant");
}
#[test]
fn test_parents_and_children() {
let mut g = SemiMarkovGraph::new();
g.add_directed("Z", "X");
g.add_directed("X", "Y");
let parents_x: Vec<String> = g.parents("X").collect();
assert!(parents_x.contains(&"Z".to_string()));
let children_x: Vec<String> = g.children("X").collect();
assert!(children_x.contains(&"Y".to_string()));
}
}