use std::collections::{HashMap, HashSet, VecDeque};
use super::dependency::DependencyGraph;
use super::errors::LocyCompileError;
pub struct StratificationResult {
pub sccs: Vec<HashSet<String>>,
pub scc_order: Vec<usize>,
pub scc_map: HashMap<String, usize>,
pub is_recursive: Vec<bool>,
pub scc_depends_on: Vec<HashSet<usize>>,
}
pub fn stratify(graph: &DependencyGraph) -> Result<StratificationResult, LocyCompileError> {
let mut adj: HashMap<&str, HashSet<&str>> = HashMap::new();
for rule in &graph.all_rules {
adj.entry(rule.as_str()).or_default();
}
for (from, tos) in &graph.positive_edges {
for to in tos {
adj.entry(from.as_str()).or_default().insert(to.as_str());
}
}
for (from, tos) in &graph.negative_edges {
for to in tos {
adj.entry(from.as_str()).or_default().insert(to.as_str());
}
}
let mut rules: Vec<&str> = graph.all_rules.iter().map(|s| s.as_str()).collect();
rules.sort();
let sccs = tarjan(&rules, &adj);
let mut scc_map: HashMap<String, usize> = HashMap::new();
for (i, scc) in sccs.iter().enumerate() {
for rule in scc {
scc_map.insert(rule.clone(), i);
}
}
for (from, tos) in &graph.negative_edges {
for to in tos {
let from_scc = scc_map[from.as_str()];
let to_scc = scc_map[to.as_str()];
if from_scc == to_scc {
let mut rules: Vec<String> = sccs[from_scc].iter().cloned().collect();
rules.sort();
return Err(LocyCompileError::CyclicNegation { rules });
}
}
}
let mut is_recursive = vec![false; sccs.len()];
for (i, scc) in sccs.iter().enumerate() {
if scc.len() > 1 {
is_recursive[i] = true;
} else {
let rule = scc.iter().next().unwrap();
let has_self_edge = graph
.positive_edges
.get(rule.as_str())
.is_some_and(|deps| deps.contains(rule));
is_recursive[i] = has_self_edge;
}
}
let mut scc_depends_on: Vec<HashSet<usize>> = vec![HashSet::new(); sccs.len()];
for (from, tos) in graph
.positive_edges
.iter()
.chain(graph.negative_edges.iter())
{
let from_scc = scc_map[from.as_str()];
for to in tos {
let to_scc = scc_map[to.as_str()];
if from_scc != to_scc {
scc_depends_on[from_scc].insert(to_scc);
}
}
}
let n = sccs.len();
let mut in_degree = vec![0usize; n];
let mut reverse_deps: Vec<Vec<usize>> = vec![vec![]; n];
for (i, deps) in scc_depends_on.iter().enumerate() {
for &dep in deps {
reverse_deps[dep].push(i);
}
in_degree[i] = deps.len();
}
let mut queue: VecDeque<usize> = VecDeque::new();
for (i, °) in in_degree.iter().enumerate() {
if deg == 0 {
queue.push_back(i);
}
}
let mut order = Vec::with_capacity(n);
while let Some(node) = queue.pop_front() {
order.push(node);
for &dependent in &reverse_deps[node] {
in_degree[dependent] -= 1;
if in_degree[dependent] == 0 {
queue.push_back(dependent);
}
}
}
Ok(StratificationResult {
sccs,
scc_order: order,
scc_map,
is_recursive,
scc_depends_on,
})
}
fn tarjan(nodes: &[&str], adj: &HashMap<&str, HashSet<&str>>) -> Vec<HashSet<String>> {
struct State<'a> {
index_counter: usize,
stack: Vec<&'a str>,
on_stack: HashSet<&'a str>,
index: HashMap<&'a str, usize>,
lowlink: HashMap<&'a str, usize>,
sccs: Vec<HashSet<String>>,
}
fn strongconnect<'a>(v: &'a str, adj: &HashMap<&str, HashSet<&'a str>>, state: &mut State<'a>) {
state.index.insert(v, state.index_counter);
state.lowlink.insert(v, state.index_counter);
state.index_counter += 1;
state.stack.push(v);
state.on_stack.insert(v);
if let Some(neighbors) = adj.get(v) {
for &w in neighbors {
if !state.index.contains_key(w) {
strongconnect(w, adj, state);
let w_low = state.lowlink[w];
let v_low = state.lowlink[v];
if w_low < v_low {
state.lowlink.insert(v, w_low);
}
} else if state.on_stack.contains(w) {
let w_idx = state.index[w];
let v_low = state.lowlink[v];
if w_idx < v_low {
state.lowlink.insert(v, w_idx);
}
}
}
}
if state.lowlink[v] == state.index[v] {
let mut scc = HashSet::new();
loop {
let w = state.stack.pop().unwrap();
state.on_stack.remove(w);
scc.insert(w.to_string());
if w == v {
break;
}
}
state.sccs.push(scc);
}
}
let mut state = State {
index_counter: 0,
stack: Vec::new(),
on_stack: HashSet::new(),
index: HashMap::new(),
lowlink: HashMap::new(),
sccs: Vec::new(),
};
for &node in nodes {
if !state.index.contains_key(node) {
strongconnect(node, adj, &mut state);
}
}
state.sccs
}