micropelt_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3
4const ATTRIBUTE_PATH: &str = "partial_close";
5
6#[proc_macro_derive(PartialClose, attributes(partial_close))]
7pub fn derive_partial_close(input: TokenStream) -> TokenStream {
8    let input = syn::parse_macro_input!(input as syn::DeriveInput);
9
10    match &input.data {
11        syn::Data::Struct(s) => derive_struct(&input.ident, s),
12        syn::Data::Enum(e) => derive_enum(&input.ident, e),
13        syn::Data::Union(_) => {
14            unimplemented!("#derive[PartialClose] is not implemented for unions")
15        }
16    }
17}
18
19fn derive_struct(ident: &syn::Ident, s: &syn::DataStruct) -> TokenStream {
20    let mut inner = quote!();
21
22    match &s.fields {
23        syn::Fields::Named(named) => {
24            for (i, field) in named.named.iter().enumerate() {
25                if i != 0 {
26                    inner.extend(quote!(&& ))
27                }
28
29                let name = &field.ident;
30                match (attributes_get_resoultion(&field.attrs), type_is_option(&field.ty)) {
31                    (Some(resolution), true) => inner.extend(quote!(
32                        (match (self.#name, other.#name) {
33                            (Some(inner_s), Some(inner_o)) => (inner_s - inner_o).abs() < #resolution,
34                            (Some(_), None) => false,
35                            (None, Some(_)) => false,
36                            (None, None) => true,
37                        }
38                    ))),
39                    (Some(resolution), false) => inner.extend(quote!(
40                        (self.#name - other.#name).abs() < #resolution)
41                    ),
42                    (None, _) => inner.extend(quote!(
43                        self.#name == other.#name
44                    ))
45                }
46            }
47        }
48        _ => unimplemented!(
49            "#[derive(PartialClose)] is not implemented for structs with unnamed fields"
50        ),
51    }
52
53    quote!(
54        impl #ident {
55            fn partial_close(&self, other: &Self) -> bool {
56                #inner
57            }
58        }
59    )
60    .into()
61}
62
63fn derive_enum(ident: &syn::Ident, e: &syn::DataEnum) -> TokenStream {
64    let mut inner = quote!();
65
66    for variant in e.variants.iter() {
67        let name = &variant.ident;
68        match variant.fields {
69            syn::Fields::Named(_) => unimplemented!(
70                "#[derive(PartialClose)] is not implemented for enums with a named field"
71            ),
72            syn::Fields::Unnamed(_) => {
73                if let Some(resolution) = attributes_get_resoultion(&variant.attrs) {
74                    inner.extend(
75                        quote!((Self::#name(s), Self::#name(o)) => (*s - *o).abs() < #resolution,),
76                    )
77                } else {
78                    inner.extend(quote!((Self::#name(s), Self::#name(o)) => s.eq(o),))
79                }
80            }
81            syn::Fields::Unit => inner.extend(quote!((Self::#name, Self::#name) => true,)),
82        }
83    }
84
85    quote!(
86        impl #ident {
87            fn partial_close(&self, other: &Self) -> bool {
88                match (self, other) {
89                    #inner
90                    _ => false,
91                }
92            }
93        }
94    )
95    .into()
96}
97
98fn type_is_option(ty: &syn::Type) -> bool {
99    if let syn::Type::Path(path) = ty {
100        if let Some(segment) = path.path.segments.first() {
101            return segment.ident == "Option";
102        }
103    }
104
105    false
106}
107
108fn attributes_get_resoultion(attributes: &Vec<syn::Attribute>) -> Option<f32> {
109    for attr in attributes {
110        if attr.path().is_ident(ATTRIBUTE_PATH) {
111            let assign = attr.parse_args::<syn::ExprAssign>().unwrap();
112            let resolution = *(assign.right);
113            match resolution {
114                syn::Expr::Lit(r) => match r.lit {
115                    syn::Lit::Float(f) => return Some(f.base10_parse().unwrap()),
116                    _ => unimplemented!("Expected a resolution as a float"),
117                },
118                _ => unimplemented!("Expected a resolution `(resoultion = x.y)`"),
119            }
120        }
121    }
122    None
123}