use bit_set::BitSet;
use std::collections::{HashSet, VecDeque};
#[derive(Debug, PartialEq)]
pub enum PosetError {
Cycle,
}
#[derive(Debug)]
pub struct Poset<T> {
pub nodes: Vec<T>,
pub covering_edges: Vec<(u32, u32)>,
pub transitive_edges: Vec<(u32, u32)>,
}
impl<T: Clone> Poset<T> {
pub fn from_covering_relation(
nodes: Vec<T>,
edges: Vec<(u32, u32)>,
) -> Result<Self, PosetError> {
let n = nodes.len();
let topo = kahn_topo_sort(n, &edges).ok_or(PosetError::Cycle)?;
let transitive_edges = transitive_closure(n, &edges, &topo);
Ok(Poset {
nodes,
covering_edges: edges,
transitive_edges,
})
}
pub fn from_transitive_relation(
nodes: Vec<T>,
edges: Vec<(u32, u32)>,
) -> Result<Self, PosetError> {
let n = nodes.len();
let _ = kahn_topo_sort(n, &edges).ok_or(PosetError::Cycle)?;
let mut reach: Vec<BitSet> = vec![BitSet::with_capacity(n); n];
for &(u, v) in &edges {
reach[u as usize].insert(v as usize);
}
let covering_edges = transitive_reduction(n, &edges, &reach);
Ok(Poset {
nodes,
covering_edges,
transitive_edges: edges,
})
}
pub fn is_leq(&self, a: u32, b: u32) -> bool {
if a == b {
return true;
}
self.transitive_edges.contains(&(a, b))
}
pub fn covers(&self, a: u32, b: u32) -> bool {
self.covering_edges.contains(&(a, b))
}
pub fn linear_extension(&self) -> Option<Vec<usize>> {
kahn_topo_sort(self.nodes.len(), &self.covering_edges)
}
pub fn linear_extension_static(
n: usize,
edges: &HashSet<(usize, usize)>,
) -> Option<Vec<usize>> {
let edges_u32: Vec<(u32, u32)> = edges.iter().map(|&(u, v)| (u as u32, v as u32)).collect();
kahn_topo_sort(n, &edges_u32)
}
}
pub(crate) fn kahn_topo_sort(n: usize, edges: &[(u32, u32)]) -> Option<Vec<usize>> {
let mut in_degree = vec![0usize; n];
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
for &(u, v) in edges {
let (u, v) = (u as usize, v as usize);
adj[u].push(v);
in_degree[v] += 1;
}
for neighbors in adj.iter_mut() {
neighbors.sort();
}
let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(u) = queue.pop_front() {
order.push(u);
for &v in &adj[u] {
in_degree[v] -= 1;
if in_degree[v] == 0 {
queue.push_back(v);
}
}
}
if order.len() == n { Some(order) } else { None }
}
fn transitive_closure(n: usize, edges: &[(u32, u32)], topo: &[usize]) -> Vec<(u32, u32)> {
let mut reach: Vec<BitSet> = vec![BitSet::with_capacity(n); n];
let mut direct: Vec<BitSet> = vec![BitSet::with_capacity(n); n];
for &(u, v) in edges {
direct[u as usize].insert(v as usize);
}
for &u in topo.iter().rev() {
reach[u] = direct[u].clone();
let succs: Vec<usize> = direct[u].iter().collect();
for v in succs {
let reach_v = reach[v].clone();
reach[u].union_with(&reach_v);
}
}
let mut result = Vec::new();
for (u, reach_u) in reach.iter().enumerate().take(n) {
for v in reach_u.iter() {
result.push((u as u32, v as u32));
}
}
result
}
fn transitive_reduction(n: usize, edges: &[(u32, u32)], reach: &[BitSet]) -> Vec<(u32, u32)> {
let mut succ: Vec<BitSet> = vec![BitSet::with_capacity(n); n];
for &(u, v) in edges {
succ[u as usize].insert(v as usize);
}
let mut result = Vec::new();
for (u, succ_u) in succ.iter().enumerate().take(n) {
let mut indirect = BitSet::with_capacity(n);
for w in succ_u.iter() {
indirect.union_with(&reach[w]);
}
let mut covering = succ_u.clone();
covering.difference_with(&indirect);
for v in covering.iter() {
result.push((u as u32, v as u32));
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_covering_chain() {
let p = Poset::from_covering_relation(
vec![0usize, 1, 2],
vec![(0, 1), (1, 2)],
)
.unwrap();
assert_eq!(p.covering_edges.len(), 2);
assert!(p.transitive_edges.contains(&(0, 1)));
assert!(p.transitive_edges.contains(&(1, 2)));
assert!(p.transitive_edges.contains(&(0, 2)));
assert_eq!(p.transitive_edges.len(), 3);
}
#[test]
fn test_from_covering_diamond() {
let p = Poset::from_covering_relation(
vec![0usize, 1, 2, 3],
vec![(0, 1), (0, 2), (1, 3), (2, 3)],
)
.unwrap();
assert!(p.transitive_edges.contains(&(0, 3)));
assert!(p.transitive_edges.contains(&(0, 1)));
assert!(p.transitive_edges.contains(&(0, 2)));
assert!(p.transitive_edges.contains(&(1, 3)));
assert!(p.transitive_edges.contains(&(2, 3)));
}
#[test]
fn test_from_transitive_diamond() {
let transitive = vec![(0u32, 1u32), (0, 2), (1, 3), (2, 3), (0, 3)];
let p = Poset::from_transitive_relation(vec![0usize, 1, 2, 3], transitive).unwrap();
let mut cov = p.covering_edges.clone();
cov.sort();
assert_eq!(cov, vec![(0, 1), (0, 2), (1, 3), (2, 3)]);
}
#[test]
fn test_cycle_detection() {
let result = Poset::from_covering_relation(vec![0usize, 1], vec![(0, 1), (1, 0)]);
assert_eq!(result.unwrap_err(), PosetError::Cycle);
}
#[test]
fn test_is_leq() {
let p = Poset::from_covering_relation(
vec![0usize, 1, 2, 3],
vec![(0, 1), (0, 2), (1, 3), (2, 3)],
)
.unwrap();
assert!(p.is_leq(0, 3));
assert!(p.is_leq(0, 0)); assert!(!p.is_leq(1, 2)); assert!(!p.is_leq(3, 0)); }
#[test]
fn test_covers() {
let p = Poset::from_covering_relation(
vec![0usize, 1, 2, 3],
vec![(0, 1), (0, 2), (1, 3), (2, 3)],
)
.unwrap();
assert!(p.covers(0, 1));
assert!(!p.covers(0, 3)); assert!(!p.covers(1, 2)); }
#[test]
fn test_linear_extension() {
let p = Poset::from_covering_relation(
vec![0usize, 1, 2],
vec![(0, 1), (1, 2)],
)
.unwrap();
let le = p.linear_extension().unwrap();
assert_eq!(le, vec![0, 1, 2]);
}
}