include-zstd-derive 0.1.1

Procedural macro implementation for include-zstd compile-time compression macros.
Documentation
use proc_macro::TokenStream;
use proc_macro_crate::{FoundCrate, crate_name};
use quote::quote;
use std::fs;
use std::path::{Path, PathBuf};
use syn::parse::{Parse, ParseStream};
use syn::{LitByteStr, LitStr, Token, parse_macro_input};

struct FileMacroInput {
    source_file: Option<LitStr>,
    target_path: LitStr,
}

impl Parse for FileMacroInput {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let first: LitStr = input.parse()?;
        if input.is_empty() {
            return Ok(Self {
                source_file: None,
                target_path: first,
            });
        }

        let _comma: Token![,] = input.parse()?;
        let second: LitStr = input.parse()?;
        if !input.is_empty() {
            return Err(input.error("expected one string literal path or 'source_file, path'"));
        }

        Ok(Self {
            source_file: Some(first),
            target_path: second,
        })
    }
}

#[proc_macro]
pub fn r#str(input: TokenStream) -> TokenStream {
    let value = parse_macro_input!(input as LitStr);
    let data = value.value().into_bytes();
    expand_from_data(data, true)
}

#[proc_macro]
pub fn bytes(input: TokenStream) -> TokenStream {
    let value = parse_macro_input!(input as LitByteStr);
    let data = value.value();
    expand_from_data(data, false)
}

#[proc_macro]
pub fn file_str(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as FileMacroInput);
    let source_file = input.source_file.as_ref().map(LitStr::value);
    let source_path = input.target_path.value();

    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
        Ok(path) => path,
        Err(err) => {
            return syn::Error::new(input.target_path.span(), err)
                .to_compile_error()
                .into();
        }
    };

    let data = match fs::read(&absolute_path) {
        Ok(data) => data,
        Err(err) => {
            return syn::Error::new(
                input.target_path.span(),
                format!("failed to read '{}': {err}", absolute_path.display()),
            )
            .to_compile_error()
            .into();
        }
    };

    expand_from_data(data, true)
}

#[proc_macro]
pub fn file_bytes(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as FileMacroInput);
    let source_file = input.source_file.as_ref().map(LitStr::value);
    let source_path = input.target_path.value();

    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
        Ok(path) => path,
        Err(err) => {
            return syn::Error::new(input.target_path.span(), err)
                .to_compile_error()
                .into();
        }
    };

    let data = match fs::read(&absolute_path) {
        Ok(data) => data,
        Err(err) => {
            return syn::Error::new(
                input.target_path.span(),
                format!("failed to read '{}': {err}", absolute_path.display()),
            )
            .to_compile_error()
            .into();
        }
    };

    expand_from_data(data, false)
}

fn expand_from_data(data: Vec<u8>, decode_utf8: bool) -> TokenStream {
    let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
        Ok(compressed) => compressed,
        Err(err) => {
            return syn::Error::new(
                proc_macro2::Span::call_site(),
                format!("failed to compress data: {err}"),
            )
            .to_compile_error()
            .into();
        }
    };

    let include_zstd_crate = match crate_name("include-zstd") {
        // In a package with both lib+bin, proc-macros expanded inside the bin
        // should still target the library crate namespace.
        Ok(FoundCrate::Itself) => quote!(::include_zstd),
        Ok(FoundCrate::Name(name)) => {
            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
            quote!(::#ident)
        }
        Err(_) => quote!(::include_zstd),
    };

    let expanded = if decode_utf8 {
        quote! {
            {
                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();

                #include_zstd_crate::__private::decode_utf8(
                    __INCLUDE_ZSTD_CACHE
                        .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
                        .as_ref(),
                )
            }
        }
    } else {
        quote! {
            {
                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();

                __INCLUDE_ZSTD_CACHE
                    .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
                    .as_ref()
            }
        }
    };

    expanded.into()
}

fn resolve_path(source_file: Option<&str>, source_path: &str) -> Result<PathBuf, String> {
    let target_path = Path::new(source_path);
    if target_path.is_absolute() {
        return Ok(target_path.to_path_buf());
    }

    // Match `include_str!` semantics: always resolve relative paths against the
    // parent directory of the invocation's source file, using an absolute path
    // so the result is independent of the compiler's current working directory.
    let source_file_abs = if let Some(source_file) = source_file {
        absolutize_source_file(Path::new(source_file))
    } else {
        invocation_source_file_abs()
    };

    let source_dir = source_file_abs.parent().ok_or_else(|| {
        format!(
            "failed to resolve include path '{}': invocation source file '{}' has no parent directory",
            source_path,
            source_file_abs.display()
        )
    })?;

    Ok(source_dir.join(target_path))
}

/// Return the absolute path of the source file that contains the macro
/// invocation, mirroring how `include_str!` locates its base directory.
fn invocation_source_file_abs() -> PathBuf {
    let call_site = proc_macro::Span::call_site();

    // `local_file()` returns the canonical absolute on-disk path when the span
    // originates from a real source file; this is the same information rustc
    // uses internally to resolve `include_str!`.
    if let Some(path) = call_site.local_file() {
        return path;
    }

    // Fallback: `Span::file()` typically yields a path relative to the crate
    // root (e.g. "src/lib.rs"). Anchor it to `CARGO_MANIFEST_DIR` so the final
    // path is absolute, matching `include_str!`'s file-relative resolution.
    absolutize_source_file(Path::new(&call_site.file()))
}

fn absolutize_source_file(source_file: &Path) -> PathBuf {
    if source_file.is_absolute() {
        return source_file.to_path_buf();
    }

    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
        return PathBuf::from(manifest_dir).join(source_file);
    }

    source_file.to_path_buf()
}