use std::collections::HashSet;
use wesl_macros::query;
use wgsl_parse::SyntaxNode;
use wgsl_parse::syntax::{
Expression, ExpressionNode, FunctionCall, GlobalDeclaration, Ident, ImportContent,
TranslationUnit, TypeExpression,
};
use wgsl_types::idents::{BUILTIN_CONSTRUCTOR_NAMES, BUILTIN_FUNCTION_NAMES};
use crate::idents::builtin_ident;
use crate::visit::Visit;
use crate::{Diagnostic, Error};
#[derive(Clone, Debug, thiserror::Error)]
pub enum ValidateError {
#[error("cannot find declaration of `{0}`")]
UndefinedSymbol(String),
#[error("incorrect number of arguments to `{0}`, expected `{1}`, got `{2}`")]
ParamCount(String, usize, usize),
#[error("`{0}` is not callable")]
NotCallable(String),
#[error("duplicate declaration of `{0}`")]
Duplicate(String),
#[error("declaration of `{0}` is cyclic via `{1}`")]
Cycle(String, String),
}
type E = ValidateError;
fn check_defined_symbols(wesl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
fn check_ty(ty: &TypeExpression) -> Result<(), Diagnostic<Error>> {
if ty.path.is_none()
&& ty.ident.use_count() == 1
&& builtin_ident(&ty.ident.name()).is_none()
&& *ty.ident.name() != "_"
{
Err(E::UndefinedSymbol(ty.ident.to_string()).into())
} else {
for arg in ty.template_args.iter().flatten() {
check_expr(&arg.expression)?;
}
Ok(())
}
}
fn check_expr(expr: &ExpressionNode) -> Result<(), Diagnostic<Error>> {
if let Expression::TypeOrIdentifier(ty) = expr.node() {
check_ty(ty).map_err(|d| d.with_span(expr.span()))
} else if let Expression::FunctionCall(call) = expr.node() {
check_ty(&call.ty).map_err(|d| d.with_span(expr.span()))?;
for expr in &call.arguments {
check_expr(expr)?;
}
Ok(())
} else {
for expr in Visit::<ExpressionNode>::visit(expr.node()) {
check_expr(expr)?;
}
Ok(())
}
}
fn check_decl(decl: &GlobalDeclaration) -> Result<(), Diagnostic<Error>> {
let decl_name = decl.ident().map(|ident| ident.name().to_string());
for expr in Visit::<ExpressionNode>::visit(decl) {
check_expr(expr).map_err(|mut d| {
d.detail.declaration = decl_name.clone();
d
})?;
}
for ty in query!(decl.{
GlobalDeclaration::Declaration.ty.[],
GlobalDeclaration::TypeAlias.ty,
GlobalDeclaration::Struct.members.[].ty,
GlobalDeclaration::Function.{ parameters.[].ty, return_type.[] }
}) {
check_ty(ty).map_err(|mut d| {
d.detail.declaration = decl_name.clone();
d
})?;
}
Ok(())
}
for decl in &wesl.global_declarations {
check_decl(decl)?;
}
Ok(())
}
fn check_function_calls(wesl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
fn check_call(call: &FunctionCall, ident: &Ident, wesl: &TranslationUnit) -> Result<(), E> {
let decl = wesl
.global_declarations
.iter()
.find(|decl| decl.ident().is_some_and(|id| &id == ident))
.map(|decl| decl.node());
match decl {
Some(GlobalDeclaration::Function(decl)) => {
if call.arguments.len() != decl.parameters.len() {
return Err(E::ParamCount(
ident.to_string(),
decl.parameters.len(),
call.arguments.len(),
));
}
}
Some(GlobalDeclaration::Struct(decl)) => {
if call.arguments.len() != decl.members.len() && !call.arguments.is_empty() {
return Err(E::ParamCount(
ident.to_string(),
decl.members.len(),
call.arguments.len(),
));
}
}
Some(GlobalDeclaration::TypeAlias(decl)) => {
if decl.ty.template_args.is_some() {
} else {
check_call(call, &decl.ty.ident, wesl)?;
}
}
Some(_) => return Err(E::NotCallable(ident.to_string())),
None => {
if BUILTIN_FUNCTION_NAMES
.iter()
.any(|name| name == &*ident.name())
{
} else if BUILTIN_CONSTRUCTOR_NAMES
.iter()
.any(|name| name == &*ident.name())
{
} else {
return Err(E::NotCallable(ident.to_string()));
}
}
};
Ok(())
}
fn check_expr(expr: &Expression, wesl: &TranslationUnit) -> Result<(), E> {
if let Expression::FunctionCall(call) = expr {
check_call(call, &call.ty.ident, wesl)?;
}
Ok(())
}
for decl in &wesl.global_declarations {
for expr in Visit::<ExpressionNode>::visit(decl.node()) {
check_expr(expr, wesl).map_err(|e| {
let mut err = Diagnostic::from(e);
err.detail.span = Some(expr.span());
err.detail.declaration = decl.ident().map(|id| id.name().to_string());
err
})?;
}
}
Ok(())
}
fn check_duplicate_decl(wesl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
let mut unique = HashSet::new();
fn check_ident(id: &Ident, unique: &mut HashSet<String>) -> Result<(), Diagnostic<Error>> {
if unique.contains(id.name().as_str()) {
Err(Diagnostic::from(E::Duplicate(id.to_string())).with_declaration(id.to_string()))
} else {
unique.insert(id.to_string());
Ok(())
}
}
fn check_import_content(
cont: &ImportContent,
unique: &mut HashSet<String>,
) -> Result<(), Diagnostic<Error>> {
match cont {
ImportContent::Item(item) => {
let id = item.rename.as_ref().unwrap_or(&item.ident);
check_ident(id, unique)?;
}
ImportContent::Collection(coll) => {
for item in coll {
check_import_content(&item.content, unique)?;
}
}
}
Ok(())
}
for import in &wesl.imports {
if import
.attributes()
.iter()
.any(|attr| attr.is_if() || attr.is_elif() || attr.is_else())
{
continue;
}
check_import_content(&import.content, &mut unique)?;
}
for decl in &wesl.global_declarations {
if decl
.attributes()
.iter()
.any(|attr| attr.is_if() || attr.is_elif() || attr.is_else())
{
continue;
}
if let Some(id) = decl.ident() {
check_ident(&id, &mut unique)?;
}
}
Ok(())
}
fn check_cycles(wesl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
fn check_decl(
id: &Ident,
decl: &GlobalDeclaration,
unique: &mut HashSet<Ident>,
wesl: &TranslationUnit,
) -> Result<(), E> {
let mut res = Ok(());
Visit::<TypeExpression>::visit_rec(decl, &mut |ty| {
if res.is_ok() {
if ty.ident == *id {
res = Err(E::Cycle(id.to_string(), decl.ident().unwrap().to_string()));
} else if unique.insert(ty.ident.clone()) {
if let Some(decl) = wesl
.global_declarations
.iter()
.find(|decl| decl.ident().as_ref() == Some(&ty.ident))
{
res = check_decl(id, decl, unique, wesl);
}
}
}
});
res
}
for decl in &wesl.global_declarations {
if let Some(id) = decl.ident() {
let mut unique = HashSet::new();
check_decl(&id, decl, &mut unique, wesl)
.map_err(|e| Diagnostic::from(e).with_declaration(id.to_string()))?;
}
}
Ok(())
}
pub fn validate_wesl(wesl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
check_defined_symbols(wesl)?;
check_duplicate_decl(wesl)?;
check_cycles(wesl)?;
Ok(())
}
pub fn validate_wgsl(wgsl: &TranslationUnit) -> Result<(), Diagnostic<Error>> {
check_defined_symbols(wgsl)?;
check_duplicate_decl(wgsl)?;
check_cycles(wgsl)?;
check_function_calls(wgsl)?;
Ok(())
}