use im_rc::OrdSet;
use std::collections::HashSet;
#[derive(Clone, Debug)]
pub struct DAGNode<T> {
value: T,
inputs: Vec<usize>,
ancestors: OrdSet<usize>,
remainder_ancestors: Vec<usize>,
}
impl<T> DAGNode<T> {
pub fn value(&self) -> &T {
&self.value
}
pub fn inputs(&self) -> &[usize] {
&self.inputs
}
pub fn ancestors(&self) -> &OrdSet<usize> {
&self.ancestors
}
pub fn remainder_ancestors(&self) -> &[usize] {
&self.remainder_ancestors
}
}
#[derive(Clone, Debug)]
pub struct DAG<T> {
nodes: Vec<DAGNode<T>>,
}
impl<T> DAG<T> {
pub fn new() -> Self {
DAG { nodes: vec![] }
}
pub fn nodes(&self) -> &[DAGNode<T>] {
&self.nodes
}
pub fn node(&self, index: usize) -> &DAGNode<T> {
&self.nodes[index]
}
pub fn add_node(&mut self, value: T, inputs: Vec<usize>) {
for input in &inputs {
assert!(*input < self.nodes.len());
}
let (ancestors, remainder_ancestors) = if inputs.is_empty() {
(OrdSet::new(), vec![])
} else {
let mut ancestors = self.nodes[inputs[0]].ancestors.clone();
let mut remainder_ancestors = HashSet::new();
ancestors.insert(inputs[0]);
let mut queue = Vec::new();
for input in &inputs[1..] {
queue.push(*input);
}
while let Some(ancestor) = queue.pop() {
if ancestors.insert(ancestor) == None {
remainder_ancestors.insert(ancestor);
for ancestor_input in &self.nodes[ancestor].inputs {
queue.push(*ancestor_input);
}
}
}
let mut sorted_remainder_ancestors =
remainder_ancestors.into_iter().collect::<Vec<usize>>();
sorted_remainder_ancestors.sort();
(ancestors, sorted_remainder_ancestors)
};
self.nodes.push(DAGNode {
value,
ancestors,
remainder_ancestors,
inputs,
});
}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! hash_set {
($($arg:expr),*) => {
vec![$($arg),*].into_iter().map(|x: i32| x as usize).collect::<OrdSet<_>>()
}
}
#[test]
fn ancestor_segments() {
let mut graph = DAG::new();
graph.add_node((), vec![]);
graph.add_node((), vec![0]);
graph.add_node((), vec![1]);
graph.add_node((), vec![2]);
assert_eq!(graph.node(0).ancestors(), &hash_set![]);
assert_eq!(graph.node(1).ancestors(), &hash_set![0]);
assert_eq!(graph.node(2).ancestors(), &hash_set![0, 1]);
assert_eq!(graph.node(3).ancestors(), &hash_set![0, 1, 2]);
}
#[test]
fn remainder_ancestors() {
let mut graph = DAG::new();
graph.add_node((), vec![]);
graph.add_node((), vec![0]);
graph.add_node((), vec![1]);
graph.add_node((), vec![0]);
graph.add_node((), vec![3]);
graph.add_node((), vec![2, 4]);
assert_eq!(graph.node(5).remainder_ancestors(), &[3, 4]);
}
}