batbox_collection_derive/
lib.rs

1#![recursion_limit = "128"]
2#![allow(unused_imports)]
3
4extern crate proc_macro;
5
6use darling::{FromDeriveInput, FromField};
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9
10#[proc_macro_derive(HasId, attributes(has_id))]
11pub fn derive_has_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12    let input: syn::DeriveInput = syn::parse_macro_input!(input);
13    match DeriveInput::from_derive_input(&input) {
14        Ok(input) => input.derive().into(),
15        Err(e) => e.write_errors().into(),
16    }
17}
18
19#[derive(FromDeriveInput)]
20#[darling(supports(struct_any))]
21struct DeriveInput {
22    ident: syn::Ident,
23    generics: syn::Generics,
24    data: darling::ast::Data<(), Field>,
25}
26
27#[derive(FromField)]
28#[darling(attributes(has_id))]
29struct Field {
30    ident: Option<syn::Ident>,
31    ty: syn::Type,
32    #[darling(default)]
33    id: bool,
34}
35
36impl DeriveInput {
37    fn derive(self) -> TokenStream {
38        let Self {
39            ident,
40            generics,
41            data,
42        } = self;
43        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
44        let data = data.take_struct().unwrap();
45        fn find_field_with_id_attr(fields: &[Field]) -> Option<(usize, &Field)> {
46            let mut result = None;
47            for (index, field) in fields.iter().enumerate() {
48                if field.id {
49                    assert!(result.is_none(), "Only one field must have id attr");
50                    result = Some((index, field));
51                }
52            }
53            result
54        }
55        fn find_field_with_id_name(fields: &[Field]) -> Option<(usize, &Field)> {
56            fields
57                .iter()
58                .enumerate()
59                .find(|(_, field)| field.ident.as_ref().map_or(false, |ident| ident == "id"))
60        }
61        let (id_field_index, id_field) = find_field_with_id_attr(&data.fields)
62            .or_else(|| find_field_with_id_name(&data.fields))
63            .expect("Expected field with #[id] attr or named `id`");
64        let id_field_ty = &id_field.ty;
65        let id_field_index = syn::Index::from(id_field_index);
66        let id_field_ref = match &id_field.ident {
67            Some(ident) => quote! { #ident },
68            None => quote! { #id_field_index },
69        };
70        quote! {
71            impl #impl_generics batbox::collection::HasId for #ident #ty_generics #where_clause {
72                type Id = #id_field_ty;
73                fn id(&self) -> &Self::Id {
74                    &self.#id_field_ref
75                }
76            }
77        }
78    }
79}