use crate::parser::{FlowLogRule, Predicate};
use itertools::Itertools;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone)]
pub(super) struct DependencyGraph {
dependency_map: HashMap<usize, HashSet<usize>>,
negative_edges: BTreeSet<(usize, usize)>,
}
impl DependencyGraph {
#[must_use]
pub(super) fn dependency_map(&self) -> &HashMap<usize, HashSet<usize>> {
&self.dependency_map
}
#[must_use]
pub(super) fn negative_edges(&self) -> &BTreeSet<(usize, usize)> {
&self.negative_edges
}
#[must_use]
pub(super) fn from_rules(rules: &[FlowLogRule]) -> Self {
let head_to_rule_map = Self::build_head_to_rule_map(rules);
let mut dependency_map: HashMap<usize, HashSet<usize>> =
(0..rules.len()).map(|i| (i, HashSet::new())).collect();
let mut negative_edges: BTreeSet<(usize, usize)> = BTreeSet::new();
for (rule_id, rule) in rules.iter().enumerate() {
for predicate in rule.rhs() {
let (atom_name, is_negative) = match predicate {
Predicate::PositiveAtom(atom) => (atom.name(), false),
Predicate::NegativeAtom(atom) => (atom.name(), true),
_ => continue,
};
if let Some(dep_ids) = head_to_rule_map.get(atom_name) {
for &dep_id in dep_ids {
dependency_map.get_mut(&rule_id).unwrap().insert(dep_id);
if is_negative {
negative_edges.insert((rule_id, dep_id));
}
}
}
}
}
Self {
dependency_map,
negative_edges,
}
}
fn build_head_to_rule_map(rules: &[FlowLogRule]) -> HashMap<String, Vec<usize>> {
let mut map: HashMap<String, Vec<usize>> = HashMap::new();
for (id, rule) in rules.iter().enumerate() {
map.entry(rule.head().name().to_string())
.or_default()
.push(id);
}
map
}
}
impl fmt::Display for DependencyGraph {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "\nDependency Graph:")?;
writeln!(f, "{}", "-".repeat(45))?;
for (rule_id, deps) in self.dependency_map().iter().sorted_by_key(|x| x.0) {
if deps.is_empty() {
writeln!(f, "Rule {}: []", rule_id)?;
} else {
let dep_str = deps.iter().sorted().map(ToString::to_string).join(", ");
writeln!(f, "Rule {}: [{}]", rule_id, dep_str)?;
}
}
Ok(())
}
}