Skip to main content

include_zstd_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::time::UNIX_EPOCH;
7use syn::parse::{Parse, ParseStream};
8use syn::{LitByteStr, LitStr, Token, parse_macro_input};
9
10struct FileMacroInput {
11    source_file: Option<LitStr>,
12    target_path: LitStr,
13}
14
15impl Parse for FileMacroInput {
16    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
17        let first: LitStr = input.parse()?;
18        if input.is_empty() {
19            return Ok(Self {
20                source_file: None,
21                target_path: first,
22            });
23        }
24
25        let _comma: Token![,] = input.parse()?;
26        let second: LitStr = input.parse()?;
27        if !input.is_empty() {
28            return Err(input.error("expected one string literal path or 'source_file, path'"));
29        }
30
31        Ok(Self {
32            source_file: Some(first),
33            target_path: second,
34        })
35    }
36}
37
38#[proc_macro]
39pub fn r#str(input: TokenStream) -> TokenStream {
40    let value = parse_macro_input!(input as LitStr);
41    let data = value.value().into_bytes();
42    expand_from_data(data, true)
43}
44
45#[proc_macro]
46pub fn bytes(input: TokenStream) -> TokenStream {
47    let value = parse_macro_input!(input as LitByteStr);
48    let data = value.value();
49    expand_from_data(data, false)
50}
51
52#[proc_macro]
53pub fn file_str(input: TokenStream) -> TokenStream {
54    let input = parse_macro_input!(input as FileMacroInput);
55    let source_file = input.source_file.as_ref().map(LitStr::value);
56    let source_path = input.target_path.value();
57
58    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
59        Ok(path) => path,
60        Err(err) => {
61            return syn::Error::new(input.target_path.span(), err)
62                .to_compile_error()
63                .into();
64        }
65    };
66
67    let data = match fs::read(&absolute_path) {
68        Ok(data) => data,
69        Err(err) => {
70            return syn::Error::new(
71                input.target_path.span(),
72                format!("failed to read '{}': {err}", absolute_path.display()),
73            )
74            .to_compile_error()
75            .into();
76        }
77    };
78
79    expand_from_data(data, true)
80}
81
82#[proc_macro]
83pub fn file_bytes(input: TokenStream) -> TokenStream {
84    let input = parse_macro_input!(input as FileMacroInput);
85    let source_file = input.source_file.as_ref().map(LitStr::value);
86    let source_path = input.target_path.value();
87
88    let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
89        Ok(path) => path,
90        Err(err) => {
91            return syn::Error::new(input.target_path.span(), err)
92                .to_compile_error()
93                .into();
94        }
95    };
96
97    let data = match fs::read(&absolute_path) {
98        Ok(data) => data,
99        Err(err) => {
100            return syn::Error::new(
101                input.target_path.span(),
102                format!("failed to read '{}': {err}", absolute_path.display()),
103            )
104            .to_compile_error()
105            .into();
106        }
107    };
108
109    expand_from_data(data, false)
110}
111
112#[proc_macro]
113pub fn include_zstd(input: TokenStream) -> TokenStream {
114    let path = parse_macro_input!(input as LitStr);
115    let source_path = path.value();
116
117    // 对于 include_zstd! 宏,直接使用 invocation_source_file_abs 获取源文件路径
118    // 确保在 examples/ 目录中也能正确解析相对路径
119    let source_file_abs = invocation_source_file_abs();
120    let source_dir = source_file_abs.parent().unwrap_or(&source_file_abs);
121
122    let absolute_path = if Path::new(&source_path).is_absolute() {
123        PathBuf::from(&source_path)
124    } else {
125        source_dir.join(&source_path)
126    };
127
128    // 尝试读取文件元数据,如果失败则尝试在其他常见位置查找
129    let (metadata, absolute_path) = match fs::metadata(&absolute_path) {
130        Ok(m) => (m, absolute_path),
131        Err(_) => {
132            // Fallback: try to find the file in common locations
133            match find_file_in_candidates(&source_path, source_dir) {
134                Some(found_path) => match fs::metadata(&found_path) {
135                    Ok(m) => (m, found_path),
136                    Err(err) => {
137                        return syn::Error::new(
138                            path.span(),
139                            format!("failed to read metadata '{}': {err}", found_path.display()),
140                        )
141                        .to_compile_error()
142                        .into();
143                    }
144                },
145                None => {
146                    return syn::Error::new(
147                        path.span(),
148                        format!(
149                            "failed to read metadata '{}': file not found",
150                            absolute_path.display()
151                        ),
152                    )
153                    .to_compile_error()
154                    .into();
155                }
156            }
157        }
158    };
159
160    let data = match fs::read(&absolute_path) {
161        Ok(d) => d,
162        Err(err) => {
163            return syn::Error::new(
164                path.span(),
165                format!("failed to read file '{}': {err}", absolute_path.display()),
166            )
167            .to_compile_error()
168            .into();
169        }
170    };
171
172    let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
173        Ok(c) => c,
174        Err(err) => {
175            return syn::Error::new(
176                proc_macro2::Span::call_site(),
177                format!("failed to compress data: {err}"),
178            )
179            .to_compile_error()
180            .into();
181        }
182    };
183
184    let len = metadata.len();
185    let is_file = metadata.is_file();
186    let is_dir = metadata.is_dir();
187
188    let modified = timestamp_to_code(&metadata.modified());
189    let accessed = timestamp_to_code(&metadata.accessed());
190    let created = timestamp_to_code(&metadata.created());
191
192    let include_zstd_crate = match crate_name("include-zstd") {
193        Ok(FoundCrate::Itself) => quote!(::include_zstd),
194        Ok(FoundCrate::Name(name)) => {
195            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
196            quote!(::#ident)
197        }
198        Err(_) => quote!(::include_zstd),
199    };
200
201    let expanded = quote! {
202        {
203            static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
204
205            #include_zstd_crate::__private::create_zstd_asset(
206                #include_zstd_crate::ZstdMetadata {
207                    len: #len,
208                    modified: #modified,
209                    accessed: #accessed,
210                    created: #created,
211                    is_file: #is_file,
212                    is_dir: #is_dir,
213                },
214                __INCLUDE_ZSTD_COMPRESSED,
215            )
216        }
217    };
218
219    expanded.into()
220}
221
222fn timestamp_to_code(
223    time: &Result<std::time::SystemTime, std::io::Error>,
224) -> proc_macro2::TokenStream {
225    match time {
226        Ok(t) => match t.duration_since(UNIX_EPOCH) {
227            Ok(d) => {
228                let secs = d.as_secs();
229                let nanos = d.subsec_nanos();
230                quote!(Some(std::time::UNIX_EPOCH + std::time::Duration::new(#secs, #nanos)))
231            }
232            Err(_) => quote!(None),
233        },
234        Err(_) => quote!(None),
235    }
236}
237
238fn expand_from_data(data: Vec<u8>, decode_utf8: bool) -> TokenStream {
239    let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
240        Ok(compressed) => compressed,
241        Err(err) => {
242            return syn::Error::new(
243                proc_macro2::Span::call_site(),
244                format!("failed to compress data: {err}"),
245            )
246            .to_compile_error()
247            .into();
248        }
249    };
250
251    let include_zstd_crate = match crate_name("include-zstd") {
252        // In a package with both lib+bin, proc-macros expanded inside the bin
253        // should still target the library crate namespace.
254        Ok(FoundCrate::Itself) => quote!(::include_zstd),
255        Ok(FoundCrate::Name(name)) => {
256            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
257            quote!(::#ident)
258        }
259        Err(_) => quote!(::include_zstd),
260    };
261
262    let expanded = if decode_utf8 {
263        quote! {
264            {
265                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
266                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
267
268                #include_zstd_crate::__private::decode_utf8(
269                    __INCLUDE_ZSTD_CACHE
270                        .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
271                        .as_ref(),
272                )
273            }
274        }
275    } else {
276        quote! {
277            {
278                static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
279                static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
280
281                __INCLUDE_ZSTD_CACHE
282                    .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
283                    .as_ref()
284            }
285        }
286    };
287
288    expanded.into()
289}
290
291fn resolve_path(source_file: Option<&str>, source_path: &str) -> Result<PathBuf, String> {
292    let target_path = Path::new(source_path);
293    if target_path.is_absolute() {
294        return Ok(target_path.to_path_buf());
295    }
296
297    // Match `include_str!` semantics: always resolve relative paths against the
298    // parent directory of the invocation's source file, using an absolute path
299    // so the result is independent of the compiler's current working directory.
300    let source_file_abs = if let Some(source_file) = source_file {
301        absolutize_source_file(Path::new(source_file))
302    } else {
303        invocation_source_file_abs()
304    };
305
306    let source_dir = source_file_abs.parent().ok_or_else(|| {
307        format!(
308            "failed to resolve include path '{}': invocation source file '{}' has no parent directory",
309            source_path,
310            source_file_abs.display()
311        )
312    })?;
313
314    let absolute_path = source_dir.join(target_path);
315
316    // If the resolved path doesn't exist, try to find it in candidate locations
317    // (handles LSP analysis where path resolution may be inaccurate)
318    if !absolute_path.exists() {
319        if let Some(found_path) = find_file_in_candidates(source_path, source_dir) {
320            return Ok(found_path);
321        }
322    }
323
324    Ok(absolute_path)
325}
326
327/// Return the absolute path of the source file that contains the macro
328/// invocation, mirroring how `include_str!` locates its base directory.
329fn invocation_source_file_abs() -> PathBuf {
330    let call_site = proc_macro::Span::call_site();
331
332    // `local_file()` returns the canonical absolute on-disk path when the span
333    // originates from a real source file; this is the same information rustc
334    // uses internally to resolve `include_str!`.
335    if let Some(path) = call_site.local_file() {
336        // If local_file() returns a file path (not a directory), return it
337        if path.extension().is_some() || path.is_file() {
338            return path;
339        }
340        // If it returns a directory, it's likely from LSP analysis
341        // Fall through to try other methods
342    }
343
344    // Fallback: `Span::file()` typically yields a path relative to the crate
345    // root (e.g. "src/lib.rs" or "examples/example.rs").
346    let file = call_site.file();
347    let file_path = Path::new(&file);
348
349    if file_path.is_absolute() {
350        return file_path.to_path_buf();
351    }
352
353    // Use CARGO_MANIFEST_DIR (crate root) to anchor relative paths.
354    // In workspace projects, this points to the specific crate's directory.
355    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
356        let manifest_path = PathBuf::from(&manifest_dir);
357        let candidate = manifest_path.join(file_path);
358
359        // Verify the candidate path's parent directory exists
360        if candidate.parent().map_or(false, |p| p.exists()) {
361            return candidate;
362        }
363    }
364
365    // Last resort: use current working directory
366    if let Ok(cwd) = std::env::current_dir() {
367        let candidate = cwd.join(file_path);
368        if candidate.parent().map_or(false, |p| p.exists()) {
369            return candidate;
370        }
371    }
372
373    // Final fallback: just return the relative path
374    file_path.to_path_buf()
375}
376
377/// Try to find a file in common candidate locations when standard path resolution fails.
378fn find_file_in_candidates(relative_path: &str, source_dir: &Path) -> Option<PathBuf> {
379    let file_name = Path::new(relative_path).file_name()?;
380
381    // Candidate locations to search:
382    // 1. Current directory (where cargo is invoked)
383    // 2. examples/ directory under current directory
384    // 3. src/ directory under current directory
385    // 4. Same directory as source file
386    // 5. CARGO_MANIFEST_DIR/examples/ (for LSP analysis in workspace projects)
387    let mut candidates = vec![
388        PathBuf::from(file_name),
389        PathBuf::from("examples").join(file_name),
390        PathBuf::from("src").join(file_name),
391        source_dir.join(file_name),
392    ];
393
394    // Add CARGO_MANIFEST_DIR based paths
395    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
396        let manifest_path = PathBuf::from(&manifest_dir);
397        candidates.push(manifest_path.join(file_name));
398        candidates.push(manifest_path.join("examples").join(file_name));
399        candidates.push(manifest_path.join("src").join(file_name));
400    }
401
402    for candidate in candidates {
403        if candidate.exists() && candidate.is_file() {
404            return Some(candidate);
405        }
406    }
407
408    None
409}
410
411fn absolutize_source_file(source_file: &Path) -> PathBuf {
412    if source_file.is_absolute() {
413        return source_file.to_path_buf();
414    }
415
416    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
417        return PathBuf::from(manifest_dir).join(source_file);
418    }
419
420    source_file.to_path_buf()
421}