use super::EquationInfo;
use std::collections::HashMap;
struct TarjanState {
index: usize,
stack: Vec<usize>,
indices: Vec<Option<usize>>,
lowlinks: Vec<usize>,
on_stack: Vec<bool>,
sccs: Vec<Vec<usize>>,
}
impl TarjanState {
fn new(n: usize) -> Self {
Self {
index: 0,
stack: Vec::new(),
indices: vec![None; n],
lowlinks: vec![0; n],
on_stack: vec![false; n],
sccs: Vec::new(),
}
}
fn strongconnect(&mut self, v: usize, graph: &[Vec<usize>]) {
self.indices[v] = Some(self.index);
self.lowlinks[v] = self.index;
self.index += 1;
self.stack.push(v);
self.on_stack[v] = true;
for &w in &graph[v] {
if self.indices[w].is_none() {
self.strongconnect(w, graph);
self.lowlinks[v] = self.lowlinks[v].min(self.lowlinks[w]);
} else if self.on_stack[w] {
self.lowlinks[v] = self.lowlinks[v].min(self.indices[w].unwrap());
}
}
if self.lowlinks[v] == self.indices[v].unwrap() {
let mut scc = Vec::new();
loop {
let w = self.stack.pop().unwrap();
self.on_stack[w] = false;
scc.push(w);
if w == v {
break;
}
}
self.sccs.push(scc);
}
}
}
pub(super) struct TarjanResult {
pub ordered_indices: Vec<usize>,
pub sccs: Vec<Vec<usize>>,
}
pub(super) fn tarjan_scc(eq_infos: &[EquationInfo]) -> TarjanResult {
let n = eq_infos.len();
let mut graph: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut var_to_eq: HashMap<String, usize> = HashMap::new();
for (i, info) in eq_infos.iter().enumerate() {
let defining_var = info
.matched_variable
.as_ref()
.or(info.lhs_variable.as_ref());
if let Some(var) = defining_var {
var_to_eq.insert(var.clone(), i);
}
}
for (i, info) in eq_infos.iter().enumerate() {
let my_var = info
.matched_variable
.as_ref()
.or(info.lhs_variable.as_ref());
for var in &info.all_variables {
if my_var.as_ref() == Some(&var) {
continue;
}
if let Some(&j) = var_to_eq.get(var)
&& i != j
{
graph[j].push(i);
}
}
}
let mut state = TarjanState::new(n);
for v in 0..n {
if state.indices[v].is_none() {
state.strongconnect(v, &graph);
}
}
state.sccs.reverse();
let mut ordered_indices = Vec::new();
for scc in &state.sccs {
for &eq_idx in scc {
ordered_indices.push(eq_idx);
}
}
TarjanResult {
ordered_indices,
sccs: state.sccs,
}
}