1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#![warn(clippy::pedantic)]
#![warn(rust_2018_idioms)]

mod codegen;
mod parser;

use codegen::generate_impl;
use parser::{is_binread_attr, Input};
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(BinRead, attributes(binread, br))]
pub fn derive_binread_trait(input: TokenStream) -> TokenStream {
    derive_binread_internal(parse_macro_input!(input as DeriveInput)).into()
}

fn derive_binread_internal(input: DeriveInput) -> proc_macro2::TokenStream {
    let binread_input = Input::from_input(&input);
    generate_impl(&input, &binread_input)
}

#[proc_macro_attribute]
pub fn derive_binread(_: TokenStream, input: TokenStream) -> TokenStream {
    let mut derive_input = parse_macro_input!(input as DeriveInput);
    let binread_input = Input::from_input(&derive_input);
    let generated_impl = generate_impl(&derive_input, &binread_input);
    let binread_input = binread_input.ok();

    clean_struct_attrs(&mut derive_input.attrs);

    match &mut derive_input.data {
        syn::Data::Struct(input_struct) => {
            clean_field_attrs(&binread_input, 0, &mut input_struct.fields);
        }
        syn::Data::Enum(input_enum) => {
            for (index, variant) in input_enum.variants.iter_mut().enumerate() {
                clean_struct_attrs(&mut variant.attrs);
                clean_field_attrs(&binread_input, index, &mut variant.fields);
            }
        }
        syn::Data::Union(union) => {
            for field in union.fields.named.iter_mut() {
                clean_struct_attrs(&mut field.attrs);
            }
        }
    }

    quote!(
        #derive_input
        #generated_impl
    )
    .into()
}

fn clean_field_attrs(
    binread_input: &Option<Input>,
    variant_index: usize,
    fields: &mut syn::Fields,
) {
    if let Some(binread_input) = binread_input {
        let fields = match fields {
            syn::Fields::Named(fields) => &mut fields.named,
            syn::Fields::Unnamed(fields) => &mut fields.unnamed,
            syn::Fields::Unit => return,
        };

        *fields = fields
            .iter_mut()
            .enumerate()
            .filter_map(|(index, value)| {
                if binread_input.is_temp_field(variant_index, index) {
                    None
                } else {
                    let mut value = value.clone();
                    clean_struct_attrs(&mut value.attrs);
                    Some(value)
                }
            })
            .collect();
    }
}

fn clean_struct_attrs(attrs: &mut Vec<syn::Attribute>) {
    attrs.retain(|attr| !is_binread_attr(attr));
}

#[cfg(test)]
mod tests {
    use runtime_macros_derive::emulate_derive_expansion_fallible;
    use std::{env, fs};

    #[test]
    fn derive_code_coverage() {
        let derive_tests_folder = env::current_dir()
            .unwrap()
            .join("..")
            .join("binread")
            .join("tests")
            .join("derive");

        let mut run_success = true;
        for entry in fs::read_dir(derive_tests_folder).unwrap() {
            let entry = entry.unwrap();
            if !entry.file_type().unwrap().is_file() {
                continue;
            }
            let file = fs::File::open(entry.path()).unwrap();
            let is_ok =
                emulate_derive_expansion_fallible(file, "BinRead", super::derive_binread_internal)
                    .is_ok();
            run_success &= is_ok;
        }

        assert!(run_success)
    }
}