float_derive_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{quote, ToTokens};
4use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Index, Type};
5
6fn is_float_type(ty: &Type) -> bool {
7    if let Type::Path(path) = ty {
8        let segments = &path.path.segments;
9        if segments.len() == 1 {
10            return segments[0].ident == "f32" || segments[0].ident == "f64"
11        }
12    }
13    false
14}
15
16fn partial_eq_impl(ty: &Type, self_tokens: &impl ToTokens, other_tokens: &impl ToTokens, is_first: &mut bool) -> TokenStream2 {
17    let first_tokens = if *is_first {
18        TokenStream2::new()
19    } else {
20        quote! { && }
21    };
22
23    let result = if is_float_type(ty) {
24        quote! { #first_tokens ::float_derive::utils::eq(#self_tokens, #other_tokens)}
25    } else {
26        quote! {
27            #first_tokens #self_tokens == #other_tokens
28        }
29    };
30    *is_first = false;
31    result
32}
33
34#[proc_macro_derive(FloatPartialEq)]
35pub fn derive_partial_eq(input: TokenStream) -> TokenStream {
36    let input = parse_macro_input!(input as DeriveInput);
37    let ident = &input.ident;
38
39    TokenStream::from(match input.data {
40        Data::Enum(data) => {
41            let mut num_non_unit_variants = 0;
42            let mut num_unit_variants = 0;
43
44            let variants = data
45                .variants
46                .iter()
47                .map(|variant| {
48                    let variant_ident = &variant.ident;
49                    let mut is_first = true;
50
51                    match &variant.fields {
52                        Fields::Named(fields) => {
53                            let self_args = fields.named.iter().enumerate().map(|(i, field)| {
54                                let field_ident = &field.ident.as_ref().unwrap();
55                                let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
56
57                                quote! {
58                                    #field_ident: #arg_ident,
59                                }
60                            });
61
62                            let other_args = fields.named.iter().enumerate().map(|(i, field)| {
63                                let field_ident = &field.ident.as_ref().unwrap();
64                                let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
65
66                                quote! {
67                                    #field_ident: #arg_ident,
68                                }
69                            });
70
71                            let impls = fields.named.iter().enumerate().map(|(i, field)| {
72                                let self_ident = Ident::new(&format!("__self_{i}"), field.span());
73                                let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
74
75                                partial_eq_impl(&field.ty, &quote! { *#self_ident}, &quote! { *#other_ident }, &mut is_first)
76                            });
77
78                            num_non_unit_variants += 1;
79                            quote! {
80                                (
81                                    #ident::#variant_ident { #(#self_args)* },
82                                    #ident::#variant_ident { #(#other_args)* }
83                                ) => {
84                                    #(#impls)*
85                                }
86                            }
87                        }
88                        Fields::Unnamed(fields) => {
89                            let self_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
90                                let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
91
92                                quote! {
93                                    #arg_ident,
94                                }
95                            });
96
97                            let other_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
98                                let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
99
100                                quote! {
101                                    #arg_ident,
102                                }
103                            });
104
105                            let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
106                                let self_ident = Ident::new(&format!("__self_{i}"), field.span());
107                                let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
108
109                                partial_eq_impl(&field.ty, &quote! { *#self_ident}, &quote! { *#other_ident }, &mut is_first)
110                            });
111
112                            num_non_unit_variants += 1;
113                            quote! {
114                                (
115                                    #ident::#variant_ident(#(#self_args)*),
116                                    #ident::#variant_ident(#(#other_args)*)
117                                ) => {
118                                    #(#impls)*
119                                }
120                            }
121                        }
122                        Fields::Unit => {
123                            num_unit_variants += 1;
124                            TokenStream2::new()
125                        }
126                    }
127                })
128                .collect::<TokenStream2>();
129
130            let num_variants = num_non_unit_variants + num_unit_variants;
131
132            let body = if num_non_unit_variants == 0 && num_variants < 2 {
133                quote! { true }
134            } else {
135                let default_pattern = if num_unit_variants == 0 {
136                    quote! {
137                        _ => unsafe { ::core::intrinsics::unreachable() }
138                    }
139                } else {
140                    quote! {
141                        _ => true
142                    }
143                };
144
145                let matched_variants = quote! {
146                    match (self, other) {
147                        #variants
148                        #default_pattern
149                    }
150                };
151
152                if num_variants > 1 {
153                    let tags = quote! {
154                        let __self_tag = ::core::intrinsics::discriminant_value(self);
155                        let __arg1_tag = ::core::intrinsics::discriminant_value(other);
156                        __self_tag == __arg1_tag
157                    };
158
159                    if num_non_unit_variants > 0 {
160                        quote! {
161                            #tags && #matched_variants
162                        }
163                    } else {
164                        tags
165                    }
166                } else {
167                    matched_variants
168                }
169            };
170
171            quote! {
172                #[automatically_derived]
173                impl ::core::marker::StructuralPartialEq for #ident {}
174                #[automatically_derived]
175                impl ::core::cmp::PartialEq for #ident {
176                    #[inline]
177                    fn eq(&self, other: &#ident) -> bool {
178                        #body
179                    }
180                }
181            }
182        }
183        Data::Struct(data) => {
184            let mut is_first = true;
185
186            let fields = match data.fields {
187                Fields::Named(fields) => {
188                    fields
189                        .named
190                        .iter()
191                        .map(|field| {
192                            let ident = field.ident.as_ref().unwrap();
193
194                            partial_eq_impl(&field.ty, &quote! { self.#ident }, &quote! { other.#ident }, &mut is_first)
195                        })
196                        .collect::<TokenStream2>()
197                }
198                Fields::Unnamed(fields) => {
199                    fields
200                        .unnamed
201                        .iter()
202                        .enumerate()
203                        .map(|(i, field)| {
204                            let index = Index { index: i as _, span: field.span() };
205                            partial_eq_impl(&field.ty, &quote! { self.#index }, &quote! { other.#index }, &mut is_first)
206                        })
207                        .collect::<TokenStream2>()
208                }
209                Fields::Unit => TokenStream2::new()
210            };
211
212            quote! {
213                #[automatically_derived]
214                impl ::core::marker::StructuralPartialEq for #ident {}
215                #[automatically_derived]
216                impl ::core::cmp::PartialEq for #ident {
217                    #[inline]
218                    fn eq(&self, other: &#ident) -> bool {
219                        #fields
220                    }
221                }
222            }
223        }
224        Data::Union(_) => panic!("this trait cannot be derived for unions")
225    })
226}
227
228#[proc_macro_derive(FloatEq)]
229pub fn derive_eq(input: TokenStream) -> TokenStream {
230    let input = parse_macro_input!(input as DeriveInput);
231    let ident = &input.ident;
232
233    TokenStream::from(quote! {
234        #[automatically_derived]
235        impl ::std::cmp::Eq for #ident {}
236    })
237}
238
239fn hash_impl(ty: &Type, tokens: &impl ToTokens) -> TokenStream2 {
240    if is_float_type(ty) {
241        quote! {
242            ::float_derive::utils::hash(#tokens, state);
243        }
244    } else {
245        quote! {
246            ::core::hash::Hash::hash(#tokens, state);
247        }
248    }
249}
250
251#[proc_macro_derive(FloatHash)]
252pub fn derive_hash(input: TokenStream) -> TokenStream {
253    let input = parse_macro_input!(input as DeriveInput);
254    let ident = &input.ident;
255
256    TokenStream::from(match input.data {
257        Data::Enum(data) => {
258            let variants = {
259                let mut has_non_unit_variants = false;
260                let mut has_unit_variants = false;
261
262                let variants = data
263                    .variants
264                    .iter()
265                    .map(|variant| {
266                        let variant_ident = &variant.ident;
267
268                        match &variant.fields {
269                            Fields::Named(fields) => {
270                                let args = fields.named.iter().map(|field| {
271                                    let field_ident = field.ident.as_ref().unwrap();
272
273                                    quote! {
274                                        #field_ident,
275                                    }
276                                });
277
278                                let impls = fields.named.iter().map(|field| {
279                                    let field_ident = field.ident.as_ref().unwrap();
280                                    hash_impl(&field.ty, &quote! { #field_ident})
281                                });
282
283                                has_non_unit_variants = true;
284                                quote! {
285                                    #ident::#variant_ident { #(#args)* } => { #(#impls)* }
286                                }
287                            }
288                            Fields::Unnamed(fields) => {
289                                let args = fields.unnamed.iter().enumerate().map(|(i, field)| {
290                                    let field_ident = Ident::new(&format!("__self_{i}"), field.span());
291
292                                    quote! {
293                                        #field_ident,
294                                    }
295                                });
296
297                                let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
298                                    let field_ident = Ident::new(&format!("__self_{i}"), field.span());
299                                    hash_impl(&field.ty, &quote! { #field_ident})
300                                });
301
302                                has_non_unit_variants = true;
303                                quote! {
304                                    #ident::#variant_ident(#(#args)*) => { #(#impls)* }
305                                }
306                            }
307                            Fields::Unit => {
308                                has_unit_variants = true;
309                                TokenStream2::new()
310                            }
311                        }
312                    })
313                    .collect::<TokenStream2>();
314
315                let default_pattern = if has_unit_variants {
316                    quote! { _ => () }
317                } else {
318                    TokenStream2::new()
319                };
320
321                if has_non_unit_variants {
322                    quote! {
323                        match self {
324                            #variants
325                            #default_pattern
326                        }
327                    }
328                } else {
329                    TokenStream2::new()
330                }
331            };
332
333            quote! {
334                #[automatically_derived]
335                impl ::std::hash::Hash for #ident {
336                    fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
337                        let __self_tag = ::core::intrinsics::discriminant_value(self);
338                        ::core::hash::Hash::hash(&__self_tag, state);
339
340                        #variants
341                    }
342                }
343            }
344        }
345        Data::Struct(data) => {
346            let fields = match data.fields {
347                Fields::Named(fields) => {
348                    fields
349                        .named
350                        .iter()
351                        .map(|field| {
352                            let ident = field.ident.as_ref().unwrap();
353                            hash_impl(&field.ty, &quote! { &self.#ident})
354                        })
355                        .collect::<TokenStream2>()
356                }
357                Fields::Unnamed(fields) => {
358                    fields
359                        .unnamed
360                        .iter()
361                        .enumerate()
362                        .map(|(i, field)| {
363                            let index = Index { index: i as _, span: field.span() };
364                            hash_impl(&field.ty, &quote! { &self.#index})
365                        })
366                        .collect::<TokenStream2>()
367                }
368                Fields::Unit => TokenStream2::new()
369            };
370
371            quote! {
372                #[automatically_derived]
373                impl ::std::hash::Hash for #ident {
374                    fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
375                        #fields
376                    }
377                }
378            }
379        }
380        Data::Union(_) => panic!("this trait cannot be derived for unions")
381    })
382}