use std::collections::{BTreeMap, BTreeSet};
use super::generational_index::*;
pub struct Graph {
data: BTreeSet<NetIndex>,
op: BTreeSet<NetIndex>,
forward_dt_op: BTreeMap<NetIndex, BTreeSet<NetIndex>>,
forward_op_dt: BTreeMap<NetIndex, BTreeSet<NetIndex>>,
backward_dt_op: BTreeMap<NetIndex, BTreeSet<NetIndex>>,
backward_op_dt: BTreeMap<NetIndex, BTreeSet<NetIndex>>,
}
impl Graph {
pub fn new() -> Graph {
Graph{
data: BTreeSet::new(),
op: BTreeSet::new(),
forward_dt_op: BTreeMap::new(),
forward_op_dt: BTreeMap::new(),
backward_dt_op: BTreeMap::new(),
backward_op_dt: BTreeMap::new(),
}
}
pub fn list_data(&self) -> Vec<NetIndex> {
let mut ret = Vec::new();
for i in &self.data {
ret.push(i.clone());
}
ret
}
pub fn list_op(&self) -> Vec<NetIndex> {
let mut ret = Vec::new();
for i in &self.op {
ret.push(i.clone());
}
ret
}
pub fn add_data(&mut self, id: &NetIndex) -> Result<NetIndex, &str> {
if !self.data.contains(id) {
self.data.insert(*id);
self.forward_dt_op.insert(id.clone(), BTreeSet::new());
self.backward_dt_op.insert(id.clone(), BTreeSet::new());
Ok(id.clone())
} else {
Err("data is exits!")
}
}
pub fn del_data(&mut self, id: &NetIndex) -> Result<NetIndex, &str> {
if self.data.contains(id) {
self.data.remove(id);
for i in self.forward_dt_op.get_mut(id).expect("").iter() {
self.backward_op_dt.get_mut(i).expect("").remove(id);
}
self.forward_dt_op.remove(id);
for i in self.backward_dt_op.get_mut(id).expect("").iter() {
self.forward_op_dt.get_mut(i).expect("").remove(id);
}
self.backward_dt_op.remove(id);
Ok(id.clone())
} else {
Err("data id is not found!")
}
}
pub fn add_op(&mut self, id: &NetIndex) -> Result<NetIndex, &str> {
if !self.op.contains(id) {
self.op.insert(*id);
self.forward_op_dt.insert(id.clone(), BTreeSet::new());
self.backward_op_dt.insert(id.clone(), BTreeSet::new());
Ok(id.clone())
} else {
Err("op id exists.")
}
}
pub fn del_op(&mut self, id: &NetIndex) -> Result<NetIndex, &str> {
if self.op.contains(id) {
self.op.remove(id);
for i in self.forward_op_dt.get_mut(id).expect("").iter() {
self.backward_dt_op.get_mut(i).expect("").remove(id);
}
self.forward_op_dt.remove(id);
for i in self.backward_op_dt.get_mut(id).expect("").iter() {
self.forward_dt_op.get_mut(i).expect("").remove(id);
}
self.backward_op_dt.remove(id);
Ok(id.clone())
} else {
Err("op id is not found!")
}
}
pub fn get_input_cache(&self) -> BTreeSet<NetIndex> {
let mut jobs = BTreeSet::<NetIndex>::new();
for i in &self.data {
if self.backward_dt_op.get(i).expect("").len() <= 0 {
jobs.insert(i.clone());
}
}
jobs
}
pub fn get_output_cache(&self) -> BTreeSet<NetIndex> {
let mut jobs = BTreeSet::<NetIndex>::new();
for i in &self.data {
if self.forward_dt_op.get(i).expect("").len() <= 0 {
jobs.insert(i.clone());
}
}
jobs
}
pub fn connect(&mut self, dti: &[NetIndex],
dto: &[NetIndex],
op: &NetIndex) -> Result<NetIndex, &str> {
let mut valid_ids = true;
if !self.op.contains(op) {
valid_ids = false;
}
for i in dti {
if !self.data.contains(i) {
valid_ids = false;
}
}
for i in dto {
if !self.data.contains(i) {
valid_ids = false;
}
}
if valid_ids {
for i in dti {
self.forward_dt_op.get_mut(i).expect("").insert(op.clone());
self.backward_op_dt.get_mut(op).expect("").insert(i.clone());
}
for i in dto {
self.forward_op_dt.get_mut(op).expect("").insert(i.clone());
self.backward_dt_op.get_mut(i).expect("").insert(op.clone());
}
Ok(op.clone())
} else {
Err("Invalid id!")
}
}
pub fn walk<F>(&self, start_set: &[NetIndex],
forward: bool,
closure: F) -> Result<(), BTreeSet<NetIndex>>
where F: Fn(&[NetIndex], &[NetIndex], &NetIndex) {
let mut fdo = &self.forward_dt_op;
let mut fod = &self.forward_op_dt;
let mut bod = &self.backward_op_dt;
if !forward {
fdo = &self.backward_dt_op;
fod = &self.backward_op_dt;
bod = &self.forward_op_dt;
}
let mut jobs = BTreeSet::<NetIndex>::new();
let mut done = BTreeSet::<NetIndex>::new();
for index in start_set {
jobs.insert(*index);
}
loop {
let mut made_progress = false;
let mut edge_op = BTreeSet::<NetIndex>::new();
for dt in &jobs {
for op_candidate in &fdo[dt] {
edge_op.insert(op_candidate.clone());
}
}
for op_candidate in edge_op {
if bod[&op_candidate]
.iter()
.all(|dt| jobs.contains(dt)) {
let mut inputs = Vec::<NetIndex>::new();
for input in bod[&op_candidate].iter() {
inputs.push(input.clone());
}
let mut outputs = Vec::<NetIndex>::new();
for output in fod[&op_candidate].iter() {
outputs.push(output.clone());
}
closure(&inputs, &outputs, &op_candidate);
done.insert(op_candidate);
for input in bod[&op_candidate].iter() {
if fdo[input]
.iter()
.all(|op| done.contains(op)) {
jobs.remove(input);
}
}
for output in fod[&op_candidate].iter() {
if fdo[output].len() > 0 {
jobs.insert(*output);
}
}
made_progress = true;
}
}
if ! made_progress {
break;
}
}
if jobs.len() > 0 {
Err(jobs)
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let _g = Graph::new();
}
fn setup_y(g: &mut Graph) {
let data_a = NetIndex::new(0,0);
let data_b = NetIndex::new(1,0);
let data_c = NetIndex::new(2,0);
g.add_data(&data_a).expect("");
g.add_data(&data_b).expect("");
g.add_data(&data_c).expect("");
let op_a = NetIndex::new(0,0);
g.add_op(&op_a).expect("");
g.connect(&[data_a, data_b], &[data_c,], &op_a).expect("");
}
fn setup_yy(g: &mut Graph) {
let data_a = NetIndex::new(0,0);
let data_b = NetIndex::new(1,0);
let data_c = NetIndex::new(2,0);
let data_d = NetIndex::new(3,0);
let data_e = NetIndex::new(4,0);
g.add_data(&data_a).expect("");
g.add_data(&data_b).expect("");
g.add_data(&data_c).expect("");
g.add_data(&data_d).expect("");
g.add_data(&data_e).expect("");
let op1 = NetIndex::new(0,0);
g.add_op(&op1).expect("");
let op2 = NetIndex::new(1,0);
g.add_op(&op2).expect("");
g.connect(&[data_a, data_b], &[data_c,], &op1).expect("");
g.connect(&[data_c, data_d], &[data_e,], &op2).expect("");
}
#[test]
fn test_get_input_cache() {
let mut g = Graph::new();
setup_y(&mut g);
assert_eq!(g.get_input_cache().len(), 2);
let mut g = Graph::new();
setup_yy(&mut g);
assert_eq!(g.get_input_cache().len(), 3);
}
#[test]
fn test_get_output_cache() {
let mut g = Graph::new();
setup_y(&mut g);
assert_eq!(g.get_output_cache().len(), 1);
let mut g = Graph::new();
setup_yy(&mut g);
assert_eq!(g.get_output_cache().len(), 1);
}
}