include_directory_macros/
lib.rs

1//! Implementation details of the `include_directory`.
2//!
3//! You probably don't want to use this crate directly.
4#![cfg_attr(feature = "nightly", feature(track_path, proc_macro_tracked_env))]
5
6use proc_macro::{TokenStream, TokenTree};
7use proc_macro2::Literal;
8use quote::quote;
9use std::{
10    error::Error,
11    fmt::{self, Display, Formatter},
12    path::{Path, PathBuf},
13    time::SystemTime,
14};
15
16/// Embed the contents of a directory in your crate.
17#[proc_macro]
18pub fn include_directory(input: TokenStream) -> TokenStream {
19    let tokens: Vec<_> = input.into_iter().collect();
20
21    let path = match tokens.as_slice() {
22        [TokenTree::Literal(lit)] => unwrap_string_literal(lit),
23        _ => panic!("This macro only accepts a single, non-empty string argument"),
24    };
25
26    let path = resolve_path(&path, get_env).unwrap();
27
28    expand_dir(&path, &path).into()
29}
30
31fn unwrap_string_literal(lit: &proc_macro::Literal) -> String {
32    let mut repr = lit.to_string();
33    if !repr.starts_with('"') || !repr.ends_with('"') {
34        panic!("This macro only accepts a single, non-empty string argument")
35    }
36
37    repr.remove(0);
38    repr.pop();
39
40    repr
41}
42
43fn expand_dir(root: &Path, path: &Path) -> proc_macro2::TokenStream {
44    let children = read_dir(path).unwrap_or_else(|e| {
45        panic!(
46            "Unable to read the entries in \"{}\": {}",
47            path.display(),
48            e
49        )
50    });
51
52    let mut child_tokens = Vec::new();
53
54    for child in children {
55        if child.is_dir() {
56            let tokens = expand_dir(root, &child);
57            child_tokens.push(quote! {
58                include_directory::DirEntry::Dir(#tokens)
59            });
60        } else if child.is_file() {
61            let tokens = expand_file(root, &child);
62            child_tokens.push(quote! {
63                include_directory::DirEntry::File(#tokens)
64            });
65        } else {
66            panic!("\"{}\" is neither a file nor a directory", child.display());
67        }
68    }
69
70    let path = normalize_path(root, path);
71
72    quote! {
73        include_directory::Dir::new(#path, &[ #(#child_tokens),* ])
74    }
75}
76
77fn expand_file(root: &Path, path: &Path) -> proc_macro2::TokenStream {
78    let abs = path
79        .canonicalize()
80        .unwrap_or_else(|e| panic!("failed to resolve \"{}\": {}", path.display(), e));
81    let literal = match abs.to_str() {
82        Some(abs) => quote!(include_bytes!(#abs)),
83        None => {
84            let contents = read_file(path);
85            let literal = Literal::byte_string(&contents);
86            quote!(#literal)
87        }
88    };
89
90    let normalized_path = normalize_path(root, path);
91
92    let mimetype = new_mime_guess::from_path(&normalized_path)
93        .first_or_text_plain()
94        .to_string();
95    let mimetype = mimetype.as_str();
96
97    let tokens = quote! {
98        include_directory::File::new(#normalized_path, #literal, #mimetype)
99    };
100
101    match metadata(path) {
102        Some(metadata) => quote!(#tokens.with_metadata(#metadata)),
103        None => tokens,
104    }
105}
106
107fn metadata(path: &Path) -> Option<proc_macro2::TokenStream> {
108    fn to_unix(t: SystemTime) -> u64 {
109        t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
110    }
111
112    if !cfg!(feature = "metadata") {
113        return None;
114    }
115
116    let meta = path.metadata().ok()?;
117    let accessed = meta.accessed().map(to_unix).ok()?;
118    let created = meta.created().map(to_unix).ok()?;
119    let modified = meta.modified().map(to_unix).ok()?;
120
121    Some(quote! {
122        include_directory::Metadata::new(
123            std::time::Duration::from_secs(#accessed),
124            std::time::Duration::from_secs(#created),
125            std::time::Duration::from_secs(#modified),
126        )
127    })
128}
129
130/// Make sure that paths use the same separator regardless of whether the host
131/// machine is Windows or Linux.
132fn normalize_path(root: &Path, path: &Path) -> String {
133    let stripped = path
134        .strip_prefix(root)
135        .expect("Should only ever be called using paths inside the root path");
136    let as_string = stripped.to_string_lossy();
137
138    as_string.replace('\\', "/")
139}
140
141fn read_dir(dir: &Path) -> Result<Vec<PathBuf>, Box<dyn Error>> {
142    if !dir.is_dir() {
143        panic!("\"{}\" is not a directory", dir.display());
144    }
145
146    track_path(dir);
147
148    let mut paths = Vec::new();
149
150    for entry in dir.read_dir()? {
151        let entry = entry?;
152        paths.push(entry.path());
153    }
154
155    paths.sort();
156
157    Ok(paths)
158}
159
160fn read_file(path: &Path) -> Vec<u8> {
161    track_path(path);
162    std::fs::read(path).unwrap_or_else(|e| panic!("Unable to read \"{}\": {}", path.display(), e))
163}
164
165fn resolve_path(
166    raw: &str,
167    get_env: impl Fn(&str) -> Option<String>,
168) -> Result<PathBuf, Box<dyn Error>> {
169    let mut unprocessed = raw;
170    let mut resolved = String::new();
171
172    while let Some(dollar_sign) = unprocessed.find('$') {
173        let (head, tail) = unprocessed.split_at(dollar_sign);
174        resolved.push_str(head);
175
176        match parse_identifier(&tail[1..]) {
177            Some((variable, rest)) => {
178                let value = get_env(variable).ok_or_else(|| MissingVariable {
179                    variable: variable.to_string(),
180                })?;
181                resolved.push_str(&value);
182                unprocessed = rest;
183            }
184            None => {
185                return Err(UnableToParseVariable { rest: tail.into() }.into());
186            }
187        }
188    }
189    resolved.push_str(unprocessed);
190
191    Ok(PathBuf::from(resolved))
192}
193
194#[derive(Debug, PartialEq)]
195struct MissingVariable {
196    variable: String,
197}
198
199impl Error for MissingVariable {}
200
201impl Display for MissingVariable {
202    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
203        write!(f, "Unable to resolve ${}", self.variable)
204    }
205}
206
207#[derive(Debug, PartialEq)]
208struct UnableToParseVariable {
209    rest: String,
210}
211
212impl Error for UnableToParseVariable {}
213
214impl Display for UnableToParseVariable {
215    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
216        write!(f, "Unable to parse a variable from \"{}\"", self.rest)
217    }
218}
219
220fn parse_identifier(text: &str) -> Option<(&str, &str)> {
221    let mut calls = 0;
222
223    let (head, tail) = take_while(text, |c| {
224        calls += 1;
225
226        match c {
227            '_' => true,
228            letter if letter.is_ascii_alphabetic() => true,
229            digit if digit.is_ascii_digit() && calls > 1 => true,
230            _ => false,
231        }
232    });
233
234    if head.is_empty() {
235        None
236    } else {
237        Some((head, tail))
238    }
239}
240
241fn take_while(s: &str, mut predicate: impl FnMut(char) -> bool) -> (&str, &str) {
242    let mut index = 0;
243
244    for c in s.chars() {
245        if predicate(c) {
246            index += c.len_utf8();
247        } else {
248            break;
249        }
250    }
251
252    s.split_at(index)
253}
254
255#[cfg(feature = "nightly")]
256fn get_env(variable: &str) -> Option<String> {
257    proc_macro::tracked_env::var(variable).ok()
258}
259
260#[cfg(not(feature = "nightly"))]
261fn get_env(variable: &str) -> Option<String> {
262    std::env::var(variable).ok()
263}
264
265fn track_path(_path: &Path) {
266    #[cfg(feature = "nightly")]
267    proc_macro::tracked_path::path(_path.to_string_lossy());
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn resolve_path_with_no_environment_variables() {
276        let path = "./file.txt";
277
278        let resolved = resolve_path(path, |_| unreachable!()).unwrap();
279
280        assert_eq!(resolved.to_str().unwrap(), path);
281    }
282
283    #[test]
284    fn simple_environment_variable() {
285        let path = "./$VAR";
286
287        let resolved = resolve_path(path, |name| {
288            assert_eq!(name, "VAR");
289            Some("file.txt".to_string())
290        })
291        .unwrap();
292
293        assert_eq!(resolved.to_str().unwrap(), "./file.txt");
294    }
295
296    #[test]
297    fn dont_resolve_recursively() {
298        let path = "./$TOP_LEVEL.txt";
299
300        let resolved = resolve_path(path, |name| match name {
301            "TOP_LEVEL" => Some("$NESTED".to_string()),
302            "$NESTED" => unreachable!("Shouldn't resolve recursively"),
303            _ => unreachable!(),
304        })
305        .unwrap();
306
307        assert_eq!(resolved.to_str().unwrap(), "./$NESTED.txt");
308    }
309
310    #[test]
311    fn parse_valid_identifiers() {
312        let inputs = vec![
313            ("a", "a"),
314            ("a_", "a_"),
315            ("_asf", "_asf"),
316            ("a1", "a1"),
317            ("a1_#sd", "a1_"),
318        ];
319
320        for (src, expected) in inputs {
321            let (got, rest) = parse_identifier(src).unwrap();
322            assert_eq!(got.len() + rest.len(), src.len());
323            assert_eq!(got, expected);
324        }
325    }
326
327    #[test]
328    fn unknown_environment_variable() {
329        let path = "$UNKNOWN";
330
331        let err = resolve_path(path, |_| None).unwrap_err();
332
333        let missing_variable = err.downcast::<MissingVariable>().unwrap();
334        assert_eq!(
335            *missing_variable,
336            MissingVariable {
337                variable: String::from("UNKNOWN"),
338            }
339        );
340    }
341
342    #[test]
343    fn invalid_variables() {
344        let inputs = &["$1", "$"];
345
346        for input in inputs {
347            let err = resolve_path(input, |_| unreachable!()).unwrap_err();
348
349            let err = err.downcast::<UnableToParseVariable>().unwrap();
350            assert_eq!(
351                *err,
352                UnableToParseVariable {
353                    rest: input.to_string(),
354                }
355            );
356        }
357    }
358}