use std::{
path::{Path, PathBuf},
rc::Rc,
};
use ahash::{HashMap, HashSet, HashSetExt as _};
use anyhow::{Context as _, anyhow, bail, ensure};
use clean_path::Clean as _;
use crate::FileSystem;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct SearchPath {
dirs: Vec<PathBuf>,
}
impl SearchPath {
pub fn from_env() -> Self {
const RERUN_SHADER_PATH: &str = "RERUN_SHADER_PATH";
std::env::var(RERUN_SHADER_PATH)
.map_or_else(|_| Ok(Self::default()), |s| s.parse())
.unwrap_or_else(|_| Self::default())
}
pub fn push(&mut self, dir: impl AsRef<Path>) {
self.dirs.push(dir.as_ref().clean());
}
pub fn insert(&mut self, index: usize, dir: impl AsRef<Path>) {
self.dirs.insert(index, dir.as_ref().clean());
}
pub fn iter(&self) -> impl Iterator<Item = &Path> {
self.dirs.iter().map(|p| p.as_path())
}
}
impl std::str::FromStr for SearchPath {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let dirs: Result<Vec<PathBuf>, _> = s
.split(':')
.filter(|s| !s.is_empty())
.map(|s| {
s.parse()
.with_context(|| format!("couldn't parse {s:?} as PathBuf"))
})
.collect();
dirs.map(|dirs| Self {
dirs: dirs.into_iter().map(|dir| dir.clean()).collect(),
})
}
}
impl std::fmt::Display for SearchPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = self
.dirs
.iter()
.map(|p| p.to_string_lossy())
.collect::<Vec<_>>()
.join(":");
f.write_str(&s)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ImportClause {
path: PathBuf,
}
impl ImportClause {
pub const PREFIX: &'static str = "#import ";
}
impl<P: Into<PathBuf>> From<P> for ImportClause {
fn from(path: P) -> Self {
Self { path: path.into() }
}
}
impl std::str::FromStr for ImportClause {
type Err = anyhow::Error;
fn from_str(clause_str: &str) -> Result<Self, Self::Err> {
let s = clause_str.trim();
ensure!(
s.starts_with(Self::PREFIX),
"import clause must start with {prefix:?}, got {s:?}",
prefix = Self::PREFIX,
);
let s = s.trim_start_matches(Self::PREFIX).trim();
let rs = s.chars().rev().collect::<String>();
let splits = s
.find('<')
.and_then(|i0| rs.find('>').map(|i1| (i0 + 1, rs.len() - i1 - 1)));
if let Some((i0, i1)) = splits {
let s = &s[i0..i1];
ensure!(!s.is_empty(), "import clause must contain a non-empty path");
return s
.parse()
.with_context(|| format!("couldn't parse {s:?} as PathBuf"))
.map(|path| Self { path });
}
bail!("misformatted import clause: {clause_str:?}")
}
}
impl std::fmt::Display for ImportClause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("#import <{}>", self.path.to_string_lossy()))
}
}
#[cfg(test)]
mod tests_import_clause {
use super::*;
#[test]
fn parsing_success() {
let testcases: [(&str, PathBuf, Option<&str>); 16] = [
(
"#import <my_constants>",
"my_constants".parse().unwrap(),
None,
),
(
"#import <my_constants.wgsl>",
"my_constants.wgsl".parse().unwrap(),
None,
),
(
"#import <x/y/z/my_constants>",
"x/y/z/my_constants".parse().unwrap(),
None,
),
(
"#import <x/y/z/my_constants.wgsl>",
"x/y/z/my_constants.wgsl".parse().unwrap(),
None,
),
(
"#import </x/y/z/my_constants>",
"/x/y/z/my_constants".parse().unwrap(),
None,
),
(
"#import </x/y/z/my_constants.wgsl>",
"/x/y/z/my_constants.wgsl".parse().unwrap(),
None,
),
(
"#import </x/y/z/my constants>",
"/x/y/z/my constants".parse().unwrap(),
None,
),
(
"#import </x/y/z/my constants.wgsl>",
"/x/y/z/my constants.wgsl".parse().unwrap(),
None,
),
(
"#import </x/y/z/my><constants>",
"/x/y/z/my><constants".parse().unwrap(),
None,
),
(
"#import </x/y/z/my><constants.wgsl>",
"/x/y/z/my><constants.wgsl".parse().unwrap(),
None,
),
(
" #import \t\t\t </x/y/z/my>\" \"<constants> \t\t\t",
"/x/y/z/my>\" \"<constants".parse().unwrap(),
"#import </x/y/z/my>\" \"<constants>".into(),
),
(
" #import \t\t\t </x/y/z/my>\" \"<constants.wgsl> \t\t\t",
"/x/y/z/my>\" \"<constants.wgsl".parse().unwrap(),
"#import </x/y/z/my>\" \"<constants.wgsl>".into(),
),
("#import <<>>", "<>".parse().unwrap(), None),
(
"#import <my_constants.wgsl> <my_other_constants.wgsl>",
"my_constants.wgsl> <my_other_constants.wgsl"
.parse()
.unwrap(),
None,
),
(
"#import <my_constants.wgsl> \t\t\t #import <my_other_constants.wgsl>",
"my_constants.wgsl> \t\t\t #import <my_other_constants.wgsl"
.parse()
.unwrap(),
None,
),
(
"#import <my_multiline\r\npath.wgsl>",
"my_multiline\r\npath.wgsl".parse().unwrap(),
None,
),
];
let testcases = testcases
.into_iter()
.map(|(clause_str, path, clause_str_clean)| {
(clause_str, ImportClause::from(path), clause_str_clean)
});
for (clause_str, expected, expected_clause) in testcases {
eprintln!("test case: ({clause_str:?}, {expected:?})");
let clause = clause_str.parse::<ImportClause>().unwrap();
assert_eq!(expected, clause);
let clause_str_clean = clause.to_string();
if let Some(expected_clause) = expected_clause {
assert_eq!(expected_clause, clause_str_clean);
} else {
assert_eq!(clause_str, clause_str_clean);
}
}
}
#[test]
fn parsing_failure() {
let testcases = [
"#import <",
"#import <>",
"import my_constants",
"my_constants",
];
for s in testcases {
eprintln!("test case: {s:?}");
assert!(s.parse::<ImportClause>().is_err());
}
}
}
#[cfg(load_shaders_from_disk)]
pub type RecommendedFileResolver = FileResolver<crate::OsFileSystem>;
#[cfg(not(load_shaders_from_disk))]
pub type RecommendedFileResolver = FileResolver<&'static crate::MemFileSystem>;
pub fn new_recommended() -> RecommendedFileResolver {
let mut search_path = SearchPath::from_env();
search_path.push("crates/viewer/re_renderer/shader");
FileResolver::with_search_path(crate::get_filesystem(), search_path)
}
#[derive(Clone, Debug, Default)]
pub struct InterpolatedFile {
pub contents: String,
pub imports: HashSet<PathBuf>,
}
#[derive(Default)]
pub struct FileResolver<Fs> {
fs: Fs,
search_path: SearchPath,
}
impl<Fs: FileSystem> FileResolver<Fs> {
pub fn new(fs: Fs) -> Self {
Self {
fs,
search_path: Default::default(),
}
}
pub fn with_search_path(fs: Fs, search_path: SearchPath) -> Self {
Self { fs, search_path }
}
}
impl<Fs: FileSystem> FileResolver<Fs> {
pub fn populate(&self, path: impl AsRef<Path>) -> anyhow::Result<InterpolatedFile> {
re_tracing::profile_function!();
fn populate_rec<Fs: FileSystem>(
this: &FileResolver<Fs>,
path: impl AsRef<Path>,
interp_files: &mut HashMap<PathBuf, Rc<InterpolatedFile>>,
path_stack: &mut Vec<PathBuf>,
visited_stack: &mut HashSet<PathBuf>,
) -> anyhow::Result<Rc<InterpolatedFile>> {
let path = path.as_ref().clean();
path_stack.push(path.clone());
ensure!(
visited_stack.insert(path.clone()),
"import cycle detected: {path_stack:?}"
);
if interp_files.contains_key(&path) {
path_stack.pop().unwrap();
visited_stack.remove(&path);
return Ok(Default::default());
}
let contents = this.fs.read_to_string(&path)?;
let mut imports = HashSet::new();
let children: Result<Vec<_>, _> = contents
.lines()
.map(|line| {
if line.trim().starts_with(ImportClause::PREFIX) {
let clause = line.parse::<ImportClause>()?;
let cwd = path.join("..").clean();
let clause_path =
this.resolve_clause_path(cwd, &clause.path).ok_or_else(|| {
anyhow!("couldn't resolve import clause path at {:?}", clause.path)
})?;
imports.insert(clause_path.clone());
populate_rec(this, clause_path, interp_files, path_stack, visited_stack)
} else {
Ok(Rc::new(InterpolatedFile {
contents: line.to_owned(),
..Default::default()
}))
}
})
.collect();
let children = children?;
let interp = children.into_iter().fold(
InterpolatedFile {
imports,
..Default::default()
},
|acc, child| InterpolatedFile {
contents: match (acc.contents.is_empty(), child.contents.is_empty()) {
(true, _) => child.contents.clone(),
(_, true) => acc.contents,
_ => [acc.contents.as_str(), child.contents.as_str()].join("\n"),
},
imports: acc.imports.union(&child.imports).cloned().collect(),
},
);
let interp = Rc::new(interp);
interp_files.insert(path.clone(), Rc::clone(&interp));
path_stack.pop().unwrap();
visited_stack.remove(&path);
Ok(interp)
}
let mut path_stack = Vec::new();
let mut visited_stack = HashSet::new();
let mut interp_files = HashMap::default();
populate_rec(
self,
path,
&mut interp_files,
&mut path_stack,
&mut visited_stack,
)
.map(|interp| (*interp).clone())
}
fn resolve_clause_path(
&self,
cwd: impl AsRef<Path>,
path: impl AsRef<Path>,
) -> Option<PathBuf> {
let path = path.as_ref().clean();
if path.is_absolute() && self.fs.exists(&path) {
return path.into();
}
{
let path = cwd.as_ref().join(&path).clean();
if self.fs.exists(&path) {
return path.into();
}
}
for dir in self.search_path.iter() {
let dir = dir.join(&path).clean();
if self.fs.exists(&dir) {
return dir.into();
}
}
None
}
}
#[cfg(test)]
mod tests_file_resolver {
use crate::MemFileSystem;
use unindent::unindent;
use super::*;
#[test]
fn acyclic_interpolation() {
let fs = MemFileSystem::get();
{
fs.create_dir_all("/shaders1/common").unwrap();
fs.create_dir_all("/shaders1/a/b/c/d").unwrap();
fs.create_file(
"/shaders1/common/shader1.wgsl",
unindent(
r#"
my first shader!
#import </shaders1/common/shader4.wgsl>
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders1/a/b/shader2.wgsl",
unindent(
r#"
#import </shaders1/common/shader1.wgsl>
#import <../../common/shader1.wgsl>
#import </shaders1/a/b/c/d/shader3.wgsl>
#import <c/d/shader3.wgsl>
my second shader!
#import <common/shader1.wgsl>
#import <shader1.wgsl>
#import <shader3.wgsl>
#import <a/b/c/d/shader3.wgsl>
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders1/a/b/c/d/shader3.wgsl",
unindent(
r#"
#import </shaders1/common/shader1.wgsl>
#import <../../../../common/shader1.wgsl>
my third shader!
#import <common/shader1.wgsl>
#import <shader1.wgsl>
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders1/common/shader4.wgsl",
unindent(r#"my fourth shader!"#).into(),
)
.unwrap();
}
let resolver = FileResolver::with_search_path(fs, {
let mut search_path = SearchPath::default();
search_path.push("/shaders1");
search_path.push("/shaders1/common");
search_path.push("/shaders1/a/b/c/d");
search_path
});
for _ in 0..3 {
let shader1_interp = resolver.populate("/shaders1/common/shader1.wgsl").unwrap();
let mut imports = shader1_interp.imports.into_iter().collect::<Vec<_>>();
imports.sort();
let expected: Vec<PathBuf> = vec!["/shaders1/common/shader4.wgsl".into()];
assert_eq!(expected, imports);
let contents = shader1_interp.contents;
let expected = unindent(
r#"
my first shader!
my fourth shader!"#,
);
assert_eq!(expected, contents);
let shader2_interp = resolver.populate("/shaders1/a/b/shader2.wgsl").unwrap();
let mut imports = shader2_interp.imports.into_iter().collect::<Vec<_>>();
imports.sort();
let expected: Vec<PathBuf> = vec![
"/shaders1/a/b/c/d/shader3.wgsl".into(),
"/shaders1/common/shader1.wgsl".into(),
"/shaders1/common/shader4.wgsl".into(),
];
assert_eq!(expected, imports);
let contents = shader2_interp.contents;
let expected = unindent(
r#"
my first shader!
my fourth shader!
my third shader!
my second shader!"#,
);
assert_eq!(expected, contents);
let shader3_interp = resolver.populate("/shaders1/a/b/c/d/shader3.wgsl").unwrap();
let mut imports = shader3_interp.imports.into_iter().collect::<Vec<_>>();
imports.sort();
let expected: Vec<PathBuf> = vec![
"/shaders1/common/shader1.wgsl".into(),
"/shaders1/common/shader4.wgsl".into(),
];
assert_eq!(expected, imports);
let contents = shader3_interp.contents;
let expected = unindent(
r#"
my first shader!
my fourth shader!
my third shader!"#,
);
assert_eq!(expected, contents);
}
}
#[test]
#[allow(clippy::should_panic_without_expect)] #[should_panic]
fn cyclic_direct() {
let fs = MemFileSystem::get();
{
fs.create_dir_all("/shaders2").unwrap();
fs.create_file(
"/shaders2/shader1.wgsl",
unindent(
r#"
#import </shaders2/shader2.wgsl>
my first shader!
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders2/shader2.wgsl",
unindent(
r#"
#import </shaders2/shader1.wgsl>
my second shader!
"#,
)
.into(),
)
.unwrap();
}
let resolver = FileResolver::new(fs);
resolver
.populate("/shaders2/shader1.wgsl")
.map_err(re_error::format)
.unwrap();
}
#[test]
#[allow(clippy::should_panic_without_expect)] #[should_panic]
fn cyclic_indirect() {
let fs = MemFileSystem::get();
{
fs.create_dir_all("/shaders3").unwrap();
fs.create_file(
"/shaders3/shader1.wgsl",
unindent(
r#"
#import </shaders3/shader2.wgsl>
my first shader!
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders3/shader2.wgsl",
unindent(
r#"
#import </shaders3/shader3.wgsl>
my second shader!
"#,
)
.into(),
)
.unwrap();
fs.create_file(
"/shaders3/shader3.wgsl",
unindent(
r#"
#import </shaders3/shader1.wgsl>
my third shader!
"#,
)
.into(),
)
.unwrap();
}
let resolver = FileResolver::new(fs);
resolver
.populate("/shaders3/shader1.wgsl")
.map_err(re_error::format)
.unwrap();
}
}