use crate::{Error, visit::Visit};
use wgsl_parse::syntax::*;
pub fn lower(wesl: &mut TranslationUnit) -> Result<(), Error> {
wesl.imports.clear();
for attrs in Visit::<Attributes>::visit_mut(wesl) {
attrs.retain(|attr| {
!matches!(attr.node(),
Attribute::Custom(CustomAttribute { name, .. }) if name == "generic")
})
}
#[cfg(not(feature = "eval"))]
{
remove_type_aliases(wesl);
remove_global_consts(wesl);
}
#[cfg(feature = "eval")]
{
use crate::Diagnostic;
use crate::eval::{Context, Exec, Lower, mark_functions_const};
use wgsl_parse::SyntaxNode;
mark_functions_const(wesl);
{
let wesl2 = wesl.clone();
let mut ctx = Context::new(&wesl2);
wesl.exec(&mut ctx) .map_err(|e| Diagnostic::from(e).with_ctx(&ctx))?;
wesl.lower(&mut ctx)
.map_err(|e| Diagnostic::from(e).with_ctx(&ctx))?;
}
for decl in &mut wesl.global_declarations {
if let GlobalDeclaration::Function(decl) = decl.node_mut() {
decl.retain_attributes_mut(|attr| *attr != Attribute::Const);
}
}
}
Ok(())
}
#[allow(unused)]
fn remove_type_aliases(wesl: &mut TranslationUnit) {
let take_next_alias = |wesl: &mut TranslationUnit| {
let index = wesl
.global_declarations
.iter()
.position(|decl| matches!(decl.node(), GlobalDeclaration::TypeAlias(_)));
index.map(|index| {
let decl = wesl.global_declarations.swap_remove(index);
match decl.into_inner() {
GlobalDeclaration::TypeAlias(alias) => alias,
_ => unreachable!(),
}
})
};
while let Some(mut alias) = take_next_alias(wesl) {
alias.ident.rename(format!("{}", alias.ty));
}
}
#[allow(unused)]
fn remove_global_consts(wesl: &mut TranslationUnit) {
let take_next_const = |wesl: &mut TranslationUnit| {
let index = wesl.global_declarations.iter().position(|decl| {
matches!(
decl.node(),
GlobalDeclaration::Declaration(Declaration {
kind: DeclarationKind::Const,
..
})
)
});
index.map(|index| {
let decl = wesl.global_declarations.swap_remove(index);
match decl.into_inner() {
GlobalDeclaration::Declaration(d) => d,
_ => unreachable!(),
}
})
};
while let Some(mut decl) = take_next_const(wesl) {
decl.ident
.rename(format!("({})", decl.initializer.unwrap()));
}
}