use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EdgeMark {
Tail,
Arrowhead,
Circle,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct LaggedNode {
pub var_idx: usize,
pub lag: i32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PagEdge {
pub from_mark: EdgeMark,
pub to_mark: EdgeMark,
}
#[derive(Debug, Clone, Default)]
pub struct PartialAncestralGraph {
pub n_vars: usize,
pub tau_max: usize,
pub n_nodes: usize,
adjacency: HashMap<(usize, usize), PagEdge>,
pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
}
impl PartialAncestralGraph {
pub fn new(n_nodes: usize) -> Self {
Self {
n_vars: n_nodes,
tau_max: 0,
n_nodes,
adjacency: HashMap::new(),
sep_sets: HashMap::new(),
}
}
pub fn with_vars_and_lags(n_vars: usize, tau_max: usize) -> Self {
let n_nodes = n_vars * (tau_max + 1);
Self {
n_vars,
tau_max,
n_nodes,
adjacency: HashMap::new(),
sep_sets: HashMap::new(),
}
}
pub fn initialize_from_skeleton(skeleton_adj: &[Vec<bool>], n_nodes: usize) -> Self {
let mut pag = Self {
n_vars: n_nodes,
tau_max: 0,
n_nodes,
adjacency: HashMap::new(),
sep_sets: HashMap::new(),
};
for i in 0..n_nodes {
if i >= skeleton_adj.len() {
break;
}
for j in (i + 1)..n_nodes {
if j >= skeleton_adj[i].len() {
break;
}
if skeleton_adj[i][j] {
pag.add_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
}
}
}
pag
}
pub fn add_edge(&mut self, i: usize, j: usize, from_mark: EdgeMark, to_mark: EdgeMark) {
let (key, edge) = if i <= j {
((i, j), PagEdge { from_mark, to_mark })
} else {
(
(j, i),
PagEdge {
from_mark: to_mark,
to_mark: from_mark,
},
)
};
self.adjacency.insert(key, edge);
}
pub fn remove_edge(&mut self, i: usize, j: usize) {
let key = Self::canonical_key(i, j);
self.adjacency.remove(&key);
}
pub fn has_edge(&self, i: usize, j: usize) -> bool {
let key = Self::canonical_key(i, j);
self.adjacency.contains_key(&key)
}
pub fn get_edge(&self, i: usize, j: usize) -> Option<PagEdge> {
let key = Self::canonical_key(i, j);
self.adjacency.get(&key).map(|e| {
if i <= j {
e.clone()
} else {
PagEdge {
from_mark: e.to_mark,
to_mark: e.from_mark,
}
}
})
}
pub fn set_mark(&mut self, from: usize, to: usize, mark: EdgeMark) {
let key = Self::canonical_key(from, to);
if let Some(edge) = self.adjacency.get_mut(&key) {
if from <= to {
edge.to_mark = mark;
} else {
edge.from_mark = mark;
}
}
}
pub fn get_mark_at(&self, from: usize, to: usize) -> Option<EdgeMark> {
let key = Self::canonical_key(from, to);
self.adjacency
.get(&key)
.map(|e| if from <= to { e.to_mark } else { e.from_mark })
}
pub fn adjacent_nodes(&self, node: usize) -> Vec<usize> {
let mut neighbors = Vec::new();
for &(a, b) in self.adjacency.keys() {
if a == node {
neighbors.push(b);
} else if b == node {
neighbors.push(a);
}
}
neighbors.sort_unstable();
neighbors.dedup();
neighbors
}
pub fn is_parent(&self, parent: usize, child: usize) -> bool {
if let Some(edge) = self.get_edge(parent, child) {
edge.from_mark == EdgeMark::Tail && edge.to_mark == EdgeMark::Arrowhead
} else {
false
}
}
pub fn n_bidirected_edges(&self) -> usize {
self.adjacency
.values()
.filter(|e| e.from_mark == EdgeMark::Arrowhead && e.to_mark == EdgeMark::Arrowhead)
.count()
}
pub fn n_circle_marks(&self) -> usize {
self.adjacency.values().fold(0, |acc, e| {
let from_circle = usize::from(e.from_mark == EdgeMark::Circle);
let to_circle = usize::from(e.to_mark == EdgeMark::Circle);
acc + from_circle + to_circle
})
}
pub fn possible_ancestors(&self, node: usize) -> Vec<usize> {
let mut visited = std::collections::HashSet::new();
let mut stack = vec![node];
while let Some(current) = stack.pop() {
if !visited.insert(current) {
continue;
}
for neighbor in self.adjacent_nodes(current) {
if let Some(mark_at_current) = self.get_mark_at(neighbor, current) {
if mark_at_current == EdgeMark::Arrowhead || mark_at_current == EdgeMark::Circle
{
if !visited.contains(&neighbor) {
stack.push(neighbor);
}
}
}
}
}
visited.remove(&node);
let mut ancestors: Vec<usize> = visited.into_iter().collect();
ancestors.sort_unstable();
ancestors
}
pub fn edges(&self) -> impl Iterator<Item = (usize, usize, &PagEdge)> {
self.adjacency.iter().map(|(&(a, b), edge)| (a, b, edge))
}
pub fn edge_node_pairs(&self) -> Vec<(usize, usize)> {
let mut pairs: Vec<(usize, usize)> = self.adjacency.keys().copied().collect();
pairs.sort_unstable();
pairs
}
fn canonical_key(i: usize, j: usize) -> (usize, usize) {
if i <= j {
(i, j)
} else {
(j, i)
}
}
}
#[cfg(test)]
mod pag_tests {
use super::*;
#[test]
fn test_pag_new_empty() {
let pag = PartialAncestralGraph::new(4);
assert_eq!(pag.n_nodes, 4);
assert_eq!(pag.n_bidirected_edges(), 0);
assert_eq!(pag.n_circle_marks(), 0);
}
#[test]
fn test_pag_add_remove_edge() {
let mut pag = PartialAncestralGraph::new(3);
pag.add_edge(0, 1, EdgeMark::Circle, EdgeMark::Circle);
assert!(pag.has_edge(0, 1));
assert!(pag.has_edge(1, 0)); pag.remove_edge(0, 1);
assert!(!pag.has_edge(0, 1));
}
#[test]
fn test_pag_initialize_from_skeleton_all_circles() {
let adj = vec![
vec![false, true, false],
vec![true, false, true],
vec![false, true, false],
];
let pag = PartialAncestralGraph::initialize_from_skeleton(&adj, 3);
assert!(pag.has_edge(0, 1));
assert!(pag.has_edge(1, 2));
assert!(!pag.has_edge(0, 2));
assert_eq!(pag.n_circle_marks(), 4); }
#[test]
fn test_pag_set_and_get_mark() {
let mut pag = PartialAncestralGraph::new(3);
pag.add_edge(0, 1, EdgeMark::Circle, EdgeMark::Circle);
pag.set_mark(0, 1, EdgeMark::Arrowhead);
let mark = pag.get_mark_at(0, 1);
assert_eq!(mark, Some(EdgeMark::Arrowhead));
let mark0 = pag.get_mark_at(1, 0);
assert_eq!(mark0, Some(EdgeMark::Circle));
}
#[test]
fn test_pag_is_parent_true() {
let mut pag = PartialAncestralGraph::new(3);
pag.add_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrowhead);
assert!(pag.is_parent(0, 1));
assert!(!pag.is_parent(1, 0));
}
#[test]
fn test_pag_n_bidirected_edges_after_orient() {
let mut pag = PartialAncestralGraph::new(4);
pag.add_edge(0, 1, EdgeMark::Arrowhead, EdgeMark::Arrowhead); pag.add_edge(1, 2, EdgeMark::Tail, EdgeMark::Arrowhead); pag.add_edge(2, 3, EdgeMark::Circle, EdgeMark::Circle); assert_eq!(pag.n_bidirected_edges(), 1);
}
#[test]
fn test_pag_n_circle_marks() {
let mut pag = PartialAncestralGraph::new(4);
pag.add_edge(0, 1, EdgeMark::Circle, EdgeMark::Circle); pag.add_edge(1, 2, EdgeMark::Circle, EdgeMark::Arrowhead); pag.add_edge(2, 3, EdgeMark::Tail, EdgeMark::Arrowhead); assert_eq!(pag.n_circle_marks(), 3);
}
#[test]
fn test_pag_adjacent_nodes() {
let mut pag = PartialAncestralGraph::new(4);
pag.add_edge(0, 1, EdgeMark::Circle, EdgeMark::Circle);
pag.add_edge(0, 2, EdgeMark::Circle, EdgeMark::Circle);
pag.add_edge(1, 3, EdgeMark::Circle, EdgeMark::Circle);
let adj0 = pag.adjacent_nodes(0);
assert_eq!(adj0, vec![1, 2]);
let adj1 = pag.adjacent_nodes(1);
assert_eq!(adj1, vec![0, 3]);
}
}