use crate::{Diagnostic, Error};
use itertools::Itertools;
use wgsl_parse::syntax::{ModulePath, PathOrigin, TranslationUnit};
use std::{
borrow::Cow,
collections::HashMap,
fs,
path::{Path, PathBuf},
};
#[derive(Clone, Debug, thiserror::Error)]
pub enum ResolveError {
#[error("file not found: `{0}` ({1})")]
FileNotFound(PathBuf, String),
#[error("module not found: `{0}` ({1})")]
ModuleNotFound(ModulePath, String),
#[error("{0}")]
Error(#[from] Diagnostic<Error>),
}
type E = ResolveError;
pub trait Resolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError>;
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
let source = self.resolve_source(path)?;
let wesl: TranslationUnit = source.parse().map_err(|e| {
Diagnostic::from(e)
.with_module_path(path.clone(), self.display_name(path))
.with_source(source.to_string())
})?;
Ok(wesl)
}
fn display_name(&self, _path: &ModulePath) -> Option<String> {
None
}
fn fs_path(&self, _path: &ModulePath) -> Option<PathBuf> {
None
}
}
impl<T: Resolver + ?Sized> Resolver for Box<T> {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
(**self).resolve_source(path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
(**self).resolve_module(path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
(**self).display_name(path)
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
(**self).fs_path(path)
}
}
impl<T: Resolver> Resolver for &T {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
(**self).resolve_source(path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
(**self).resolve_module(path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
(**self).display_name(path)
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
(**self).fs_path(path)
}
}
#[derive(Default, Clone, Debug)]
pub struct NoResolver;
impl Resolver for NoResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
Err(E::ModuleNotFound(
path.clone(),
"no module resolver, imports are effectively disabled here".to_string(),
))
}
}
#[derive(Default)]
pub struct FileResolver {
base: PathBuf,
extension: &'static str,
}
impl FileResolver {
pub fn new(base: impl AsRef<Path>) -> Self {
Self {
base: base.as_ref().to_path_buf(),
extension: "wesl",
}
}
pub fn set_extension(&mut self, extension: &'static str) {
self.extension = extension;
}
fn file_path(&self, path: &ModulePath) -> Result<PathBuf, ResolveError> {
if path.origin.is_package() {
return Err(E::ModuleNotFound(
path.clone(),
"this is an external package import, not a file import. Use `package::` or `super::` for file imports."
.to_string(),
));
}
let mut fs_path = self.base.to_path_buf();
fs_path.extend(&path.components);
fs_path.set_extension(self.extension);
if fs_path.exists() {
Ok(fs_path)
} else {
fs_path.set_extension("wgsl");
if fs_path.exists() {
Ok(fs_path)
} else {
Err(E::FileNotFound(fs_path, "physical file".to_string()))
}
}
}
}
impl Resolver for FileResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
let fs_path = self.file_path(path)?;
let source = fs::read_to_string(&fs_path)
.map_err(|_| E::FileNotFound(fs_path, "physical file".to_string()))?;
Ok(source.into())
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
self.file_path(path)
.ok()
.map(|fs_path| fs_path.display().to_string())
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
self.file_path(path).ok()
}
}
#[derive(Default)]
pub struct VirtualResolver<'a> {
files: HashMap<ModulePath, Cow<'a, str>>,
}
impl<'a> VirtualResolver<'a> {
pub fn new() -> Self {
Self {
files: HashMap::new(),
}
}
pub fn add_module(&mut self, path: ModulePath, file: Cow<'a, str>) {
self.files.insert(path, file);
}
pub fn add_translation_unit(&mut self, path: ModulePath, translation_unit: TranslationUnit) {
self.files
.insert(path, Cow::Owned(translation_unit.to_string()));
}
pub fn get_module(&self, path: &ModulePath) -> Result<&str, ResolveError> {
let source = self
.files
.get(path)
.ok_or_else(|| E::ModuleNotFound(path.clone(), "virtual module".to_string()))?;
Ok(source)
}
pub fn modules(&self) -> impl Iterator<Item = (&ModulePath, &str)> {
self.files.iter().map(|(res, file)| (res, &**file))
}
}
impl Resolver for VirtualResolver<'_> {
fn resolve_source<'b>(&'b self, path: &ModulePath) -> Result<Cow<'b, str>, ResolveError> {
let source = self.get_module(path)?;
Ok(source.into())
}
}
pub trait ResolveFn: Fn(&mut TranslationUnit) -> Result<(), Error> {}
impl<T: Fn(&mut TranslationUnit) -> Result<(), Error>> ResolveFn for T {}
pub struct Preprocessor<R: Resolver, F: ResolveFn> {
pub resolver: R,
pub preprocess: F,
}
impl<R: Resolver, F: ResolveFn> Preprocessor<R, F> {
pub fn new(resolver: R, preprocess: F) -> Self {
Self {
resolver,
preprocess,
}
}
}
impl<R: Resolver, F: ResolveFn> Resolver for Preprocessor<R, F> {
fn resolve_source<'b>(&'b self, path: &ModulePath) -> Result<Cow<'b, str>, ResolveError> {
let res = self.resolver.resolve_source(path)?;
Ok(res)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
let mut wesl = self.resolver.resolve_module(path)?;
(self.preprocess)(&mut wesl).map_err(|e| {
Diagnostic::from(e)
.with_module_path(path.clone(), self.display_name(path))
.with_source(self.resolve_source(path).unwrap().to_string())
})?;
Ok(wesl)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
self.resolver.display_name(path)
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
self.resolver.fs_path(path)
}
}
pub struct Router {
mount_points: Vec<(ModulePath, Box<dyn Resolver>)>,
fallback: Option<(ModulePath, Box<dyn Resolver>)>,
}
impl Router {
pub fn new() -> Self {
Self {
mount_points: Vec::new(),
fallback: None,
}
}
pub fn mount_resolver(&mut self, prefix: ModulePath, resolver: impl Resolver + 'static) {
self.mount_points.push((prefix, Box::new(resolver)));
}
pub fn mount_fallback_resolver(&mut self, resolver: impl Resolver + 'static) {
self.fallback = Some((ModulePath::new_root(), Box::new(resolver)));
}
fn route(&self, path: &ModulePath) -> Result<(&dyn Resolver, ModulePath), ResolveError> {
let (mount_path, resolver) = self
.mount_points
.iter()
.filter(|(prefix, _)| path.starts_with(prefix))
.max_by_key(|(prefix, _)| prefix.components.len())
.or(self.fallback.as_ref())
.ok_or_else(|| E::ModuleNotFound(path.clone(), "no mount point".to_string()))?;
let components = path
.components
.iter()
.skip(mount_path.components.len())
.cloned()
.collect_vec();
let suffix = ModulePath::new(PathOrigin::Absolute, components);
Ok((resolver, suffix))
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Resolver for Router {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
let (resolver, path) = self.route(path)?;
resolver.resolve_source(&path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
let (resolver, path) = self.route(path)?;
resolver.resolve_module(&path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
let (resolver, path) = self.route(path).ok()?;
resolver.display_name(&path)
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
let (resolver, path) = self.route(path).ok()?;
resolver.fs_path(&path)
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct CodegenPkg {
pub crate_name: &'static str,
pub root: &'static CodegenModule,
pub dependencies: &'static [&'static CodegenPkg],
}
#[derive(Debug, PartialEq, Eq)]
pub struct CodegenModule {
pub name: &'static str,
pub source: &'static str,
pub submodules: &'static [&'static CodegenModule],
}
pub struct PkgResolver {
packages: Vec<&'static CodegenPkg>,
}
impl PkgResolver {
pub fn new() -> Self {
Self {
packages: Vec::new(),
}
}
pub fn add_package(&mut self, pkg: &'static CodegenPkg) {
self.packages.push(pkg);
}
}
impl Default for PkgResolver {
fn default() -> Self {
Self::new()
}
}
impl Resolver for PkgResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<std::borrow::Cow<'a, str>, E> {
let pkg_path = match &path.origin {
PathOrigin::Package(pkg) => pkg,
_ => {
return Err(E::ModuleNotFound(
path.clone(),
"resolver can only resolve package imports".to_string(),
));
}
};
let pkg_parts = pkg_path.split('/').collect_vec();
let root_pkg = pkg_parts
.first()
.and_then(|name| self.packages.iter().find(|p| p.root.name == *name))
.ok_or_else(|| {
E::ModuleNotFound(
path.clone(),
format!("dependency `{}` not found", pkg_parts.iter().format("/"),),
)
})?;
let pkg = pkg_parts.iter().skip(1).try_fold(root_pkg, |dep, name| {
dep.dependencies
.iter()
.find(|p| p.root.name == *name)
.ok_or_else(|| {
E::ModuleNotFound(
path.clone(),
format!(
"dependency `{}` not found in package path `{}`",
name,
pkg_parts.iter().format("/"),
),
)
})
})?;
let mut cur_mod = pkg.root;
for comp in &path.components {
if let Some(submod) = cur_mod.submodules.iter().find(|m| m.name == comp) {
cur_mod = submod;
} else {
return Err(E::ModuleNotFound(
path.clone(),
format!("in module `{}`, no submodule named `{comp}`", cur_mod.name),
));
}
}
Ok(cur_mod.source.into())
}
}
pub struct StandardResolver {
pkg: PkgResolver,
files: FileResolver,
constants: HashMap<String, f64>,
}
impl StandardResolver {
pub fn new(base: impl AsRef<Path>) -> Self {
Self {
pkg: PkgResolver::new(),
files: FileResolver::new(base),
constants: HashMap::new(),
}
}
pub fn add_package(&mut self, pkg: &'static CodegenPkg) {
self.pkg.add_package(pkg)
}
pub fn add_constant(&mut self, name: impl ToString, value: f64) {
self.constants.insert(name.to_string(), value);
}
fn generate_constant_module(&self) -> String {
self.constants
.iter()
.map(|(name, value)| format!("const {name} = {value};"))
.format("\n")
.to_string()
}
}
impl Resolver for StandardResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
if let PathOrigin::Package(pkg_name) = &path.origin {
if pkg_name == "constants" || pkg_name.ends_with("/constants") {
return Ok(self.generate_constant_module().into());
}
}
if path.origin.is_package() {
self.pkg.resolve_source(path)
} else {
self.files.resolve_source(path)
}
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
if path.origin.is_package() {
self.pkg.display_name(path)
} else {
self.files.display_name(path)
}
}
fn fs_path(&self, path: &ModulePath) -> Option<PathBuf> {
if path.origin.is_package() {
self.pkg.fs_path(path)
} else {
self.files.fs_path(path)
}
}
}
pub fn emit_rerun_if_changed(modules: &[ModulePath], resolver: &impl Resolver) {
for module in modules {
if module.origin.is_package() {
continue;
}
assert!(
!module.origin.is_relative(),
"the modules passed to emit_rerun_if_changed must be absolute"
);
if let Some(mut path) = resolver.fs_path(module) {
println!("cargo::rerun-if-changed={}", path.display());
if path.extension().unwrap() == "wgsl" {
path.set_extension("wesl");
println!("cargo::rerun-if-changed={}", path.display());
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn router_resolver() {
let mut r = Router::new();
let mut v1 = VirtualResolver::new();
v1.add_module("package".parse().unwrap(), "m1".into());
v1.add_module("package::foo".parse().unwrap(), "m2".into());
v1.add_module("package::bar".parse().unwrap(), "m3".into());
r.mount_resolver("package".parse().unwrap(), v1);
let mut v2 = VirtualResolver::new();
v2.add_module("package".parse().unwrap(), "m4".into());
v2.add_module("package::baz".parse().unwrap(), "m5".into());
r.mount_resolver("package::bar".parse().unwrap(), v2);
let mut v3 = VirtualResolver::new();
v3.add_module("package::bar".parse().unwrap(), "m6".into());
r.mount_fallback_resolver(v3);
assert_eq!(r.resolve_source(&"package".parse().unwrap()).unwrap(), "m1");
assert_eq!(
r.resolve_source(&"package::foo".parse().unwrap()).unwrap(),
"m2"
);
assert_eq!(
r.resolve_source(&"package::bar".parse().unwrap()).unwrap(),
"m4"
);
assert_eq!(
r.resolve_source(&"package::bar::baz".parse().unwrap())
.unwrap(),
"m5"
);
assert_eq!(
r.resolve_source(&"foo::bar".parse().unwrap()).unwrap(),
"m6"
);
}
}