use disposition_ir_model::{
edge::EdgeGroups,
entity::{EntityType, EntityTypes},
node::{NodeHierarchy, NodeId, NodeRank, NodeRanks},
};
use disposition_model_common::{Id, Map};
#[derive(Clone, Copy, Debug)]
pub struct NodeRanksCalculator;
struct TarjanState {
index_counter: usize,
stack: Vec<usize>,
on_stack: Vec<bool>,
index: Vec<Option<usize>>,
lowlink: Vec<usize>,
scc_ids: Vec<usize>,
scc_counter: usize,
}
impl NodeRanksCalculator {
pub fn calculate<'id>(
node_hierarchy: &NodeHierarchy<'id>,
edge_groups: &EdgeGroups<'id>,
entity_types: &EntityTypes<'id>,
) -> NodeRanks<'id> {
let mut all_node_ids: Vec<NodeId<'id>> = Vec::new();
Self::node_ids_collect(node_hierarchy, &mut all_node_ids);
if all_node_ids.is_empty() {
return NodeRanks::new();
}
let dependency_edges = Self::dependency_edges_collect(edge_groups, entity_types);
if dependency_edges.is_empty() {
return all_node_ids
.into_iter()
.map(|node_id| (node_id, NodeRank::new(0)))
.collect();
}
Self::ranks_compute(&all_node_ids, &dependency_edges)
}
fn node_ids_collect<'id>(node_hierarchy: &NodeHierarchy<'id>, node_ids: &mut Vec<NodeId<'id>>) {
for (node_id, child_hierarchy) in node_hierarchy.iter() {
node_ids.push(node_id.clone());
Self::node_ids_collect(child_hierarchy, node_ids);
}
}
fn dependency_edges_collect<'id>(
edge_groups: &EdgeGroups<'id>,
entity_types: &EntityTypes<'id>,
) -> Vec<(NodeId<'id>, NodeId<'id>)> {
let mut dependency_edges = Vec::new();
for (edge_group_id, edge_group) in edge_groups.iter() {
let edge_group_id: &Id = edge_group_id.as_ref();
let is_dependency = entity_types
.get(edge_group_id)
.map(|types| types.iter().any(Self::entity_type_is_dependency_edge_group))
.unwrap_or(false);
if !is_dependency {
continue;
}
for edge in edge_group.iter() {
if edge.from == edge.to {
continue;
}
dependency_edges.push((edge.from.clone(), edge.to.clone()));
}
}
dependency_edges
}
fn entity_type_is_dependency_edge_group(entity_type: &EntityType) -> bool {
matches!(
entity_type,
EntityType::DependencyEdgeCyclicDefault
| EntityType::DependencyEdgeSequenceDefault
| EntityType::DependencyEdgeSymmetricDefault
)
}
fn ranks_compute<'id>(
all_node_ids: &[NodeId<'id>],
dependency_edges: &[(NodeId<'id>, NodeId<'id>)],
) -> NodeRanks<'id> {
let mut node_to_index: Map<NodeId<'id>, usize> = Map::new();
for (i, node_id) in all_node_ids.iter().enumerate() {
node_to_index.insert(node_id.clone(), i);
}
let node_count = all_node_ids.len();
let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); node_count];
for (from_id, to_id) in dependency_edges {
if let (Some(&from_idx), Some(&to_idx)) =
(node_to_index.get(from_id), node_to_index.get(to_id))
{
adjacency[from_idx].push(to_idx);
}
}
let scc_ids = Self::tarjan_scc(&adjacency, node_count);
let scc_count = scc_ids.iter().copied().max().map(|m| m + 1).unwrap_or(0);
if scc_count == 0 {
return all_node_ids
.iter()
.map(|node_id| (node_id.clone(), NodeRank::new(0)))
.collect();
}
let mut scc_adjacency: Vec<Vec<usize>> = vec![Vec::new(); scc_count];
let mut scc_in_degree: Vec<usize> = vec![0; scc_count];
for from_idx in 0..node_count {
let from_scc = scc_ids[from_idx];
for &to_idx in &adjacency[from_idx] {
let to_scc = scc_ids[to_idx];
if from_scc != to_scc {
scc_adjacency[from_scc].push(to_scc);
}
}
}
for neighbours in &mut scc_adjacency {
neighbours.sort_unstable();
neighbours.dedup();
}
scc_in_degree.fill(0);
scc_adjacency.iter().take(scc_count).for_each(|scc_idx| {
scc_idx.iter().copied().for_each(|to_scc| {
scc_in_degree[to_scc] += 1;
});
});
let scc_ranks = Self::scc_dag_ranks_compute(&scc_adjacency, &scc_in_degree, scc_count);
all_node_ids
.iter()
.enumerate()
.map(|(node_idx, node_id)| {
let scc_id = scc_ids[node_idx];
let rank = scc_ranks[scc_id];
(node_id.clone(), NodeRank::new(rank))
})
.collect()
}
fn tarjan_scc(adjacency: &[Vec<usize>], node_count: usize) -> Vec<usize> {
let mut state = TarjanState {
index_counter: 0,
stack: Vec::new(),
on_stack: vec![false; node_count],
index: vec![None; node_count],
lowlink: vec![0; node_count],
scc_ids: vec![0; node_count],
scc_counter: 0,
};
for node in 0..node_count {
if state.index[node].is_none() {
Self::tarjan_strongconnect_iterative(adjacency, node, &mut state);
}
}
state.scc_ids
}
fn tarjan_strongconnect_iterative(
adjacency: &[Vec<usize>],
start: usize,
state: &mut TarjanState,
) {
let mut call_stack: Vec<(usize, usize)> = Vec::new();
state.index[start] = Some(state.index_counter);
state.lowlink[start] = state.index_counter;
state.index_counter += 1;
state.stack.push(start);
state.on_stack[start] = true;
call_stack.push((start, 0));
while let Some(&mut (v, ref mut ni)) = call_stack.last_mut() {
if *ni < adjacency[v].len() {
let w = adjacency[v][*ni];
*ni += 1;
if state.index[w].is_none() {
state.index[w] = Some(state.index_counter);
state.lowlink[w] = state.index_counter;
state.index_counter += 1;
state.stack.push(w);
state.on_stack[w] = true;
call_stack.push((w, 0));
} else if state.on_stack[w] {
let w_index = state.index[w].unwrap();
if w_index < state.lowlink[v] {
state.lowlink[v] = w_index;
}
}
} else {
if state.lowlink[v] == state.index[v].unwrap() {
let scc_id = state.scc_counter;
state.scc_counter += 1;
while let Some(w) = state.stack.pop() {
state.on_stack[w] = false;
state.scc_ids[w] = scc_id;
if w == v {
break;
}
}
}
call_stack.pop();
if let Some(&mut (caller, _)) = call_stack.last_mut()
&& state.lowlink[v] < state.lowlink[caller]
{
state.lowlink[caller] = state.lowlink[v];
}
}
}
}
fn scc_dag_ranks_compute(
scc_adjacency: &[Vec<usize>],
scc_in_degree: &[usize],
scc_count: usize,
) -> Vec<u32> {
let mut ranks: Vec<u32> = vec![0; scc_count];
let mut in_degree = scc_in_degree.to_vec();
let mut queue: std::collections::VecDeque<usize> = std::collections::VecDeque::new();
in_degree
.iter()
.copied()
.enumerate()
.take(scc_count)
.filter(|(_scc_idx, in_degree_item)| *in_degree_item == 0)
.for_each(|(scc_idx, _in_degree_item)| queue.push_back(scc_idx));
while let Some(scc_idx) = queue.pop_front() {
for &to_scc in &scc_adjacency[scc_idx] {
let candidate_rank = ranks[scc_idx] + 1;
if candidate_rank > ranks[to_scc] {
ranks[to_scc] = candidate_rank;
}
in_degree[to_scc] -= 1;
if in_degree[to_scc] == 0 {
queue.push_back(to_scc);
}
}
}
ranks
}
}