prest_embed_macro/
lib.rs

1#![recursion_limit = "1024"]
2#![forbid(unsafe_code)]
3#[macro_use]
4extern crate quote;
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use std::{collections::BTreeMap, env, path::Path};
10use syn::{parse_macro_input, Data, DeriveInput, Expr, ExprLit, Fields, Lit, Meta, MetaNameValue};
11
12fn embedded(
13    ident: &syn::Ident,
14    relative_folder_path: Option<&str>,
15    absolute_folder_path: String,
16    prefix: Option<&str>,
17    includes: &[String],
18    excludes: &[String],
19) -> syn::Result<TokenStream2> {
20    extern crate prest_embed_utils;
21    use prest_embed_utils::*;
22
23    let mut match_values = BTreeMap::new();
24    let mut list_values = Vec::<String>::new();
25
26    let includes: Vec<&str> = includes.iter().map(AsRef::as_ref).collect();
27    let excludes: Vec<&str> = excludes.iter().map(AsRef::as_ref).collect();
28    for FileEntry {
29        rel_path,
30        full_canonical_path,
31    } in get_files(absolute_folder_path.clone(), &includes, &excludes)
32    {
33        match_values.insert(
34            rel_path.clone(),
35            embed_file(
36                relative_folder_path.clone(),
37                ident,
38                &rel_path,
39                &full_canonical_path,
40            )?,
41        );
42
43        list_values.push(if let Some(prefix) = prefix {
44            format!("{}{}", prefix, rel_path)
45        } else {
46            rel_path
47        });
48    }
49
50    let array_len = list_values.len();
51
52    let not_debug_attr = quote! { #[cfg(any(not(debug_assertions), target_arch = "wasm32"))]};
53
54    let handle_prefix = if let Some(prefix) = prefix {
55        quote! {
56          let file_path = file_path.strip_prefix(#prefix)?;
57        }
58    } else {
59        TokenStream2::new()
60    };
61    let match_values = match_values.into_iter().map(|(path, bytes)| {
62        quote! {
63            (#path, #bytes),
64        }
65    });
66    let value_type = quote! { prest::embed::EmbeddedFile };
67    let get_value = quote! {|idx| ENTRIES[idx].1.clone()};
68
69    Ok(quote! {
70        #not_debug_attr
71        impl #ident {
72            /// Get an embedded file and its metadata.
73            pub fn get(file_path: &str) -> Option<prest::embed::EmbeddedFile> {
74              #handle_prefix
75              let key = file_path.replace("\\", "/");
76              const ENTRIES: &'static [(&'static str, #value_type)] = &[
77                  #(#match_values)*];
78              let position = ENTRIES.binary_search_by_key(&key.as_str(), |entry| entry.0);
79              position.ok().map(#get_value)
80
81            }
82
83            fn names() -> std::slice::Iter<'static, &'static str> {
84                const ITEMS: [&str; #array_len] = [#(#list_values),*];
85                ITEMS.iter()
86            }
87
88            /// Iterates over the file paths in the folder.
89            pub fn iter() -> impl Iterator<Item = std::borrow::Cow<'static, str>> {
90                Self::names().map(|x| std::borrow::Cow::from(*x))
91            }
92        }
93
94        #not_debug_attr
95        impl prest::EmbeddedStruct for #ident {
96          fn get(file_path: &str) -> Option<prest::embed::EmbeddedFile> {
97            #ident::get(file_path)
98          }
99          fn iter() -> prest::embed::__Filenames {
100            prest::embed::__Filenames::Embedded(#ident::names())
101          }
102        }
103    })
104}
105
106fn dynamic(
107    ident: &syn::Ident,
108    folder_path: String,
109    prefix: Option<&str>,
110    includes: &[String],
111    excludes: &[String],
112) -> TokenStream2 {
113    let (handle_prefix, map_iter) = if let Some(prefix) = prefix {
114        (
115            quote! { let file_path = file_path.strip_prefix(#prefix)?; },
116            quote! { std::borrow::Cow::Owned(format!("{}{}", #prefix, e.rel_path)) },
117        )
118    } else {
119        (
120            TokenStream2::new(),
121            quote! { std::borrow::Cow::from(e.rel_path) },
122        )
123    };
124
125    let declare_includes = quote! {
126      const INCLUDES: &[&str] = &[#(#includes),*];
127    };
128
129    let declare_excludes = quote! {
130      const EXCLUDES: &[&str] = &[#(#excludes),*];
131    };
132
133    let canonical_folder_path = Path::new(&folder_path)
134        .canonicalize()
135        .expect("folder path must resolve to an absolute path");
136    let canonical_folder_path = canonical_folder_path
137        .to_str()
138        .expect("absolute folder path must be valid unicode");
139
140    quote! {
141        #[cfg(all(debug_assertions, not(target_arch = "wasm32")))]
142        impl #ident {
143            /// Get an embedded file and its metadata.
144            pub fn get(file_path: &str) -> Option<prest::embed::EmbeddedFile> {
145                #handle_prefix
146
147                #declare_includes
148                #declare_excludes
149
150                let rel_file_path = file_path.replace("\\", "/");
151                let file_path = std::path::Path::new(#folder_path).join(&rel_file_path);
152
153                // Make sure the path requested does not escape the folder path
154                let canonical_file_path = file_path.canonicalize().ok()?;
155                if !canonical_file_path.starts_with(#canonical_folder_path) {
156                    // Tried to request a path that is not in the embedded folder
157                    return None;
158                }
159
160                if prest::embed::is_path_included(&rel_file_path, INCLUDES, EXCLUDES) {
161                  prest::embed::read_file_from_fs(&canonical_file_path).ok()
162                } else {
163                  None
164                }
165            }
166
167            /// Iterates over the file paths in the folder.
168            pub fn iter() -> impl Iterator<Item = std::borrow::Cow<'static, str>> {
169                use std::path::Path;
170
171                #declare_includes
172                #declare_excludes
173
174                prest::embed::get_files(String::from(#folder_path), INCLUDES, EXCLUDES)
175                    .map(|e| #map_iter)
176            }
177        }
178
179        #[cfg(all(debug_assertions, not(target_arch = "wasm32")))]
180        impl prest::EmbeddedStruct for #ident {
181          fn get(file_path: &str) -> Option<prest::embed::EmbeddedFile> {
182            #ident::get(file_path)
183          }
184          fn iter() -> prest::embed::__Filenames {
185            // the return type of iter() is unnamable, so we have to box it
186            prest::embed::__Filenames::Dynamic(Box::new(#ident::iter()))
187          }
188        }
189    }
190}
191
192fn generate_assets(
193    ident: &syn::Ident,
194    relative_folder_path: Option<&str>,
195    absolute_folder_path: String,
196    prefix: Option<String>,
197    includes: Vec<String>,
198    excludes: Vec<String>,
199) -> syn::Result<TokenStream2> {
200    let embedded_impl = embedded(
201        ident,
202        relative_folder_path,
203        absolute_folder_path.clone(),
204        prefix.as_deref(),
205        &includes,
206        &excludes,
207    );
208    let embedded_impl = embedded_impl?;
209    let dynamic_impl = dynamic(
210        ident,
211        absolute_folder_path,
212        prefix.as_deref(),
213        &includes,
214        &excludes,
215    );
216
217    Ok(quote! {
218        #embedded_impl
219        #dynamic_impl
220    })
221}
222
223fn embed_file(
224    _folder_path: Option<&str>,
225    _ident: &syn::Ident,
226    _rel_path: &str,
227    full_canonical_path: &str,
228) -> syn::Result<TokenStream2> {
229    let file = prest_embed_utils::read_file_from_fs(Path::new(full_canonical_path))
230        .expect("File should be readable");
231    let hash = file.metadata.sha256_hash();
232    let last_modified = match file.metadata.last_modified() {
233        Some(last_modified) => quote! { Some(#last_modified) },
234        None => quote! { None },
235    };
236    let mimetype_tokens = {
237        let mt = file.metadata.mimetype();
238        quote! { , #mt }
239    };
240
241    let embedding_code = quote! {
242      const BYTES: &'static [u8] = include_bytes!(#full_canonical_path);
243    };
244
245    let closure_args = quote! {};
246    Ok(quote! {
247         #closure_args {
248          #embedding_code
249
250          prest::embed::EmbeddedFile {
251              data: std::borrow::Cow::Borrowed(&BYTES),
252              metadata: prest::embed::EmbeddedFileMetadata::__rust_embed_new([#(#hash),*], #last_modified #mimetype_tokens)
253          }
254        }
255    })
256}
257
258/// Find all pairs of the `name = "value"` attribute from the derive input
259fn find_attribute_values(ast: &syn::DeriveInput, attr_name: &str) -> Vec<String> {
260    ast.attrs
261        .iter()
262        .filter(|value| value.path().is_ident(attr_name))
263        .filter_map(|attr| match &attr.meta {
264            Meta::NameValue(MetaNameValue {
265                value:
266                    Expr::Lit(ExprLit {
267                        lit: Lit::Str(val), ..
268                    }),
269                ..
270            }) => Some(val.value()),
271            _ => None,
272        })
273        .collect()
274}
275
276fn impl_rust_embed(ast: &syn::DeriveInput) -> syn::Result<TokenStream2> {
277    match ast.data {
278        Data::Struct(ref data) => match data.fields {
279            Fields::Unit => {}
280            _ => {
281                return Err(syn::Error::new_spanned(
282                    ast,
283                    "Embed can only be derived for unit structs",
284                ))
285            }
286        },
287        _ => {
288            return Err(syn::Error::new_spanned(
289                ast,
290                "Embed can only be derived for unit structs",
291            ))
292        }
293    };
294
295    let mut folder_paths = find_attribute_values(ast, "folder");
296    if folder_paths.len() != 1 {
297        return Err(syn::Error::new_spanned(
298            ast,
299            "Embed must contain one attribute like this #[folder = \"examples/public/\"]",
300        ));
301    }
302    let folder_path = folder_paths.remove(0);
303
304    let prefix = find_attribute_values(ast, "prefix").into_iter().next();
305    let includes = find_attribute_values(ast, "include");
306    let excludes = find_attribute_values(ast, "exclude");
307
308    let folder_path = shellexpand::full(&folder_path)
309        .map_err(|v| syn::Error::new_spanned(ast, v.to_string()))?
310        .to_string();
311
312    // Base relative paths on the Cargo.toml location
313    let (relative_path, absolute_folder_path) = if Path::new(&folder_path).is_relative() {
314        let absolute_path = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap())
315            .join(&folder_path)
316            .to_str()
317            .unwrap()
318            .to_owned();
319        (Some(folder_path.clone()), absolute_path)
320    } else {
321        (None, folder_path)
322    };
323
324    if !Path::new(&absolute_folder_path).exists() {
325        let message = format!(
326            "The embedded folder '{}' does not exist. cwd: '{}'",
327            absolute_folder_path,
328            std::env::current_dir().unwrap().to_str().unwrap()
329        );
330
331        return Err(syn::Error::new_spanned(ast, message));
332    };
333
334    generate_assets(
335        &ast.ident,
336        relative_path.as_deref(),
337        absolute_folder_path,
338        prefix,
339        includes,
340        excludes,
341    )
342}
343
344/// Derive macro that embeds files and provides access to them through the given struct
345#[proc_macro_derive(Embed, attributes(folder, prefix, include, exclude))]
346pub fn derive_input_object(input: TokenStream) -> TokenStream {
347    let ast = parse_macro_input!(input as DeriveInput);
348    match impl_rust_embed(&ast) {
349        Ok(ok) => ok.into(),
350        Err(e) => e.to_compile_error().into(),
351    }
352}