1use proc_macro::TokenStream;
2use quote::quote;
3
4const ATTRIBUTE_PATH: &str = "partial_close";
5
6#[proc_macro_derive(PartialClose, attributes(partial_close))]
7pub fn derive_partial_close(input: TokenStream) -> TokenStream {
8 let input = syn::parse_macro_input!(input as syn::DeriveInput);
9
10 match &input.data {
11 syn::Data::Struct(s) => derive_struct(&input.ident, s),
12 syn::Data::Enum(e) => derive_enum(&input.ident, e),
13 syn::Data::Union(_) => {
14 unimplemented!("#derive[PartialClose] is not implemented for unions")
15 }
16 }
17}
18
19fn derive_struct(ident: &syn::Ident, s: &syn::DataStruct) -> TokenStream {
20 let mut inner = quote!();
21
22 match &s.fields {
23 syn::Fields::Named(named) => {
24 for (i, field) in named.named.iter().enumerate() {
25 if i != 0 {
26 inner.extend(quote!(&& ))
27 }
28
29 let name = &field.ident;
30 match (attributes_get_resoultion(&field.attrs), type_is_option(&field.ty)) {
31 (Some(resolution), true) => inner.extend(quote!(
32 (match (self.#name, other.#name) {
33 (Some(inner_s), Some(inner_o)) => (inner_s - inner_o).abs() < #resolution,
34 (Some(_), None) => false,
35 (None, Some(_)) => false,
36 (None, None) => true,
37 }
38 ))),
39 (Some(resolution), false) => inner.extend(quote!(
40 (self.#name - other.#name).abs() < #resolution)
41 ),
42 (None, _) => inner.extend(quote!(
43 self.#name == other.#name
44 ))
45 }
46 }
47 }
48 _ => unimplemented!(
49 "#[derive(PartialClose)] is not implemented for structs with unnamed fields"
50 ),
51 }
52
53 quote!(
54 impl #ident {
55 fn partial_close(&self, other: &Self) -> bool {
56 #inner
57 }
58 }
59 )
60 .into()
61}
62
63fn derive_enum(ident: &syn::Ident, e: &syn::DataEnum) -> TokenStream {
64 let mut inner = quote!();
65
66 for variant in e.variants.iter() {
67 let name = &variant.ident;
68 match variant.fields {
69 syn::Fields::Named(_) => unimplemented!(
70 "#[derive(PartialClose)] is not implemented for enums with a named field"
71 ),
72 syn::Fields::Unnamed(_) => {
73 if let Some(resolution) = attributes_get_resoultion(&variant.attrs) {
74 inner.extend(
75 quote!((Self::#name(s), Self::#name(o)) => (*s - *o).abs() < #resolution,),
76 )
77 } else {
78 inner.extend(quote!((Self::#name(s), Self::#name(o)) => s.eq(o),))
79 }
80 }
81 syn::Fields::Unit => inner.extend(quote!((Self::#name, Self::#name) => true,)),
82 }
83 }
84
85 quote!(
86 impl #ident {
87 fn partial_close(&self, other: &Self) -> bool {
88 match (self, other) {
89 #inner
90 _ => false,
91 }
92 }
93 }
94 )
95 .into()
96}
97
98fn type_is_option(ty: &syn::Type) -> bool {
99 if let syn::Type::Path(path) = ty {
100 if let Some(segment) = path.path.segments.first() {
101 return segment.ident == "Option";
102 }
103 }
104
105 false
106}
107
108fn attributes_get_resoultion(attributes: &Vec<syn::Attribute>) -> Option<f32> {
109 for attr in attributes {
110 if attr.path().is_ident(ATTRIBUTE_PATH) {
111 let assign = attr.parse_args::<syn::ExprAssign>().unwrap();
112 let resolution = *(assign.right);
113 match resolution {
114 syn::Expr::Lit(r) => match r.lit {
115 syn::Lit::Float(f) => return Some(f.base10_parse().unwrap()),
116 _ => unimplemented!("Expected a resolution as a float"),
117 },
118 _ => unimplemented!("Expected a resolution `(resoultion = x.y)`"),
119 }
120 }
121 }
122 None
123}