use crate::{Diagnostic, Error};
use itertools::Itertools;
use wgsl_parse::syntax::{ModulePath, PathOrigin, TranslationUnit};
use std::{
borrow::Cow,
collections::HashMap,
fs,
path::{Path, PathBuf},
};
#[derive(Clone, Debug, thiserror::Error)]
pub enum ResolveError {
#[error("file not found: `{0}` ({1})")]
FileNotFound(PathBuf, String),
#[error("module not found: `{0}` ({1})")]
ModuleNotFound(ModulePath, String),
#[error("{0}")]
Error(#[from] Diagnostic<Error>),
}
type E = ResolveError;
pub trait Resolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError>;
fn source_to_module(
&self,
source: &str,
path: &ModulePath,
) -> Result<TranslationUnit, ResolveError> {
let wesl: TranslationUnit = source.parse().map_err(|e| {
Diagnostic::from(e)
.with_module_path(path.clone(), self.display_name(path))
.with_source(source.to_string())
})?;
Ok(wesl)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
let source = self.resolve_source(path)?;
let wesl = self.source_to_module(&source, path)?;
Ok(wesl)
}
fn display_name(&self, _path: &ModulePath) -> Option<String> {
None
}
}
impl<T: Resolver + ?Sized> Resolver for Box<T> {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
(**self).resolve_source(path)
}
fn source_to_module(
&self,
source: &str,
path: &ModulePath,
) -> Result<TranslationUnit, ResolveError> {
(**self).source_to_module(source, path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
(**self).resolve_module(path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
(**self).display_name(path)
}
}
impl<T: Resolver> Resolver for &T {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
(**self).resolve_source(path)
}
fn source_to_module(
&self,
source: &str,
path: &ModulePath,
) -> Result<TranslationUnit, ResolveError> {
(**self).source_to_module(source, path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
(**self).resolve_module(path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
(**self).display_name(path)
}
}
#[derive(Default, Clone, Debug)]
pub struct NoResolver;
impl Resolver for NoResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
Err(E::ModuleNotFound(
path.clone(),
"no module resolver, imports are effectively disabled here".to_string(),
))
}
}
#[derive(Default)]
pub struct FileResolver {
base: PathBuf,
extension: &'static str,
}
impl FileResolver {
pub fn new(base: impl AsRef<Path>) -> Self {
Self {
base: base.as_ref().to_path_buf(),
extension: "wesl",
}
}
pub fn set_extension(&mut self, extension: &'static str) {
self.extension = extension;
}
fn file_path(&self, path: &ModulePath) -> Result<PathBuf, ResolveError> {
if path.origin.is_package() {
return Err(E::ModuleNotFound(
path.clone(),
"this is an external package import, not a file import. Use `package::` or `super::` for file imports."
.to_string(),
));
}
let mut fs_path = self.base.to_path_buf();
fs_path.extend(&path.components);
fs_path.set_extension(self.extension);
if fs_path.exists() {
Ok(fs_path)
} else {
fs_path.set_extension("wgsl");
if fs_path.exists() {
Ok(fs_path)
} else {
Err(E::FileNotFound(fs_path, "physical file".to_string()))
}
}
}
}
impl Resolver for FileResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
let fs_path = self.file_path(path)?;
let source = fs::read_to_string(&fs_path)
.map_err(|_| E::FileNotFound(fs_path, "physical file".to_string()))?;
Ok(source.into())
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
self.file_path(path)
.ok()
.map(|fs_path| fs_path.display().to_string())
}
}
#[derive(Default)]
pub struct VirtualResolver<'a> {
files: HashMap<ModulePath, Cow<'a, str>>,
}
impl<'a> VirtualResolver<'a> {
pub fn new() -> Self {
Self {
files: HashMap::new(),
}
}
pub fn add_module(&mut self, path: impl Into<ModulePath>, file: Cow<'a, str>) {
let mut path = path.into();
path.origin = PathOrigin::Absolute; self.files.insert(path, file);
}
pub fn get_module(&self, path: &ModulePath) -> Result<&str, ResolveError> {
let source = self
.files
.get(path)
.ok_or_else(|| E::ModuleNotFound(path.clone(), "virtual module".to_string()))?;
Ok(source)
}
pub fn modules(&self) -> impl Iterator<Item = (&ModulePath, &str)> {
self.files.iter().map(|(res, file)| (res, &**file))
}
}
impl Resolver for VirtualResolver<'_> {
fn resolve_source<'b>(&'b self, path: &ModulePath) -> Result<Cow<'b, str>, ResolveError> {
let source = self.get_module(path)?;
Ok(source.into())
}
}
pub trait ResolveFn: Fn(&mut TranslationUnit) -> Result<(), Error> {}
impl<T: Fn(&mut TranslationUnit) -> Result<(), Error>> ResolveFn for T {}
pub struct Preprocessor<R: Resolver, F: ResolveFn> {
pub resolver: R,
pub preprocess: F,
}
impl<R: Resolver, F: ResolveFn> Preprocessor<R, F> {
pub fn new(resolver: R, preprocess: F) -> Self {
Self {
resolver,
preprocess,
}
}
}
impl<R: Resolver, F: ResolveFn> Resolver for Preprocessor<R, F> {
fn resolve_source<'b>(&'b self, path: &ModulePath) -> Result<Cow<'b, str>, ResolveError> {
let res = self.resolver.resolve_source(path)?;
Ok(res)
}
fn source_to_module(
&self,
source: &str,
path: &ModulePath,
) -> Result<TranslationUnit, ResolveError> {
let mut wesl: TranslationUnit = source.parse().map_err(|e| {
Diagnostic::from(e)
.with_module_path(path.clone(), self.display_name(path))
.with_source(source.to_string())
})?;
(self.preprocess)(&mut wesl).map_err(|e| {
Diagnostic::from(e)
.with_module_path(path.clone(), self.display_name(path))
.with_source(source.to_string())
})?;
Ok(wesl)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
self.resolver.display_name(path)
}
}
pub struct Router {
mount_points: Vec<(ModulePath, Box<dyn Resolver>)>,
fallback: Option<(ModulePath, Box<dyn Resolver>)>,
}
impl Router {
pub fn new() -> Self {
Self {
mount_points: Vec::new(),
fallback: None,
}
}
pub fn mount_resolver(
&mut self,
path: impl Into<ModulePath>,
resolver: impl Resolver + 'static,
) {
let path = path.into();
let resolver: Box<dyn Resolver> = Box::new(resolver);
if path.is_empty() {
self.fallback = Some((path, resolver));
} else {
self.mount_points.push((path, resolver));
}
}
pub fn mount_fallback_resolver(&mut self, resolver: impl Resolver + 'static) {
self.mount_resolver("", resolver);
}
fn route(&self, path: &ModulePath) -> Result<(&dyn Resolver, ModulePath), ResolveError> {
let (mount_path, resolver) = self
.mount_points
.iter()
.filter(|(prefix, _)| path.starts_with(prefix))
.max_by_key(|(prefix, _)| prefix.components.len())
.or(self.fallback.as_ref())
.ok_or_else(|| E::ModuleNotFound(path.clone(), "no mount point".to_string()))?;
let components = path
.components
.iter()
.skip(mount_path.components.len())
.cloned()
.collect_vec();
let suffix = ModulePath::new(PathOrigin::Absolute, components);
Ok((resolver, suffix))
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Resolver for Router {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
let (resolver, path) = self.route(path)?;
resolver.resolve_source(&path)
}
fn source_to_module(
&self,
source: &str,
path: &ModulePath,
) -> Result<TranslationUnit, ResolveError> {
let (resolver, path) = self.route(path)?;
resolver.source_to_module(source, &path)
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
let (resolver, path) = self.route(path)?;
resolver.resolve_module(&path)
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
let (resolver, path) = self.route(path).ok()?;
resolver.display_name(&path)
}
}
pub trait PkgModule: Send + Sync {
fn name(&self) -> &'static str;
fn source(&self) -> &'static str;
fn submodules(&self) -> &[&dyn PkgModule];
fn submodule(&self, name: &str) -> Option<&dyn PkgModule> {
self.submodules()
.iter()
.find(|sm| sm.name() == name)
.copied()
}
}
pub struct PkgResolver {
packages: Vec<&'static dyn PkgModule>,
}
impl PkgResolver {
pub fn new() -> Self {
Self {
packages: Vec::new(),
}
}
pub fn add_package(&mut self, pkg: &'static dyn PkgModule) {
self.packages.push(pkg);
}
}
impl Default for PkgResolver {
fn default() -> Self {
Self::new()
}
}
impl Resolver for PkgResolver {
fn resolve_source<'a>(
&'a self,
path: &ModulePath,
) -> Result<std::borrow::Cow<'a, str>, ResolveError> {
for pkg in &self.packages {
if path.origin.is_package()
&& path.components.first().map(String::as_str) == Some(pkg.name())
{
let mut cur_mod = *pkg;
for comp in path.components.iter().skip(1) {
if let Some(submod) = pkg.submodule(comp) {
cur_mod = submod;
} else {
return Err(E::ModuleNotFound(
path.clone(),
format!(
"in module `{}`, no submodule named `{comp}`",
cur_mod.name()
),
));
}
}
return Ok(cur_mod.source().into());
}
}
Err(E::ModuleNotFound(
path.clone(),
"no package found".to_string(),
))
}
}
pub struct StandardResolver {
pkg: PkgResolver,
files: FileResolver,
}
impl StandardResolver {
pub fn new(base: impl AsRef<Path>) -> Self {
Self {
pkg: PkgResolver::new(),
files: FileResolver::new(base),
}
}
pub fn add_package(&mut self, pkg: &'static dyn PkgModule) {
self.pkg.add_package(pkg)
}
}
impl Resolver for StandardResolver {
fn resolve_source<'a>(&'a self, path: &ModulePath) -> Result<Cow<'a, str>, ResolveError> {
if path.origin.is_package() {
self.pkg.resolve_source(path)
} else {
self.files.resolve_source(path)
}
}
fn resolve_module(&self, path: &ModulePath) -> Result<TranslationUnit, ResolveError> {
if path.origin.is_package() {
self.pkg.resolve_module(path)
} else {
self.files.resolve_module(path)
}
}
fn display_name(&self, path: &ModulePath) -> Option<String> {
if path.origin.is_package() {
self.pkg.display_name(path)
} else {
self.files.display_name(path)
}
}
}