1use proc_macro::TokenStream;
2use proc_macro_crate::{FoundCrate, crate_name};
3use quote::quote;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::time::UNIX_EPOCH;
7use syn::parse::{Parse, ParseStream};
8use syn::{LitByteStr, LitStr, Token, parse_macro_input};
9
10struct FileMacroInput {
11 source_file: Option<LitStr>,
12 target_path: LitStr,
13}
14
15impl Parse for FileMacroInput {
16 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
17 let first: LitStr = input.parse()?;
18 if input.is_empty() {
19 return Ok(Self {
20 source_file: None,
21 target_path: first,
22 });
23 }
24
25 let _comma: Token![,] = input.parse()?;
26 let second: LitStr = input.parse()?;
27 if !input.is_empty() {
28 return Err(input.error("expected one string literal path or 'source_file, path'"));
29 }
30
31 Ok(Self {
32 source_file: Some(first),
33 target_path: second,
34 })
35 }
36}
37
38#[proc_macro]
39pub fn r#str(input: TokenStream) -> TokenStream {
40 let value = parse_macro_input!(input as LitStr);
41 let data = value.value().into_bytes();
42 expand_from_data(data, true)
43}
44
45#[proc_macro]
46pub fn bytes(input: TokenStream) -> TokenStream {
47 let value = parse_macro_input!(input as LitByteStr);
48 let data = value.value();
49 expand_from_data(data, false)
50}
51
52#[proc_macro]
53pub fn file_str(input: TokenStream) -> TokenStream {
54 let input = parse_macro_input!(input as FileMacroInput);
55 let source_file = input.source_file.as_ref().map(LitStr::value);
56 let source_path = input.target_path.value();
57
58 let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
59 Ok(path) => path,
60 Err(err) => {
61 return syn::Error::new(input.target_path.span(), err)
62 .to_compile_error()
63 .into();
64 }
65 };
66
67 let data = match fs::read(&absolute_path) {
68 Ok(data) => data,
69 Err(err) => {
70 return syn::Error::new(
71 input.target_path.span(),
72 format!("failed to read '{}': {err}", absolute_path.display()),
73 )
74 .to_compile_error()
75 .into();
76 }
77 };
78
79 expand_from_data(data, true)
80}
81
82#[proc_macro]
83pub fn file_bytes(input: TokenStream) -> TokenStream {
84 let input = parse_macro_input!(input as FileMacroInput);
85 let source_file = input.source_file.as_ref().map(LitStr::value);
86 let source_path = input.target_path.value();
87
88 let absolute_path = match resolve_path(source_file.as_deref(), &source_path) {
89 Ok(path) => path,
90 Err(err) => {
91 return syn::Error::new(input.target_path.span(), err)
92 .to_compile_error()
93 .into();
94 }
95 };
96
97 let data = match fs::read(&absolute_path) {
98 Ok(data) => data,
99 Err(err) => {
100 return syn::Error::new(
101 input.target_path.span(),
102 format!("failed to read '{}': {err}", absolute_path.display()),
103 )
104 .to_compile_error()
105 .into();
106 }
107 };
108
109 expand_from_data(data, false)
110}
111
112#[proc_macro]
113pub fn include_zstd(input: TokenStream) -> TokenStream {
114 let path = parse_macro_input!(input as LitStr);
115 let source_path = path.value();
116
117 let source_file_abs = invocation_source_file_abs();
120 let source_dir = source_file_abs.parent().unwrap_or(&source_file_abs);
121
122 let absolute_path = if Path::new(&source_path).is_absolute() {
123 PathBuf::from(&source_path)
124 } else {
125 source_dir.join(&source_path)
126 };
127
128 let (metadata, absolute_path) = match fs::metadata(&absolute_path) {
130 Ok(m) => (m, absolute_path),
131 Err(_) => {
132 match find_file_in_candidates(&source_path, source_dir) {
134 Some(found_path) => match fs::metadata(&found_path) {
135 Ok(m) => (m, found_path),
136 Err(err) => {
137 return syn::Error::new(
138 path.span(),
139 format!("failed to read metadata '{}': {err}", found_path.display()),
140 )
141 .to_compile_error()
142 .into();
143 }
144 },
145 None => {
146 return syn::Error::new(
147 path.span(),
148 format!(
149 "failed to read metadata '{}': file not found",
150 absolute_path.display()
151 ),
152 )
153 .to_compile_error()
154 .into();
155 }
156 }
157 }
158 };
159
160 let data = match fs::read(&absolute_path) {
161 Ok(d) => d,
162 Err(err) => {
163 return syn::Error::new(
164 path.span(),
165 format!("failed to read file '{}': {err}", absolute_path.display()),
166 )
167 .to_compile_error()
168 .into();
169 }
170 };
171
172 let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
173 Ok(c) => c,
174 Err(err) => {
175 return syn::Error::new(
176 proc_macro2::Span::call_site(),
177 format!("failed to compress data: {err}"),
178 )
179 .to_compile_error()
180 .into();
181 }
182 };
183
184 let len = metadata.len();
185 let is_file = metadata.is_file();
186 let is_dir = metadata.is_dir();
187
188 let modified = timestamp_to_code(&metadata.modified());
189 let accessed = timestamp_to_code(&metadata.accessed());
190 let created = timestamp_to_code(&metadata.created());
191
192 let include_zstd_crate = match crate_name("include-zstd") {
193 Ok(FoundCrate::Itself) => quote!(::include_zstd),
194 Ok(FoundCrate::Name(name)) => {
195 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
196 quote!(::#ident)
197 }
198 Err(_) => quote!(::include_zstd),
199 };
200
201 let expanded = quote! {
202 {
203 static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
204
205 #include_zstd_crate::__private::create_zstd_asset(
206 #include_zstd_crate::ZstdMetadata {
207 len: #len,
208 modified: #modified,
209 accessed: #accessed,
210 created: #created,
211 is_file: #is_file,
212 is_dir: #is_dir,
213 },
214 __INCLUDE_ZSTD_COMPRESSED,
215 )
216 }
217 };
218
219 expanded.into()
220}
221
222fn timestamp_to_code(
223 time: &Result<std::time::SystemTime, std::io::Error>,
224) -> proc_macro2::TokenStream {
225 match time {
226 Ok(t) => match t.duration_since(UNIX_EPOCH) {
227 Ok(d) => {
228 let secs = d.as_secs();
229 let nanos = d.subsec_nanos();
230 quote!(Some(std::time::UNIX_EPOCH + std::time::Duration::new(#secs, #nanos)))
231 }
232 Err(_) => quote!(None),
233 },
234 Err(_) => quote!(None),
235 }
236}
237
238fn expand_from_data(data: Vec<u8>, decode_utf8: bool) -> TokenStream {
239 let compressed = match zstd::stream::encode_all(data.as_slice(), 0) {
240 Ok(compressed) => compressed,
241 Err(err) => {
242 return syn::Error::new(
243 proc_macro2::Span::call_site(),
244 format!("failed to compress data: {err}"),
245 )
246 .to_compile_error()
247 .into();
248 }
249 };
250
251 let include_zstd_crate = match crate_name("include-zstd") {
252 Ok(FoundCrate::Itself) => quote!(::include_zstd),
255 Ok(FoundCrate::Name(name)) => {
256 let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
257 quote!(::#ident)
258 }
259 Err(_) => quote!(::include_zstd),
260 };
261
262 let expanded = if decode_utf8 {
263 quote! {
264 {
265 static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
266 static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
267
268 #include_zstd_crate::__private::decode_utf8(
269 __INCLUDE_ZSTD_CACHE
270 .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
271 .as_ref(),
272 )
273 }
274 }
275 } else {
276 quote! {
277 {
278 static __INCLUDE_ZSTD_COMPRESSED: &[u8] = &[#(#compressed),*];
279 static __INCLUDE_ZSTD_CACHE: ::std::sync::OnceLock<::std::boxed::Box<[u8]>> = ::std::sync::OnceLock::new();
280
281 __INCLUDE_ZSTD_CACHE
282 .get_or_init(|| #include_zstd_crate::__private::decompress_bytes(__INCLUDE_ZSTD_COMPRESSED))
283 .as_ref()
284 }
285 }
286 };
287
288 expanded.into()
289}
290
291fn resolve_path(source_file: Option<&str>, source_path: &str) -> Result<PathBuf, String> {
292 let target_path = Path::new(source_path);
293 if target_path.is_absolute() {
294 return Ok(target_path.to_path_buf());
295 }
296
297 let source_file_abs = if let Some(source_file) = source_file {
301 absolutize_source_file(Path::new(source_file))
302 } else {
303 invocation_source_file_abs()
304 };
305
306 let source_dir = source_file_abs.parent().ok_or_else(|| {
307 format!(
308 "failed to resolve include path '{}': invocation source file '{}' has no parent directory",
309 source_path,
310 source_file_abs.display()
311 )
312 })?;
313
314 let absolute_path = source_dir.join(target_path);
315
316 if !absolute_path.exists() {
319 if let Some(found_path) = find_file_in_candidates(source_path, source_dir) {
320 return Ok(found_path);
321 }
322 }
323
324 Ok(absolute_path)
325}
326
327fn invocation_source_file_abs() -> PathBuf {
330 let call_site = proc_macro::Span::call_site();
331
332 if let Some(path) = call_site.local_file() {
336 if path.extension().is_some() || path.is_file() {
338 return path;
339 }
340 }
343
344 let file = call_site.file();
347 let file_path = Path::new(&file);
348
349 if file_path.is_absolute() {
350 return file_path.to_path_buf();
351 }
352
353 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
356 let manifest_path = PathBuf::from(&manifest_dir);
357 let candidate = manifest_path.join(file_path);
358
359 if candidate.parent().map_or(false, |p| p.exists()) {
361 return candidate;
362 }
363 }
364
365 if let Ok(cwd) = std::env::current_dir() {
367 let candidate = cwd.join(file_path);
368 if candidate.parent().map_or(false, |p| p.exists()) {
369 return candidate;
370 }
371 }
372
373 file_path.to_path_buf()
375}
376
377fn find_file_in_candidates(relative_path: &str, source_dir: &Path) -> Option<PathBuf> {
379 let file_name = Path::new(relative_path).file_name()?;
380
381 let mut candidates = vec![
388 PathBuf::from(file_name),
389 PathBuf::from("examples").join(file_name),
390 PathBuf::from("src").join(file_name),
391 source_dir.join(file_name),
392 ];
393
394 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
396 let manifest_path = PathBuf::from(&manifest_dir);
397 candidates.push(manifest_path.join(file_name));
398 candidates.push(manifest_path.join("examples").join(file_name));
399 candidates.push(manifest_path.join("src").join(file_name));
400 }
401
402 for candidate in candidates {
403 if candidate.exists() && candidate.is_file() {
404 return Some(candidate);
405 }
406 }
407
408 None
409}
410
411fn absolutize_source_file(source_file: &Path) -> PathBuf {
412 if source_file.is_absolute() {
413 return source_file.to_path_buf();
414 }
415
416 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
417 return PathBuf::from(manifest_dir).join(source_file);
418 }
419
420 source_file.to_path_buf()
421}