1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
use proc_macro::TokenStream;

fn impl_serde_json_roundtrip(name: &syn::Ident) -> TokenStream {
    let gen = quote::quote!(
        impl ::tskit::metadata::MetadataRoundtrip for #name {
            fn encode(&self) -> Result<Vec<u8>, ::tskit::metadata::MetadataError> {
                match ::serde_json::to_string(self) {
                    Ok(x) => Ok(x.as_bytes().to_vec()),
                    Err(e) => {
                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
                    }
                }
            }

            fn decode(md: &[u8]) -> Result<Self, ::tskit::metadata::MetadataError> {
                let value: Result<Self, ::serde_json::Error> = ::serde_json::from_slice(md);
                match value {
                    Ok(v) => Ok(v),
                    Err(e) => {
                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
                    }
                }
            }
        }
    );
    gen.into()
}

fn impl_serde_bincode_roundtrip(name: &syn::Ident) -> TokenStream {
    let gen = quote::quote!(
        impl ::tskit::metadata::MetadataRoundtrip for #name {
            fn encode(&self) -> Result<Vec<u8>, ::tskit::metadata::MetadataError> {
                match ::bincode::serialize(&self) {
                    Ok(x) => Ok(x),
                    Err(e) => {
                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
                    }
                }
            }
            fn decode(md: &[u8]) -> Result<Self, ::tskit::metadata::MetadataError> {
                match ::bincode::deserialize(md) {
                    Ok(x) => Ok(x),
                    Err(e) => {
                        Err(::tskit::metadata::MetadataError::RoundtripError { value: Box::new(e) })
                    }
                }
            }
        }
    );
    gen.into()
}

fn impl_metadata_roundtrip_macro(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
    let name = &ast.ident;
    let attrs = &ast.attrs;

    for attr in attrs.iter() {
        if attr.path.is_ident("serializer") {
            let lit: syn::LitStr = attr.parse_args().unwrap();
            let serializer = lit.value();

            if &serializer == "serde_json" {
                return Ok(impl_serde_json_roundtrip(name));
            } else if &serializer == "bincode" {
                return Ok(impl_serde_bincode_roundtrip(name));
            } else {
                proc_macro_error::abort!(serializer, "is not a supported protocol.");
            }
        } else {
            proc_macro_error::abort!(attr.path, "is not a supported attribute.");
        }
    }

    proc_macro_error::abort_call_site!("missing [serializer(...)] attribute")
}

macro_rules! make_derive_metadata_tag {
    ($function: ident, $metadatatag: ident) => {
        #[proc_macro_error::proc_macro_error]
        #[proc_macro_derive($metadatatag, attributes(serializer))]
        /// Register a type as metadata.
        pub fn $function(input: TokenStream) -> TokenStream {
            let ast: syn::DeriveInput = match syn::parse(input) {
                Ok(ast) => ast,
                Err(err) => proc_macro_error::abort!(err),
            };
            let mut roundtrip = impl_metadata_roundtrip_macro(&ast).unwrap();
            let name = &ast.ident;
            let gen: proc_macro::TokenStream = quote::quote!(
                impl ::tskit::metadata::$metadatatag for #name {}
            )
            .into();
            roundtrip.extend(gen);
            roundtrip
        }
    };
}

make_derive_metadata_tag!(individual_metadata_derive, IndividualMetadata);
make_derive_metadata_tag!(mutation_metadata_derive, MutationMetadata);
make_derive_metadata_tag!(site_metadata_derive, SiteMetadata);
make_derive_metadata_tag!(population_metadata_derive, PopulationMetadata);
make_derive_metadata_tag!(node_metadata_derive, NodeMetadata);
make_derive_metadata_tag!(edge_metadata_derive, EdgeMetadata);
make_derive_metadata_tag!(migration_metadata_derive, MigrationMetadata);