Skip to main content

include_zstd_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_crate::{crate_name, FoundCrate};
3use quote::quote;
4use std::fs;
5use std::path::{Path, PathBuf};
6use syn::parse::{Parse, ParseStream};
7use syn::{parse_macro_input, LitByteStr, LitStr, Token};
8
9struct FileMacroInput {
10    source_file: Option<LitStr>,
11    target_path: LitStr,
12}
13
14impl Parse for FileMacroInput {
15    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16        let first: LitStr = input.parse()?;
17        if input.is_empty() {
18            return Ok(Self {
19                source_file: None,
20                target_path: first,
21            });
22        }
23
24        let _comma: Token![,] = input.parse()?;
25        let second: LitStr = input.parse()?;
26        if !input.is_empty() {
27            return Err(input.error("expected one string literal path or 'source_file, path'"));
28        }
29
30        Ok(Self {
31            source_file: Some(first),
32            target_path: second,
33        })
34    }
35}
36
37#[proc_macro]
38pub fn r#str(input: TokenStream) -> TokenStream {
39    let value = parse_macro_input!(input as LitStr);
40    let data = value.value().into_bytes();
41    expand_from_data(data, true)
42}
43
44#[proc_macro]
45pub fn bytes(input: TokenStream) -> TokenStream {
46    let value = parse_macro_input!(input as LitByteStr);
47    let data = value.value();
48    expand_from_data(data, false)
49}
50
51#[proc_macro]
52pub fn file_str(input: TokenStream) -> TokenStream {
53    let input = parse_macro_input!(input as FileMacroInput);
54    let source_file = input.source_file.as_ref().map(LitStr::value);
55    let source_path = input.target_path.value();
56
57    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
58        Ok(path) => path,
59        Err(err) => {
60            return syn::Error::new(input.target_path.span(), err)
61                .to_compile_error()
62                .into();
63        }
64    };
65
66    let data = match fs::read(&absolute_path) {
67        Ok(data) => data,
68        Err(err) => {
69            return syn::Error::new(
70                input.target_path.span(),
71                format!("failed to read '{}': {err}", absolute_path.display()),
72            )
73            .to_compile_error()
74            .into();
75        }
76    };
77
78    expand_from_data(data, true)
79}
80
81#[proc_macro]
82pub fn file_bytes(input: TokenStream) -> TokenStream {
83    let input = parse_macro_input!(input as FileMacroInput);
84    let source_file = input.source_file.as_ref().map(LitStr::value);
85    let source_path = input.target_path.value();
86
87    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
88        Ok(path) => path,
89        Err(err) => {
90            return syn::Error::new(input.target_path.span(), err)
91                .to_compile_error()
92                .into();
93        }
94    };
95
96    let data = match fs::read(&absolute_path) {
97        Ok(data) => data,
98        Err(err) => {
99            return syn::Error::new(
100                input.target_path.span(),
101                format!("failed to read '{}': {err}", absolute_path.display()),
102            )
103            .to_compile_error()
104            .into();
105        }
106    };
107
108    expand_from_data(data, false)
109}
110
111fn expand_from_data(data: Vec<u8>, decode_utf8: bool) -> TokenStream {
112    let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
113        Ok(compressed) => compressed,
114        Err(err) => {
115            return syn::Error::new(
116                proc_macro2::Span::call_site(),
117                format!("failed to compress data: {err}"),
118            )
119            .to_compile_error()
120            .into();
121        }
122    };
123
124    let include_zstd_crate = match crate_name("include-zstd") {
125        // In a package with both lib+bin, proc-macros expanded inside the bin
126        // should still target the library crate namespace.
127        Ok(FoundCrate::Itself) => quote!(::include_zstd),
128        Ok(FoundCrate::Name(name)) => {
129            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
130            quote!(::#ident)
131        }
132        Err(_) => quote!(::include_zstd),
133    };
134
135    let expanded = if decode_utf8 {
136        quote! {
137            {
138                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
139                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
140
141                #include_zstd_crate::__private::decode_utf8(
142                    __INCLUDE_ZSTD_CACHE
143                        .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
144                        .as_ref(),
145                )
146            }
147        }
148    } else {
149        quote! {
150            {
151                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
152                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
153
154                __INCLUDE_ZSTD_CACHE
155                    .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
156                    .as_ref()
157            }
158        }
159    };
160
161    expanded.into()
162}
163
164fn resolve_path(source_file: Option<&str>, source_path: &str) -> Result<PathBuf, String> {
165    let target_path = Path::new(source_path);
166    if target_path.is_absolute() {
167        return Ok(target_path.to_path_buf());
168    }
169
170    let source_dir = if let Some(source_file) = source_file {
171        source_dir_from_source_file(source_file, source_path)?
172    } else {
173        source_dir_from_invocation(source_path)?
174    };
175
176    Ok(source_dir.join(target_path))
177}
178
179fn source_dir_from_source_file(source_file: &str, source_path: &str) -> Result<PathBuf, String> {
180    let source_file_path = Path::new(source_file);
181    let source_dir = source_file_path.parent().ok_or_else(|| {
182        format!(
183            "failed to resolve include path '{}': invocation source file '{}' has no parent directory",
184            source_path,
185            source_file_path.display()
186        )
187    })?;
188
189    if source_file_path.is_absolute() {
190        Ok(source_dir.to_path_buf())
191    } else {
192        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
193            .map(PathBuf::from)
194            .map_err(|err| format!("failed to read CARGO_MANIFEST_DIR: {err}"))?;
195        Ok(manifest_dir.join(source_dir))
196    }
197}
198
199fn source_dir_from_invocation(source_path: &str) -> Result<PathBuf, String> {
200    let source_file_path = invocation_source_file();
201    if let Some(source_dir) = source_file_path.parent() {
202        return Ok(source_dir.to_path_buf());
203    }
204
205    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
206        return Ok(PathBuf::from(manifest_dir));
207    }
208
209    std::env::current_dir().map_err(|err| {
210        format!(
211            "failed to resolve include path '{}': no invocation source path and no usable base directory: {err}",
212            source_path
213        )
214    })
215}
216
217fn invocation_source_file() -> PathBuf {
218    // local_file provides the canonical on-disk path when available.
219    proc_macro::Span::call_site()
220        .local_file()
221        .unwrap_or_else(|| PathBuf::from(proc_macro::Span::call_site().file()))
222}