use std::collections::{HashMap, HashSet, VecDeque};
use crate::constraint::{Constraint, Expression};
use crate::operation::AstEngine;
use crate::structure::{NodeId, NodeKind, Reference};
fn extract_var_refs(expr: &Expression) -> Vec<NodeId> {
match expr {
Expression::Var(Reference::VariableRef(id)) => vec![*id],
Expression::Lit(_) | Expression::Var(_) => vec![],
Expression::BinOp { lhs, rhs, .. } => {
let mut refs = extract_var_refs(lhs);
refs.extend(extract_var_refs(rhs));
refs
}
Expression::Pow { base, exp } => {
let mut refs = extract_var_refs(base);
refs.extend(extract_var_refs(exp));
refs
}
Expression::FnCall { args, .. } => args.iter().flat_map(extract_var_refs).collect(),
}
}
fn extract_constraint_refs(constraint: &Constraint) -> Vec<NodeId> {
match constraint {
Constraint::Range { lower, upper, .. } => {
let mut refs = extract_var_refs(lower);
refs.extend(extract_var_refs(upper));
refs
}
Constraint::SumBound { upper, .. } => extract_var_refs(upper),
Constraint::LengthRelation { length, .. } => extract_var_refs(length),
Constraint::Relation { lhs, rhs, .. } => {
let mut refs = extract_var_refs(lhs);
refs.extend(extract_var_refs(rhs));
refs
}
Constraint::StringLength { min, max, .. } => {
let mut refs = extract_var_refs(min);
refs.extend(extract_var_refs(max));
refs
}
Constraint::Guarantee {
predicate: Some(expr),
..
} => extract_var_refs(expr),
Constraint::TypeDecl { .. }
| Constraint::Distinct { .. }
| Constraint::Property { .. }
| Constraint::Sorted { .. }
| Constraint::CharSet { .. }
| Constraint::RenderHint { .. }
| Constraint::Guarantee {
predicate: None, ..
} => vec![],
}
}
#[derive(Debug, Clone)]
pub struct CycleError {
pub involved: Vec<NodeId>,
}
impl std::fmt::Display for CycleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"dependency cycle detected involving {} nodes",
self.involved.len()
)
}
}
impl std::error::Error for CycleError {}
#[derive(Debug, Clone)]
pub struct DependencyGraph {
deps: HashMap<NodeId, Vec<NodeId>>,
all_nodes: Vec<NodeId>,
}
impl DependencyGraph {
#[must_use]
pub fn build(engine: &AstEngine) -> Self {
let mut deps: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
let mut all_nodes = Vec::new();
for node in engine.structure.iter() {
let id = node.id();
all_nodes.push(id);
deps.entry(id).or_default();
}
for node in engine.structure.iter() {
let id = node.id();
match node.kind() {
NodeKind::Array { length, .. } => {
for ref_id in extract_var_refs(length) {
deps.entry(id).or_default().push(ref_id);
}
}
NodeKind::Matrix { rows, cols, .. } => {
if let Reference::VariableRef(ref_id) = rows {
deps.entry(id).or_default().push(*ref_id);
}
if let Reference::VariableRef(ref_id) = cols {
deps.entry(id).or_default().push(*ref_id);
}
}
NodeKind::Repeat { count, body, .. } => {
for ref_id in extract_var_refs(count) {
deps.entry(id).or_default().push(ref_id);
}
for &child in body {
deps.entry(child).or_default().push(id);
}
}
NodeKind::Sequence { children } => {
for &child in children {
deps.entry(child).or_default().push(id);
}
}
NodeKind::Section { header, body } => {
if let Some(h) = header {
deps.entry(*h).or_default().push(id);
}
for &child in body {
deps.entry(child).or_default().push(id);
}
}
NodeKind::Tuple { elements } => {
for &child in elements {
deps.entry(child).or_default().push(id);
}
}
NodeKind::Choice { variants, tag } => {
if let Reference::VariableRef(ref_id) = tag {
deps.entry(id).or_default().push(*ref_id);
}
for (_, children) in variants {
for &child in children {
deps.entry(child).or_default().push(id);
}
}
}
NodeKind::Scalar { .. } | NodeKind::Hole { .. } => {}
}
}
let all_nodes_set: HashSet<NodeId> = all_nodes.iter().copied().collect();
for node in engine.structure.iter() {
let id = node.id();
let constraint_ids = engine.constraints.for_node(id);
for cid in constraint_ids {
if let Some(constraint) = engine.constraints.get(cid) {
let refs = extract_constraint_refs(constraint);
for ref_id in refs {
if ref_id != id && all_nodes_set.contains(&ref_id) {
deps.entry(id).or_default().push(ref_id);
}
}
}
}
}
Self { deps, all_nodes }
}
#[must_use]
pub fn dependencies_of(&self, node: NodeId) -> &[NodeId] {
self.deps.get(&node).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn all_nodes(&self) -> &[NodeId] {
&self.all_nodes
}
pub fn topological_sort(&self) -> Result<Vec<NodeId>, CycleError> {
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
let mut reverse: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
for &node in &self.all_nodes {
in_degree.entry(node).or_insert(0);
reverse.entry(node).or_default();
}
for (&node, dep_list) in &self.deps {
let unique_deps: HashSet<&NodeId> = dep_list.iter().collect();
*in_degree.entry(node).or_insert(0) += unique_deps.len();
for &dep in &unique_deps {
reverse.entry(*dep).or_default().push(node);
}
}
let mut queue: VecDeque<NodeId> = VecDeque::new();
for (&node, °) in &in_degree {
if deg == 0 {
queue.push_back(node);
}
}
let mut sorted = Vec::with_capacity(self.all_nodes.len());
while let Some(node) = queue.pop_front() {
sorted.push(node);
if let Some(dependents) = reverse.get(&node) {
for &dependent in dependents {
if let Some(deg) = in_degree.get_mut(&dependent) {
*deg = deg.saturating_sub(1);
if *deg == 0 {
queue.push_back(dependent);
}
}
}
}
}
if sorted.len() == self.all_nodes.len() {
Ok(sorted)
} else {
let sorted_set: HashSet<NodeId> = sorted.into_iter().collect();
let involved = self
.all_nodes
.iter()
.filter(|n| !sorted_set.contains(n))
.copied()
.collect();
Err(CycleError { involved })
}
}
}