embed_rust/
lib.rs

1#![feature(proc_macro_span)]
2
3use std::{
4    collections::{hash_map::Entry, HashMap, HashSet},
5    env::{self, temp_dir},
6    fs::{self, File},
7    io::{Read, Write},
8    path::{Path, PathBuf},
9    process::Command,
10};
11
12use fs2::FileExt;
13use path_slash::PathBufExt;
14use proc_macro2::{Group, Span};
15use quote::quote;
16use serde::Deserialize;
17use syn::{
18    braced, bracketed,
19    punctuated::Punctuated,
20    token::{Brace, Bracket, Comma},
21    Error, Ident, LitStr, Token,
22};
23
24const CARGO_DEPENDENCIES_SECTION: &str = "[dependencies]";
25const COMPILER_ARTIFACT_MESSAGE_TYPE: &str = "compiler-artifact";
26
27#[derive(Deserialize)]
28struct CompilerArtifactMessage {
29    reason: String,
30    filenames: Vec<String>,
31}
32
33macro_rules! parse_options {
34    ($input: expr, $key: ident, $args: ident, $argument_parser: block) => {
35        parse_options!($input, $key, $args, $argument_parser, {}, false);
36    };
37    ($input: expr, $key: ident, $args: ident, $argument_parser: block, $string_argument_parser: block) => {
38        parse_options!($input, $key, $args, $argument_parser, $string_argument_parser, true);
39    };
40    ($input: expr, $key: ident, $args: ident, $argument_parser: block, $string_argument_parser: block, $allow_string_arguments: expr) => {
41        let mut seen_arguments = HashSet::new();
42        let $args;
43        let _: Brace = braced!($args in $input);
44        while !$args.is_empty() {
45            let lookahead = $args.lookahead1();
46            if lookahead.peek(Ident) {
47                let $key: Ident = $args.parse()?;
48                let _colon: syn::token::Colon = $args.parse()?;
49                if !seen_arguments.insert($key.to_string()) {
50                    return Err(Error::new(
51                        $key.span(),
52                        format!("Duplicated parameter `{}`", $key),
53                    ));
54                }
55                $argument_parser;
56            } else if $allow_string_arguments && lookahead.peek(LitStr) {
57                #[allow(unused_variables)]
58                let $key: LitStr = $args.parse()?;
59                let _colon: syn::token::Colon = $args.parse()?;
60                $string_argument_parser;
61            } else {
62                return Err(lookahead.error());
63            }
64            let lookahead = $args.lookahead1();
65            if lookahead.peek(Comma) {
66                let _: Comma = $args.parse()?;
67            } else if !$args.is_empty() {
68                return Err(lookahead.error());
69            }
70        }
71    };
72}
73
74#[derive(Debug)]
75struct GitSource {
76    url: String,
77    branch: Option<String>,
78    path: Option<PathBuf>,
79}
80
81#[derive(Debug)]
82enum Source {
83    Inline(Group),
84    Git(GitSource),
85    Path(PathBuf),
86}
87
88enum CommandItem {
89    Raw(String),
90    InputPath,
91    OutputPath,
92}
93
94impl CommandItem {
95    fn to_string<'a>(&'a self, input: &'a String, output: &'a String) -> &'a String {
96        match self {
97            CommandItem::Raw(string) => string,
98            CommandItem::InputPath => input,
99            CommandItem::OutputPath => output,
100        }
101    }
102}
103
104/// Arguments for the [embed_rust] macro.
105struct MatchTelecommandArgs {
106    sources: Vec<Source>,
107    extra_files: HashMap<PathBuf, String>,
108    dependencies: String,
109    post_build_commands: Vec<Vec<CommandItem>>,
110    binary_cache_path: Option<PathBuf>,
111}
112
113impl syn::parse::Parse for MatchTelecommandArgs {
114    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
115        let mut sources = Vec::new();
116        let mut extra_files = HashMap::new();
117        let mut dependencies = String::new();
118        let mut post_build_commands = Vec::new();
119        let mut binary_cache_path = None;
120        parse_options!(
121            input,
122            key,
123            args,
124            {
125                match key.to_string().as_str() {
126                    "source" => sources.push(Source::Inline(args.parse()?)),
127                    "dependencies" => {
128                        let dependencies_string: LitStr = args.parse()?;
129                        dependencies = dependencies_string.value();
130                    }
131                    "git" => {
132                        let lookahead = args.lookahead1();
133                        sources.push(Source::Git(if lookahead.peek(LitStr) {
134                            let url: LitStr = args.parse()?;
135                            GitSource {
136                                url: url.value(),
137                                branch: None,
138                                path: None,
139                            }
140                        } else if lookahead.peek(Brace) {
141                            let mut url = None;
142                            let mut branch = None;
143                            let mut path = None;
144                            parse_options!(args, key, git_args, {
145                                match key.to_string().as_str() {
146                                    "url" => {
147                                        let url_literal: LitStr = git_args.parse()?;
148                                        url = Some(url_literal.value());
149                                    }
150                                    "branch" => {
151                                        let branch_literal: LitStr = git_args.parse()?;
152                                        branch = Some(branch_literal.value());
153                                    }
154                                    "path" => {
155                                        let path_literal: LitStr = git_args.parse()?;
156                                        path = Some(PathBuf::from_slash(path_literal.value()));
157                                    }
158                                    _ => {
159                                        return Err(Error::new(
160                                            key.span(),
161                                            format!("Invalid parameter `{key}`"),
162                                        ))
163                                    }
164                                }
165                            });
166                            let Some(url) = url else {
167                                return Err(Error::new(
168                                    key.span(),
169                                    format!("missing `url` key for `{key}` argument"),
170                                ));
171                            };
172                            GitSource { url, branch, path }
173                        } else {
174                            return Err(lookahead.error());
175                        }))
176                    }
177                    "path" => {
178                        let path_literal: LitStr = args.parse()?;
179                        sources.push(Source::Path(PathBuf::from_slash(path_literal.value())));
180                    }
181                    "post_build" => {
182                        if !post_build_commands.is_empty() {
183                            return Err(Error::new(
184                                key.span(),
185                                format!("Can only have one `{key}`"),
186                            ));
187                        }
188                        let commands;
189                        let _: Bracket = bracketed!(commands in args);
190                        post_build_commands = commands
191                            .parse_terminated(
192                                |command| {
193                                    let command_items;
194                                    let _: Bracket = bracketed!(command_items in command);
195                                    Ok(Punctuated::<_, Token![,]>::parse_separated_nonempty_with(
196                                        &command_items,
197                                        |command_item| {
198                                            let lookahead = command_item.lookahead1();
199                                            if lookahead.peek(LitStr) {
200                                                let item: LitStr = command_item.parse()?;
201                                                Ok(CommandItem::Raw(item.value()))
202                                            } else if lookahead.peek(Ident) {
203                                                let item: Ident = command_item.parse()?;
204                                                match item.to_string().as_str() {
205                                                    "input_path" => Ok(CommandItem::InputPath),
206                                                    "output_path" => Ok(CommandItem::OutputPath),
207                                                    _ => Err(Error::new(
208                                                        item.span(),
209                                                        format!(
210                                                            "Invalid command expansion variable `{item}`, only `input_path` and `output_path` are valid",
211                                                            item = item.to_string().as_str(),
212                                                        ),
213                                                    )),
214                                                }
215                                            } else {
216                                                Err(lookahead.error())
217                                            }
218                                        },
219                                    )?
220                                    .into_iter()
221                                    .collect())
222                                },
223                                Token![,],
224                            )?
225                            .into_iter()
226                            .collect();
227                    }
228                    "binary_cache_path" => {
229                        let path_literal: LitStr = args.parse()?;
230                        binary_cache_path = Some(PathBuf::from_slash(path_literal.value()));
231                    }
232                    _ => return Err(Error::new(key.span(), format!("Invalid parameter `{key}`"))),
233                }
234            },
235            {
236                let key_value = key.value();
237                let path = PathBuf::from_slash(key_value.clone());
238                let extra_file_slot = match extra_files.entry(path) {
239                    Entry::Vacant(entry) => entry,
240                    Entry::Occupied(_) => {
241                        return Err(Error::new(
242                            key.span(),
243                            format!("Duplicated file `{key_value}`"),
244                        ))
245                    }
246                };
247                match key_value.as_str() {
248                    "src/main.rs" => sources.push(Source::Inline(args.parse()?)),
249                    _ => {
250                        let content = if key_value.ends_with(".rs") {
251                            let source: Group = args.parse()?;
252                            let source = source.stream();
253                            quote!(#source).to_string()
254                        } else {
255                            let lookahead = args.lookahead1();
256                            if lookahead.peek(LitStr) {
257                                let content: LitStr = args.parse()?;
258                                content.value()
259                            } else if lookahead.peek(Brace) {
260                                let source: Group = args.parse()?;
261                                let source = source.stream();
262                                quote!(#source).to_string()
263                            } else {
264                                return Err(lookahead.error());
265                            }
266                        };
267                        extra_file_slot.insert(content);
268                    }
269                }
270            }
271        );
272        if sources.is_empty() {
273            return Err(Error::new(Span::call_site(), "Missing `source` attribute"));
274        };
275        Ok(Self {
276            sources,
277            extra_files,
278            dependencies,
279            post_build_commands,
280            binary_cache_path,
281        })
282    }
283}
284
285#[proc_macro]
286pub fn embed_rust(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
287    let args = syn::parse_macro_input!(tokens as MatchTelecommandArgs);
288    let path = match compile_rust(args) {
289        Ok(path) => path,
290        Err(error) => return error.into_compile_error().into(),
291    };
292    let Some(path) = path.to_str() else {
293        return Error::new(
294            Span::call_site(),
295            "Generated binary path contains invalid UTF-8",
296        )
297        .into_compile_error()
298        .into();
299    };
300    quote! {
301        include_bytes!(#path)
302    }
303    .into()
304}
305
306fn lock_and_clear_directory(generated_project_dir: &Path) -> syn::Result<File> {
307    let mut lock_file = generated_project_dir.to_path_buf();
308    lock_file.set_extension(".lock");
309    let lock_file = File::options()
310        .read(true)
311        .write(true)
312        .create(true)
313        .open(lock_file)
314        .map_err(|error| {
315            Error::new(
316                Span::call_site(),
317                format!("Failed to open lock-file: {error:?}"),
318            )
319        })?;
320    if let Err(error) = lock_file.lock_exclusive() {
321        return Err(Error::new(
322            Span::call_site(),
323            format!("Failed to lock lock-file: {error:?}"),
324        ));
325    }
326    let _ = fs::remove_dir_all(generated_project_dir); // Ignore errors about non-existent directories.
327    if let Err(error) = fs::create_dir_all(generated_project_dir) {
328        return Err(Error::new(
329            Span::call_site(),
330            format!("Failed to create embedded project directory: {error:?}"),
331        ));
332    }
333    Ok(lock_file)
334}
335
336fn fill_project_template(
337    generated_project_dir: &Path,
338    extra_files: &HashMap<PathBuf, String>,
339    dependencies: &String,
340) -> syn::Result<()> {
341    for (path, content) in extra_files.iter() {
342        write_file(generated_project_dir.join(path), content)?;
343    }
344    if !dependencies.is_empty() {
345        let cargo_file = generated_project_dir.join("Cargo.toml");
346        let mut cargo_content = String::new();
347        File::open(cargo_file.clone())
348            .map_err(|error| {
349                Error::new(
350                    Span::call_site(),
351                    format!("Failed to open {cargo_file:?}: {error:?}"),
352                )
353            })?
354            .read_to_string(&mut cargo_content)
355            .map_err(|error| {
356                Error::new(
357                    Span::call_site(),
358                    format!("Failed to read {cargo_file:?}: {error:?}"),
359                )
360            })?;
361        if !cargo_content.contains(CARGO_DEPENDENCIES_SECTION) {
362            return Err(Error::new(
363                Span::call_site(),
364                "Generated Cargo.toml has no dependencies section",
365            ));
366        }
367        cargo_content = cargo_content.replace(
368            CARGO_DEPENDENCIES_SECTION,
369            (CARGO_DEPENDENCIES_SECTION.to_owned() + "\n" + dependencies).as_str(),
370        );
371        write_file(cargo_file, &cargo_content)?;
372    }
373    Ok(())
374}
375
376fn prepare_source(
377    source: &Source,
378    generated_project_dir: &Path,
379    source_file_dir: &Path,
380    args: &MatchTelecommandArgs,
381) -> syn::Result<(Option<File>, PathBuf)> {
382    Ok(match source {
383        Source::Inline(source) => {
384            let lock_file = lock_and_clear_directory(generated_project_dir)?;
385            run_command(
386                Command::new("cargo")
387                    .current_dir(generated_project_dir)
388                    .arg("init"),
389                "Failed to initialize embedded project crate",
390            )?;
391            let source_dir = generated_project_dir.join("src");
392            let main_source_file = source_dir.join("main.rs");
393            let main_source = source.stream();
394            write_file(main_source_file, &quote!(#main_source).to_string())?;
395            fill_project_template(generated_project_dir, &args.extra_files, &args.dependencies)?;
396            (Some(lock_file), generated_project_dir.to_path_buf())
397        }
398        Source::Git(git_source) => {
399            let lock_file = lock_and_clear_directory(generated_project_dir)?;
400            let mut clone_command = Command::new("git");
401            let mut clone_command = clone_command
402                .arg("clone")
403                .arg("--recurse-submodules")
404                .arg("--depth=1");
405            if let Some(ref branch) = git_source.branch {
406                clone_command = clone_command.arg("--branch").arg(branch);
407            }
408            run_command(
409                clone_command
410                    .arg(&git_source.url)
411                    .arg(generated_project_dir),
412                "Failed to clone embedded project",
413            )?;
414            let generated_project_dir = if let Some(ref path) = git_source.path {
415                generated_project_dir.join(path.clone())
416            } else {
417                generated_project_dir.to_path_buf()
418            };
419            fill_project_template(
420                &generated_project_dir,
421                &args.extra_files,
422                &args.dependencies,
423            )?;
424            (Some(lock_file), generated_project_dir)
425        }
426        Source::Path(path) => {
427            let generated_project_dir = if path.is_absolute() {
428                path.clone()
429            } else {
430                source_file_dir.join(path)
431            };
432            if !generated_project_dir.exists() {
433                return Err(Error::new(
434                    Span::call_site(),
435                    format!("Given path does not exist: {path:?}"),
436                ));
437            }
438            (None, generated_project_dir)
439        }
440    })
441}
442
443fn compile_rust(args: MatchTelecommandArgs) -> syn::Result<PathBuf> {
444    let call_site = Span::call_site().unwrap();
445    let call_site_location = call_site.start();
446    let source_file = call_site.source_file().path();
447    if !source_file.exists() {
448        return Err(Error::new(
449            Span::call_site(),
450            "Unable to get path of source file",
451        ));
452    }
453    let crate_dir = PathBuf::from(
454        env::var("CARGO_MANIFEST_DIR")
455            .expect("'CARGO_MANIFEST_DIR' environment variable is missing"),
456    );
457
458    let source_file_id = sanitize_filename::sanitize(
459        source_file
460            .strip_prefix(&crate_dir)
461            .unwrap_or(&source_file)
462            .to_string_lossy(),
463    )
464    .replace('.', "_");
465
466    let line = call_site_location.line();
467    let column = call_site_location.column();
468    let id = format!("{source_file_id}_{line}_{column}");
469    let mut generated_project_dir = env::var("OUT_DIR")
470        .map_or_else(|_| temp_dir(), PathBuf::from)
471        .join(id);
472    let source_file_dir = source_file
473        .parent()
474        .expect("Should be able to resolve the parent directory of the source file");
475
476    let mut i = 0;
477    let lock_file = loop {
478        match prepare_source(
479            &args.sources[i],
480            &generated_project_dir,
481            source_file_dir,
482            &args,
483        ) {
484            Ok((lock_file, new_generated_project_dir)) => {
485                generated_project_dir = new_generated_project_dir;
486                break lock_file;
487            }
488            Err(error) => {
489                if i + 1 == args.sources.len() {
490                    if let Some(compiled_binary_path) = args.binary_cache_path {
491                        return Ok(compiled_binary_path);
492                    }
493                    return Err(error);
494                }
495            }
496        }
497        i += 1;
498    };
499
500    let mut build_command = Command::new("cargo");
501    let build_command = build_command
502        .current_dir(&generated_project_dir)
503        .arg("build")
504        .arg("--release");
505    run_command(
506        build_command,
507        format!("Failed to build embedded project crate at {generated_project_dir:?}").as_str(),
508    )?;
509
510    // Find binary.
511    let output = run_command(
512        build_command.arg("--message-format").arg("json"),
513        "Failed to build embedded project crate",
514    )?;
515    let Ok(output) = core::str::from_utf8(&output) else {
516        return Err(Error::new(
517            Span::call_site(),
518            "Unable to parse cargo output: Invalid UTF-8",
519        ));
520    };
521    let mut artifact_message = None;
522    for line in output.lines() {
523        if line.contains(COMPILER_ARTIFACT_MESSAGE_TYPE) {
524            artifact_message = Some(line);
525        }
526    }
527    let Some(artifact_message) = artifact_message else {
528        return Err(Error::new(
529            Span::call_site(),
530            "Did not found an artifact message in cargo build output",
531        ));
532    };
533    let artifact_message: CompilerArtifactMessage = serde_json::from_str(artifact_message)
534        .map_err(|error| {
535            Error::new(
536                Span::call_site(),
537                format!("Failed to parse artifact message from cargo: {error:?}"),
538            )
539        })?;
540    if artifact_message.reason != COMPILER_ARTIFACT_MESSAGE_TYPE {
541        return Err(Error::new(
542            Span::call_site(),
543            "Invalid cargo artifact message: Wrong reason",
544        ));
545    }
546    let Some(mut artifact_path) = artifact_message.filenames.first() else {
547        return Err(Error::new(
548            Span::call_site(),
549            "Invalid cargo artifact message: No artifact path",
550        ));
551    };
552    let output_artifact_path = &(artifact_path.to_owned() + ".tmp");
553    let mut used_output_path = false;
554    for command_items in args.post_build_commands {
555        let (mut shell, first_arg) = if cfg!(target_os = "windows") {
556            (Command::new("powershell"), "/C")
557        } else {
558            (Command::new("sh"), "-c")
559        };
560        let command = shell.arg(first_arg).arg(
561            command_items
562                .iter()
563                .map(|item| item.to_string(artifact_path, output_artifact_path))
564                .fold(String::new(), |left, right| left + " " + right),
565        );
566        run_command(command, "Failed to run post_build command")?;
567        used_output_path |= command_items
568            .iter()
569            .any(|item| matches!(item, CommandItem::OutputPath));
570    }
571    if used_output_path {
572        artifact_path = output_artifact_path;
573    }
574    let artifact_path = PathBuf::from(artifact_path);
575    let artifact_path = if let Some(binary_cache_path) = args.binary_cache_path {
576        let absolute_binary_cache_path = source_file_dir.join(&binary_cache_path);
577        if let Some(parent) = absolute_binary_cache_path.parent() {
578            fs::create_dir_all(parent).map_err(|error| {
579                Error::new(
580                    Span::call_site(),
581                    format!("Failed to create directories for binary_cache_path at {parent:?}: {error:?}"),
582                )
583            })?;
584        }
585        fs::copy(artifact_path, &absolute_binary_cache_path).map_err(|error| {
586            Error::new(
587                Span::call_site(),
588                format!(
589                    "Failed to copy generated binary to binary_cache_path at {absolute_binary_cache_path:?}: {error:?}"
590                ),
591            )
592        })?;
593        binary_cache_path
594    } else {
595        artifact_path
596    };
597    drop(lock_file);
598
599    Ok(artifact_path)
600}
601
602fn run_command(command: &mut Command, error_message: &str) -> syn::Result<Vec<u8>> {
603    match command.output() {
604        Ok(output) => {
605            if !output.status.success() {
606                Err(Error::new(
607                    Span::call_site(),
608                    format!(
609                        "{error_message}: `{command:?}` failed with exit code {exit_code:?}\n# Stdout:\n{stdout}\n# Stderr:\n{stderr}",
610                        exit_code = output.status.code(),
611                        stdout = core::str::from_utf8(output.stdout.as_slice()).unwrap_or("<Invalid UTF-8>"),
612                        stderr = core::str::from_utf8(output.stderr.as_slice()).unwrap_or("<Invalid UTF-8>")
613                    ),
614                ))
615            } else {
616                Ok(output.stdout)
617            }
618        }
619        Err(error) => Err(Error::new(
620            Span::call_site(),
621            format!("{error_message}: `{command:?}` failed: {error:?}"),
622        )),
623    }
624}
625
626fn write_file(path: PathBuf, content: &String) -> syn::Result<()> {
627    if let Some(parent_dir) = path.parent() {
628        if let Err(error) = fs::create_dir_all(parent_dir) {
629            return Err(Error::new(
630                Span::call_site(),
631                format!("Failed to create parent directory for {path:?}: {error:?}"),
632            ));
633        }
634    }
635    let mut file = File::create(path.clone()).map_err(|error| {
636        Error::new(
637            Span::call_site(),
638            format!("Failed to open {path:?}: {error:?}"),
639        )
640    })?;
641    file.write_all(content.as_bytes()).map_err(|error| {
642        Error::new(
643            Span::call_site(),
644            format!("Failed to write {path:?} to project: {error:?}"),
645        )
646    })?;
647    Ok(())
648}