use super::Error;
use crate::front::wgsl::parse::ast;
use crate::{FastHashMap, Handle, Span};
pub struct Index<'a> {
dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>,
}
impl<'a> Index<'a> {
pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> {
let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
for (handle, decl) in tu.decls.iter() {
let ident = decl_ident(decl);
let name = ident.name;
if let Some(old) = globals.insert(name, handle) {
return Err(Error::Redefinition {
previous: decl_ident(&tu.decls[old]).span,
current: ident.span,
});
}
}
let len = tu.decls.len();
let solver = DependencySolver {
globals: &globals,
module: tu,
visited: vec![false; len],
temp_visited: vec![false; len],
path: Vec::new(),
out: Vec::with_capacity(len),
};
let dependency_order = solver.solve()?;
Ok(Self { dependency_order })
}
pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ {
self.dependency_order.iter().copied()
}
}
struct ResolvedDependency<'a> {
decl: Handle<ast::GlobalDecl<'a>>,
usage: Span,
}
struct DependencySolver<'source, 'temp> {
globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>,
module: &'temp ast::TranslationUnit<'source>,
visited: Vec<bool>,
temp_visited: Vec<bool>,
path: Vec<ResolvedDependency<'source>>,
out: Vec<Handle<ast::GlobalDecl<'source>>>,
}
impl<'a> DependencySolver<'a, '_> {
fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> {
for (id, _) in self.module.decls.iter() {
if self.visited[id.index()] {
continue;
}
self.dfs(id)?;
}
Ok(self.out)
}
fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> {
let decl = &self.module.decls[id];
let id_usize = id.index();
self.temp_visited[id_usize] = true;
for dep in decl.dependencies.iter() {
if let Some(&dep_id) = self.globals.get(dep.ident) {
self.path.push(ResolvedDependency {
decl: dep_id,
usage: dep.usage,
});
let dep_id_usize = dep_id.index();
if self.temp_visited[dep_id_usize] {
return if dep_id == id {
Err(Error::RecursiveDeclaration {
ident: decl_ident(decl).span,
usage: dep.usage,
})
} else {
let start_at = self
.path
.iter()
.rev()
.enumerate()
.find_map(|(i, dep)| (dep.decl == dep_id).then_some(i))
.unwrap_or(0);
Err(Error::CyclicDeclaration {
ident: decl_ident(&self.module.decls[dep_id]).span,
path: self.path[start_at..]
.iter()
.map(|curr_dep| {
let curr_id = curr_dep.decl;
let curr_decl = &self.module.decls[curr_id];
(decl_ident(curr_decl).span, curr_dep.usage)
})
.collect(),
})
};
} else if !self.visited[dep_id_usize] {
self.dfs(dep_id)?;
}
self.path.pop();
}
}
self.temp_visited[id_usize] = false;
self.out.push(id);
self.visited[id_usize] = true;
Ok(())
}
}
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
match decl.kind {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
}