use std::{
convert::Into,
fs,
io::{self, Write},
path::{Path, PathBuf},
};
use display_full_error::DisplayFullError;
use flate2::write::GzEncoder;
use glob::glob;
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, quote};
use sha2::{Digest as _, Sha256};
use syn::{
Ident, LitBool, LitByteStr, LitStr, Token, bracketed,
parse::{Parse, ParseStream},
parse_macro_input,
};
mod error;
use error::{Error, GzipType, ZstdType};
#[proc_macro]
pub fn embed_assets(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parsed = parse_macro_input!(input as EmbedAssets);
quote! { #parsed }.into()
}
#[proc_macro]
pub fn embed_asset(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parsed = parse_macro_input!(input as EmbedAsset);
quote! { #parsed }.into()
}
struct EmbedAsset {
asset_file: AssetFile,
should_compress: ShouldCompress,
cache_busted: IsCacheBusted,
allow_unknown_extensions: LitBool,
}
struct AssetFile(LitStr);
impl Parse for EmbedAsset {
fn parse(input: ParseStream) -> syn::Result<Self> {
let asset_file: AssetFile = input.parse()?;
let mut maybe_should_compress = None;
let mut maybe_is_cache_busted = None;
let mut maybe_allow_unknown_extensions = None;
while !input.is_empty() {
input.parse::<Token![,]>()?;
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"compress" => {
let value = input.parse()?;
maybe_should_compress = Some(value);
}
"cache_bust" => {
let value = input.parse()?;
maybe_is_cache_busted = Some(value);
}
"allow_unknown_extensions" => {
let value = input.parse()?;
maybe_allow_unknown_extensions = Some(value);
}
_ => {
return Err(syn::Error::new(
key.span(),
format!(
"Unknown key in `embed_asset!` macro. Expected `compress`, `cache_bust`, or `allow_unknown_extensions` but got {key}"
),
));
}
}
}
let should_compress = maybe_should_compress.unwrap_or_else(|| {
ShouldCompress(LitBool {
value: false,
span: Span::call_site(),
})
});
let cache_busted = maybe_is_cache_busted.unwrap_or_else(|| {
IsCacheBusted(LitBool {
value: false,
span: Span::call_site(),
})
});
let allow_unknown_extensions = maybe_allow_unknown_extensions.unwrap_or(LitBool {
value: false,
span: Span::call_site(),
});
Ok(Self {
asset_file,
should_compress,
cache_busted,
allow_unknown_extensions,
})
}
}
impl Parse for AssetFile {
fn parse(input: ParseStream) -> syn::Result<Self> {
let input_span = input.span();
let asset_file: LitStr = input.parse()?;
let literal = asset_file.value();
let path = Path::new(&literal);
let metadata = match fs::metadata(path) {
Ok(meta) => meta,
Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
return Err(syn::Error::new(
input_span,
format!("The specified asset file ({literal}) does not exist."),
));
}
Err(e) => {
return Err(syn::Error::new(
input_span,
format!("Error reading file {literal}: {}", DisplayFullError(&e)),
));
}
};
if metadata.is_dir() {
return Err(syn::Error::new(
input_span,
"The specified asset is a directory, not a file. Did you mean to call `embed_assets!` instead?",
));
}
Ok(AssetFile(asset_file))
}
}
impl ToTokens for EmbedAsset {
fn to_tokens(&self, tokens: &mut TokenStream) {
let AssetFile(asset_file) = &self.asset_file;
let ShouldCompress(should_compress) = &self.should_compress;
let IsCacheBusted(cache_busted) = &self.cache_busted;
let allow_unknown_extensions = &self.allow_unknown_extensions;
let result = generate_static_handler(
asset_file,
should_compress,
cache_busted,
allow_unknown_extensions,
);
match result {
Ok(value) => {
tokens.extend(quote! {
#value
});
}
Err(err_message) => {
let error = syn::Error::new(Span::call_site(), err_message);
tokens.extend(error.to_compile_error());
}
}
}
}
struct EmbedAssets {
assets_dir: AssetsDir,
validated_ignore_paths: IgnorePaths,
should_compress: ShouldCompress,
should_strip_html_ext: ShouldStripHtmlExt,
cache_busted_paths: CacheBustedPaths,
allow_unknown_extensions: LitBool,
}
impl Parse for EmbedAssets {
fn parse(input: ParseStream) -> syn::Result<Self> {
let assets_dir: AssetsDir = input.parse()?;
let mut maybe_should_compress = None;
let mut maybe_ignore_paths = None;
let mut maybe_should_strip_html_ext = None;
let mut maybe_cache_busted_paths = None;
let mut maybe_allow_unknown_extensions = None;
while !input.is_empty() {
input.parse::<Token![,]>()?;
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"compress" => {
let value = input.parse()?;
maybe_should_compress = Some(value);
}
"ignore_paths" => {
let value = input.parse()?;
maybe_ignore_paths = Some(value);
}
"strip_html_ext" => {
let value = input.parse()?;
maybe_should_strip_html_ext = Some(value);
}
"cache_busted_paths" => {
let value = input.parse()?;
maybe_cache_busted_paths = Some(value);
}
"allow_unknown_extensions" => {
let value = input.parse()?;
maybe_allow_unknown_extensions = Some(value);
}
_ => {
return Err(syn::Error::new(
key.span(),
"Unknown key in embed_assets! macro. Expected `compress`, `ignore_paths`, `strip_html_ext`, `cache_busted_paths`, or `allow_unknown_extensions`",
));
}
}
}
let should_compress = maybe_should_compress.unwrap_or_else(|| {
ShouldCompress(LitBool {
value: false,
span: Span::call_site(),
})
});
let should_strip_html_ext = maybe_should_strip_html_ext.unwrap_or_else(|| {
ShouldStripHtmlExt(LitBool {
value: false,
span: Span::call_site(),
})
});
let ignore_paths_with_span = maybe_ignore_paths.unwrap_or(IgnorePathsWithSpan(vec![]));
let validated_ignore_paths = validate_ignore_paths(ignore_paths_with_span, &assets_dir.0)?;
let maybe_cache_busted_paths =
maybe_cache_busted_paths.unwrap_or(CacheBustedPathsWithSpan(vec![]));
let cache_busted_paths =
validate_cache_busted_paths(maybe_cache_busted_paths, &assets_dir.0)?;
let allow_unknown_extensions = maybe_allow_unknown_extensions.unwrap_or(LitBool {
value: false,
span: Span::call_site(),
});
Ok(Self {
assets_dir,
validated_ignore_paths,
should_compress,
should_strip_html_ext,
cache_busted_paths,
allow_unknown_extensions,
})
}
}
impl ToTokens for EmbedAssets {
fn to_tokens(&self, tokens: &mut TokenStream) {
let AssetsDir(assets_dir) = &self.assets_dir;
let ignore_paths = &self.validated_ignore_paths;
let ShouldCompress(should_compress) = &self.should_compress;
let ShouldStripHtmlExt(should_strip_html_ext) = &self.should_strip_html_ext;
let cache_busted_paths = &self.cache_busted_paths;
let allow_unknown_extensions = &self.allow_unknown_extensions;
let result = generate_static_routes(
assets_dir,
ignore_paths,
should_compress,
should_strip_html_ext,
cache_busted_paths,
allow_unknown_extensions.value,
);
match result {
Ok(value) => {
tokens.extend(quote! {
#value
});
}
Err(err_message) => {
let error = syn::Error::new(Span::call_site(), err_message);
tokens.extend(error.to_compile_error());
}
}
}
}
struct AssetsDir(LitStr);
impl Parse for AssetsDir {
fn parse(input: ParseStream) -> syn::Result<Self> {
let input_span = input.span();
let assets_dir: LitStr = input.parse()?;
let literal = assets_dir.value();
let path = Path::new(&literal);
let metadata = match fs::metadata(path) {
Ok(meta) => meta,
Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
return Err(syn::Error::new(
input_span,
"The specified assets directory does not exist",
));
}
Err(e) => {
return Err(syn::Error::new(
input_span,
format!(
"Error reading directory {literal}: {}",
DisplayFullError(&e)
),
));
}
};
if !metadata.is_dir() {
return Err(syn::Error::new(
input_span,
"The specified assets directory is not a directory",
));
}
Ok(AssetsDir(assets_dir))
}
}
struct IgnorePaths(Vec<PathBuf>);
struct IgnorePathsWithSpan(Vec<(PathBuf, Span)>);
impl Parse for IgnorePathsWithSpan {
fn parse(input: ParseStream) -> syn::Result<Self> {
let dirs = parse_dirs(input)?;
Ok(IgnorePathsWithSpan(dirs))
}
}
fn validate_ignore_paths(
ignore_paths: IgnorePathsWithSpan,
assets_dir: &LitStr,
) -> syn::Result<IgnorePaths> {
let mut valid_ignore_paths = Vec::new();
for (dir, span) in ignore_paths.0 {
let full_path = PathBuf::from(assets_dir.value()).join(&dir);
match fs::metadata(&full_path) {
Ok(_) => valid_ignore_paths.push(full_path),
Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
return Err(syn::Error::new(
span,
"The specified ignored path does not exist",
));
}
Err(e) => {
return Err(syn::Error::new(
span,
format!(
"Error reading ignored path {}: {}",
dir.to_string_lossy(),
DisplayFullError(&e)
),
));
}
}
}
Ok(IgnorePaths(valid_ignore_paths))
}
struct ShouldCompress(LitBool);
impl Parse for ShouldCompress {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lit = input.parse()?;
Ok(ShouldCompress(lit))
}
}
struct ShouldStripHtmlExt(LitBool);
impl Parse for ShouldStripHtmlExt {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lit = input.parse()?;
Ok(ShouldStripHtmlExt(lit))
}
}
struct IsCacheBusted(LitBool);
impl Parse for IsCacheBusted {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lit = input.parse()?;
Ok(IsCacheBusted(lit))
}
}
struct CacheBustedPaths {
dirs: Vec<PathBuf>,
files: Vec<PathBuf>,
}
struct CacheBustedPathsWithSpan(Vec<(PathBuf, Span)>);
impl Parse for CacheBustedPathsWithSpan {
fn parse(input: ParseStream) -> syn::Result<Self> {
let dirs = parse_dirs(input)?;
Ok(CacheBustedPathsWithSpan(dirs))
}
}
fn validate_cache_busted_paths(
tuples: CacheBustedPathsWithSpan,
assets_dir: &LitStr,
) -> syn::Result<CacheBustedPaths> {
let mut valid_dirs = Vec::new();
let mut valid_files = Vec::new();
for (dir, span) in tuples.0 {
let full_path = PathBuf::from(assets_dir.value()).join(&dir);
match fs::metadata(&full_path) {
Ok(meta) => {
if meta.is_dir() {
valid_dirs.push(full_path);
} else {
valid_files.push(full_path);
}
}
Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
return Err(syn::Error::new(
span,
"The specified directory for cache busting does not exist",
));
}
Err(e) => {
return Err(syn::Error::new(
span,
format!(
"Error reading path {}: {}",
dir.to_string_lossy(),
DisplayFullError(&e)
),
));
}
}
}
Ok(CacheBustedPaths {
dirs: valid_dirs,
files: valid_files,
})
}
fn parse_dirs(input: ParseStream) -> syn::Result<Vec<(PathBuf, Span)>> {
let inner_content;
bracketed!(inner_content in input);
let mut dirs = Vec::new();
while !inner_content.is_empty() {
let directory_span = inner_content.span();
let directory_str = inner_content.parse::<LitStr>()?;
let path = PathBuf::from(directory_str.value());
dirs.push((path, directory_span));
if !inner_content.is_empty() {
inner_content.parse::<Token![,]>()?;
}
}
Ok(dirs)
}
fn generate_static_routes(
assets_dir: &LitStr,
ignore_paths: &IgnorePaths,
should_compress: &LitBool,
should_strip_html_ext: &LitBool,
cache_busted_paths: &CacheBustedPaths,
allow_unknown_extensions: bool,
) -> Result<TokenStream, error::Error> {
let assets_dir_abs = Path::new(&assets_dir.value())
.canonicalize()
.map_err(Error::CannotCanonicalizeDirectory)?;
let assets_dir_abs_str = assets_dir_abs
.to_str()
.ok_or(Error::InvalidUnicodeInDirectoryName)?;
let canon_ignore_paths = ignore_paths
.0
.iter()
.map(|d| {
d.canonicalize()
.map_err(Error::CannotCanonicalizeIgnorePath)
})
.collect::<Result<Vec<_>, _>>()?;
let canon_cache_busted_dirs = cache_busted_paths
.dirs
.iter()
.map(|d| {
d.canonicalize()
.map_err(Error::CannotCanonicalizeCacheBustedDir)
})
.collect::<Result<Vec<_>, _>>()?;
let canon_cache_busted_files = cache_busted_paths
.files
.iter()
.map(|file| file.canonicalize().map_err(Error::CannotCanonicalizeFile))
.collect::<Result<Vec<_>, _>>()?;
let mut routes = Vec::new();
for entry in glob(&format!("{assets_dir_abs_str}/**/*")).map_err(Error::Pattern)? {
let entry = entry.map_err(Error::Glob)?;
let metadata = entry.metadata().map_err(Error::CannotGetMetadata)?;
if metadata.is_dir() {
continue;
}
if canon_ignore_paths
.iter()
.any(|ignore_path| entry.starts_with(ignore_path))
{
continue;
}
let mut is_entry_cache_busted = false;
if canon_cache_busted_dirs
.iter()
.any(|dir| entry.starts_with(dir))
|| canon_cache_busted_files.contains(&entry)
{
is_entry_cache_busted = true;
}
let entry = entry
.canonicalize()
.map_err(Error::CannotCanonicalizeFile)?;
let entry_str = entry.to_str().ok_or(Error::FilePathIsNotUtf8)?;
let EmbeddedFileInfo {
entry_path,
content_type,
etag_str,
lit_byte_str_contents,
maybe_gzip,
maybe_zstd,
cache_busted,
} = EmbeddedFileInfo::from_path(
&entry,
Some(assets_dir_abs_str),
should_compress,
should_strip_html_ext,
is_entry_cache_busted,
allow_unknown_extensions,
)?;
routes.push(quote! {
router = ::static_serve::static_route(
router,
#entry_path,
#content_type,
#etag_str,
{
const _: &[u8] = include_bytes!(#entry_str);
#lit_byte_str_contents
},
#maybe_gzip,
#maybe_zstd,
#cache_busted
);
});
}
Ok(quote! {
pub fn static_router<S>() -> ::axum::Router<S>
where S: ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static {
let mut router = ::axum::Router::<S>::new();
#(#routes)*
router
}
})
}
fn generate_static_handler(
asset_file: &LitStr,
should_compress: &LitBool,
cache_busted: &LitBool,
allow_unknown_extensions: &LitBool,
) -> Result<TokenStream, error::Error> {
let asset_file_abs = Path::new(&asset_file.value())
.canonicalize()
.map_err(Error::CannotCanonicalizeFile)?;
let asset_file_abs_str = asset_file_abs.to_str().ok_or(Error::FilePathIsNotUtf8)?;
let EmbeddedFileInfo {
entry_path: _,
content_type,
etag_str,
lit_byte_str_contents,
maybe_gzip,
maybe_zstd,
cache_busted,
} = EmbeddedFileInfo::from_path(
&asset_file_abs,
None,
should_compress,
&LitBool {
value: false,
span: Span::call_site(),
},
cache_busted.value(),
allow_unknown_extensions.value(),
)?;
let route = quote! {
::static_serve::static_method_router(
#content_type,
#etag_str,
{
const _: &[u8] = include_bytes!(#asset_file_abs_str);
#lit_byte_str_contents
},
#maybe_gzip,
#maybe_zstd,
#cache_busted
)
};
Ok(route)
}
struct OptionBytesSlice(Option<LitByteStr>);
impl ToTokens for OptionBytesSlice {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend(if let Some(inner) = &self.0.as_ref() {
quote! { ::std::option::Option::Some(#inner) }
} else {
quote! { ::std::option::Option::None }
});
}
}
struct EmbeddedFileInfo {
entry_path: Option<String>,
content_type: String,
etag_str: String,
lit_byte_str_contents: LitByteStr,
maybe_gzip: OptionBytesSlice,
maybe_zstd: OptionBytesSlice,
cache_busted: bool,
}
impl EmbeddedFileInfo {
fn from_path(
pathbuf: &PathBuf,
assets_dir_abs_str: Option<&str>,
should_compress: &LitBool,
should_strip_html_ext: &LitBool,
cache_busted: bool,
allow_unknown_extensions: bool,
) -> Result<Self, Error> {
let contents = fs::read(pathbuf).map_err(Error::CannotReadEntryContents)?;
let (maybe_gzip, maybe_zstd) = if should_compress.value {
let gzip = gzip_compress(&contents)?;
let zstd = zstd_compress(&contents)?;
(gzip, zstd)
} else {
(None, None)
};
let content_type = file_content_type(pathbuf, allow_unknown_extensions)?;
let entry_path = if let Some(dir) = assets_dir_abs_str {
let relative_entry = pathbuf
.strip_prefix(dir)
.ok()
.and_then(|p| p.to_str())
.ok_or(Error::InvalidUnicodeInEntryName)?;
let mut web_path = normalize_web_path(relative_entry);
if should_strip_html_ext.value && content_type == "text/html" {
strip_html_ext(&mut web_path);
}
Some(web_path)
} else {
None
};
let etag_str = etag(&contents);
let lit_byte_str_contents = LitByteStr::new(&contents, Span::call_site());
let maybe_gzip = OptionBytesSlice(maybe_gzip);
let maybe_zstd = OptionBytesSlice(maybe_zstd);
Ok(Self {
entry_path,
content_type,
etag_str,
lit_byte_str_contents,
maybe_gzip,
maybe_zstd,
cache_busted,
})
}
}
fn gzip_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
let mut compressor = GzEncoder::new(Vec::new(), flate2::Compression::best());
compressor
.write_all(contents)
.map_err(|e| Error::Gzip(GzipType::CompressorWrite(e)))?;
let compressed = compressor
.finish()
.map_err(|e| Error::Gzip(GzipType::EncoderFinish(e)))?;
Ok(maybe_get_compressed(&compressed, contents))
}
fn zstd_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
let level = *zstd::compression_level_range().end();
let mut encoder = zstd::Encoder::new(Vec::new(), level).unwrap();
write_to_zstd_encoder(&mut encoder, contents)
.map_err(|e| Error::Zstd(ZstdType::EncoderWrite(e)))?;
let compressed = encoder
.finish()
.map_err(|e| Error::Zstd(ZstdType::EncoderFinish(e)))?;
Ok(maybe_get_compressed(&compressed, contents))
}
fn write_to_zstd_encoder(
encoder: &mut zstd::Encoder<'static, Vec<u8>>,
contents: &[u8],
) -> io::Result<()> {
encoder.set_pledged_src_size(Some(
contents
.len()
.try_into()
.expect("contents size should fit into u64"),
))?;
encoder.window_log(23)?;
encoder.include_checksum(false)?;
encoder.include_contentsize(false)?;
encoder.long_distance_matching(false)?;
encoder.write_all(contents)?;
Ok(())
}
fn is_compression_significant(compressed_len: usize, contents_len: usize) -> bool {
let ninety_pct_original = contents_len / 10 * 9;
compressed_len < ninety_pct_original
}
fn maybe_get_compressed(compressed: &[u8], contents: &[u8]) -> Option<LitByteStr> {
is_compression_significant(compressed.len(), contents.len())
.then(|| LitByteStr::new(compressed, Span::call_site()))
}
fn file_content_type(path: &Path, allow_unknown_extensions: bool) -> Result<String, error::Error> {
let Some(ext) = path.extension() else {
return if allow_unknown_extensions {
Ok(mime_guess::mime::APPLICATION_OCTET_STREAM.to_string())
} else {
Err(error::Error::UnknownFileExtension(None))
};
};
let ext = ext
.to_str()
.ok_or(error::Error::InvalidFileExtension(path.into()))?;
let guess = mime_guess::MimeGuess::from_ext(ext);
if allow_unknown_extensions {
return Ok(guess.first_or_octet_stream().to_string());
}
guess
.first_raw()
.map(ToOwned::to_owned)
.ok_or(error::Error::UnknownFileExtension(Some(ext.into())))
}
fn etag(contents: &[u8]) -> String {
let sha256 = Sha256::digest(contents);
let hash = u64::from_le_bytes(sha256[..8].try_into().unwrap())
^ u64::from_le_bytes(sha256[8..16].try_into().unwrap())
^ u64::from_le_bytes(sha256[16..24].try_into().unwrap())
^ u64::from_le_bytes(sha256[24..32].try_into().unwrap());
format!("\"{hash:016x}\"")
}
fn normalize_web_path(relative_path: &str) -> String {
let normalized = Path::new(relative_path)
.components()
.filter_map(|component| match component {
std::path::Component::Normal(segment) => segment.to_str(),
_ => None,
})
.collect::<Vec<_>>()
.join("/");
format!("/{normalized}")
}
fn strip_html_ext(path: &mut String) {
let ext = path.rsplit_once('.').map(|(_, ext)| ext);
if ext.is_some_and(|ext| ext.eq_ignore_ascii_case("html")) {
path.truncate(path.len() - ".html".len());
} else if ext.is_some_and(|ext| ext.eq_ignore_ascii_case("htm")) {
path.truncate(path.len() - ".htm".len());
}
if path.ends_with("/index") {
path.truncate(path.len() - "index".len());
} else if path == "/index" {
path.truncate(1);
}
}