use super::error::ParseError;
use crate::ast::Id;
use crate::{IndexMap, ParseErrorKind, ParseResult};
use alloc::collections::BinaryHeap;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use core::mem;
#[derive(Default, Clone)]
struct State {
outbound_remaining: usize,
reverse_deps: Vec<usize>,
}
pub fn toposort<'a>(
kind: &str,
deps: &IndexMap<&'a str, Vec<Id<'a>>>,
) -> ParseResult<Vec<&'a str>> {
let mut states = vec![State::default(); deps.len()];
for (i, (_, edges)) in deps.iter().enumerate() {
states[i].outbound_remaining = edges.len();
for edge in edges {
let (j, _, _) = deps.get_full(edge.name).ok_or_else(|| {
ParseError::from(ParseErrorKind::ItemNotFound {
span: edge.span,
name: edge.name.to_string(),
kind: kind.to_string(),
hint: None,
})
})?;
states[j].reverse_deps.push(i);
}
}
let mut order = Vec::new();
let mut heap = BinaryHeap::new();
for (i, dep) in deps.keys().enumerate() {
if states[i].outbound_remaining == 0 {
heap.push((deps.len() - i, *dep, i));
}
}
while let Some((_order, node, i)) = heap.pop() {
order.push(node);
for i in mem::take(&mut states[i].reverse_deps) {
states[i].outbound_remaining -= 1;
if states[i].outbound_remaining == 0 {
let (dep, _) = deps.get_index(i).unwrap();
heap.push((deps.len() - i, *dep, i));
}
}
}
if order.len() == deps.len() {
return Ok(order);
}
for (i, state) in states.iter().enumerate() {
if state.outbound_remaining == 0 {
continue;
}
let (_, edges) = deps.get_index(i).unwrap();
for dep in edges {
let (j, _, _) = deps.get_full(dep.name).unwrap();
if states[j].outbound_remaining == 0 {
continue;
}
return Err(ParseErrorKind::TypeCycle {
span: dep.span,
name: dep.name.to_string(),
kind: kind.to_string(),
}
.into());
}
}
unreachable!()
}
#[cfg(test)]
mod tests {
use super::*;
fn id(name: &str) -> Id<'_> {
Id {
name,
span: Default::default(),
}
}
#[test]
fn smoke() {
let empty: Vec<&str> = Vec::new();
assert_eq!(toposort("", &IndexMap::default()).unwrap(), empty);
let mut nonexistent = IndexMap::default();
nonexistent.insert("a", vec![id("b")]);
assert!(matches!(
toposort("", &nonexistent).unwrap_err().kind(),
ParseErrorKind::ItemNotFound { .. }
));
let mut one = IndexMap::default();
one.insert("a", vec![]);
assert_eq!(toposort("", &one).unwrap(), ["a"]);
let mut two = IndexMap::default();
two.insert("a", vec![]);
two.insert("b", vec![id("a")]);
assert_eq!(toposort("", &two).unwrap(), ["a", "b"]);
let mut two = IndexMap::default();
two.insert("a", vec![id("b")]);
two.insert("b", vec![]);
assert_eq!(toposort("", &two).unwrap(), ["b", "a"]);
}
#[test]
fn cycles() {
let mut cycle = IndexMap::default();
cycle.insert("a", vec![id("a")]);
assert!(matches!(
toposort("", &cycle).unwrap_err().kind(),
ParseErrorKind::TypeCycle { .. }
));
let mut cycle = IndexMap::default();
cycle.insert("a", vec![id("b")]);
cycle.insert("b", vec![id("c")]);
cycle.insert("c", vec![id("a")]);
assert!(matches!(
toposort("", &cycle).unwrap_err().kind(),
ParseErrorKind::TypeCycle { .. }
));
}
#[test]
fn depend_twice() {
let mut two = IndexMap::default();
two.insert("b", vec![id("a"), id("a")]);
two.insert("a", vec![]);
assert_eq!(toposort("", &two).unwrap(), ["a", "b"]);
}
#[test]
fn preserve_order() {
let mut order = IndexMap::default();
order.insert("a", vec![]);
order.insert("b", vec![]);
assert_eq!(toposort("", &order).unwrap(), ["a", "b"]);
let mut order = IndexMap::default();
order.insert("b", vec![]);
order.insert("a", vec![]);
assert_eq!(toposort("", &order).unwrap(), ["b", "a"]);
}
}