float_derive_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{quote, ToTokens};
4use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Index, Type};
5
6fn is_float_type(ty: &Type) -> bool {
7 if let Type::Path(path) = ty {
8 let segments = &path.path.segments;
9 if segments.len() == 1 {
10 return segments[0].ident == "f32" || segments[0].ident == "f64"
11 }
12 }
13 false
14}
15
16fn partial_eq_impl(ty: &Type, self_tokens: &impl ToTokens, other_tokens: &impl ToTokens, is_first: &mut bool) -> TokenStream2 {
17 let first_tokens = if *is_first {
18 TokenStream2::new()
19 } else {
20 quote! { && }
21 };
22
23 let result = if is_float_type(ty) {
24 quote! { #first_tokens ::float_derive::utils::eq(#self_tokens, #other_tokens)}
25 } else {
26 quote! {
27 #first_tokens #self_tokens == #other_tokens
28 }
29 };
30 *is_first = false;
31 result
32}
33
34#[proc_macro_derive(FloatPartialEq)]
35pub fn derive_partial_eq(input: TokenStream) -> TokenStream {
36 let input = parse_macro_input!(input as DeriveInput);
37 let ident = &input.ident;
38
39 TokenStream::from(match input.data {
40 Data::Enum(data) => {
41 let mut num_non_unit_variants = 0;
42 let mut num_unit_variants = 0;
43
44 let variants = data
45 .variants
46 .iter()
47 .map(|variant| {
48 let variant_ident = &variant.ident;
49 let mut is_first = true;
50
51 match &variant.fields {
52 Fields::Named(fields) => {
53 let self_args = fields.named.iter().enumerate().map(|(i, field)| {
54 let field_ident = &field.ident.as_ref().unwrap();
55 let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
56
57 quote! {
58 #field_ident: #arg_ident,
59 }
60 });
61
62 let other_args = fields.named.iter().enumerate().map(|(i, field)| {
63 let field_ident = &field.ident.as_ref().unwrap();
64 let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
65
66 quote! {
67 #field_ident: #arg_ident,
68 }
69 });
70
71 let impls = fields.named.iter().enumerate().map(|(i, field)| {
72 let self_ident = Ident::new(&format!("__self_{i}"), field.span());
73 let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
74
75 partial_eq_impl(&field.ty, "e! { *#self_ident}, "e! { *#other_ident }, &mut is_first)
76 });
77
78 num_non_unit_variants += 1;
79 quote! {
80 (
81 #ident::#variant_ident { #(#self_args)* },
82 #ident::#variant_ident { #(#other_args)* }
83 ) => {
84 #(#impls)*
85 }
86 }
87 }
88 Fields::Unnamed(fields) => {
89 let self_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
90 let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
91
92 quote! {
93 #arg_ident,
94 }
95 });
96
97 let other_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
98 let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
99
100 quote! {
101 #arg_ident,
102 }
103 });
104
105 let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
106 let self_ident = Ident::new(&format!("__self_{i}"), field.span());
107 let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
108
109 partial_eq_impl(&field.ty, "e! { *#self_ident}, "e! { *#other_ident }, &mut is_first)
110 });
111
112 num_non_unit_variants += 1;
113 quote! {
114 (
115 #ident::#variant_ident(#(#self_args)*),
116 #ident::#variant_ident(#(#other_args)*)
117 ) => {
118 #(#impls)*
119 }
120 }
121 }
122 Fields::Unit => {
123 num_unit_variants += 1;
124 TokenStream2::new()
125 }
126 }
127 })
128 .collect::<TokenStream2>();
129
130 let num_variants = num_non_unit_variants + num_unit_variants;
131
132 let body = if num_non_unit_variants == 0 && num_variants < 2 {
133 quote! { true }
134 } else {
135 let default_pattern = if num_unit_variants == 0 {
136 quote! {
137 _ => unsafe { ::core::intrinsics::unreachable() }
138 }
139 } else {
140 quote! {
141 _ => true
142 }
143 };
144
145 let matched_variants = quote! {
146 match (self, other) {
147 #variants
148 #default_pattern
149 }
150 };
151
152 if num_variants > 1 {
153 let tags = quote! {
154 let __self_tag = ::core::intrinsics::discriminant_value(self);
155 let __arg1_tag = ::core::intrinsics::discriminant_value(other);
156 __self_tag == __arg1_tag
157 };
158
159 if num_non_unit_variants > 0 {
160 quote! {
161 #tags && #matched_variants
162 }
163 } else {
164 tags
165 }
166 } else {
167 matched_variants
168 }
169 };
170
171 quote! {
172 #[automatically_derived]
173 impl ::core::marker::StructuralPartialEq for #ident {}
174 #[automatically_derived]
175 impl ::core::cmp::PartialEq for #ident {
176 #[inline]
177 fn eq(&self, other: &#ident) -> bool {
178 #body
179 }
180 }
181 }
182 }
183 Data::Struct(data) => {
184 let mut is_first = true;
185
186 let fields = match data.fields {
187 Fields::Named(fields) => {
188 fields
189 .named
190 .iter()
191 .map(|field| {
192 let ident = field.ident.as_ref().unwrap();
193
194 partial_eq_impl(&field.ty, "e! { self.#ident }, "e! { other.#ident }, &mut is_first)
195 })
196 .collect::<TokenStream2>()
197 }
198 Fields::Unnamed(fields) => {
199 fields
200 .unnamed
201 .iter()
202 .enumerate()
203 .map(|(i, field)| {
204 let index = Index { index: i as _, span: field.span() };
205 partial_eq_impl(&field.ty, "e! { self.#index }, "e! { other.#index }, &mut is_first)
206 })
207 .collect::<TokenStream2>()
208 }
209 Fields::Unit => TokenStream2::new()
210 };
211
212 quote! {
213 #[automatically_derived]
214 impl ::core::marker::StructuralPartialEq for #ident {}
215 #[automatically_derived]
216 impl ::core::cmp::PartialEq for #ident {
217 #[inline]
218 fn eq(&self, other: &#ident) -> bool {
219 #fields
220 }
221 }
222 }
223 }
224 Data::Union(_) => panic!("this trait cannot be derived for unions")
225 })
226}
227
228#[proc_macro_derive(FloatEq)]
229pub fn derive_eq(input: TokenStream) -> TokenStream {
230 let input = parse_macro_input!(input as DeriveInput);
231 let ident = &input.ident;
232
233 TokenStream::from(quote! {
234 #[automatically_derived]
235 impl ::std::cmp::Eq for #ident {}
236 })
237}
238
239fn hash_impl(ty: &Type, tokens: &impl ToTokens) -> TokenStream2 {
240 if is_float_type(ty) {
241 quote! {
242 ::float_derive::utils::hash(#tokens, state);
243 }
244 } else {
245 quote! {
246 ::core::hash::Hash::hash(#tokens, state);
247 }
248 }
249}
250
251#[proc_macro_derive(FloatHash)]
252pub fn derive_hash(input: TokenStream) -> TokenStream {
253 let input = parse_macro_input!(input as DeriveInput);
254 let ident = &input.ident;
255
256 TokenStream::from(match input.data {
257 Data::Enum(data) => {
258 let variants = {
259 let mut has_non_unit_variants = false;
260 let mut has_unit_variants = false;
261
262 let variants = data
263 .variants
264 .iter()
265 .map(|variant| {
266 let variant_ident = &variant.ident;
267
268 match &variant.fields {
269 Fields::Named(fields) => {
270 let args = fields.named.iter().map(|field| {
271 let field_ident = field.ident.as_ref().unwrap();
272
273 quote! {
274 #field_ident,
275 }
276 });
277
278 let impls = fields.named.iter().map(|field| {
279 let field_ident = field.ident.as_ref().unwrap();
280 hash_impl(&field.ty, "e! { #field_ident})
281 });
282
283 has_non_unit_variants = true;
284 quote! {
285 #ident::#variant_ident { #(#args)* } => { #(#impls)* }
286 }
287 }
288 Fields::Unnamed(fields) => {
289 let args = fields.unnamed.iter().enumerate().map(|(i, field)| {
290 let field_ident = Ident::new(&format!("__self_{i}"), field.span());
291
292 quote! {
293 #field_ident,
294 }
295 });
296
297 let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
298 let field_ident = Ident::new(&format!("__self_{i}"), field.span());
299 hash_impl(&field.ty, "e! { #field_ident})
300 });
301
302 has_non_unit_variants = true;
303 quote! {
304 #ident::#variant_ident(#(#args)*) => { #(#impls)* }
305 }
306 }
307 Fields::Unit => {
308 has_unit_variants = true;
309 TokenStream2::new()
310 }
311 }
312 })
313 .collect::<TokenStream2>();
314
315 let default_pattern = if has_unit_variants {
316 quote! { _ => () }
317 } else {
318 TokenStream2::new()
319 };
320
321 if has_non_unit_variants {
322 quote! {
323 match self {
324 #variants
325 #default_pattern
326 }
327 }
328 } else {
329 TokenStream2::new()
330 }
331 };
332
333 quote! {
334 #[automatically_derived]
335 impl ::std::hash::Hash for #ident {
336 fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
337 let __self_tag = ::core::intrinsics::discriminant_value(self);
338 ::core::hash::Hash::hash(&__self_tag, state);
339
340 #variants
341 }
342 }
343 }
344 }
345 Data::Struct(data) => {
346 let fields = match data.fields {
347 Fields::Named(fields) => {
348 fields
349 .named
350 .iter()
351 .map(|field| {
352 let ident = field.ident.as_ref().unwrap();
353 hash_impl(&field.ty, "e! { &self.#ident})
354 })
355 .collect::<TokenStream2>()
356 }
357 Fields::Unnamed(fields) => {
358 fields
359 .unnamed
360 .iter()
361 .enumerate()
362 .map(|(i, field)| {
363 let index = Index { index: i as _, span: field.span() };
364 hash_impl(&field.ty, "e! { &self.#index})
365 })
366 .collect::<TokenStream2>()
367 }
368 Fields::Unit => TokenStream2::new()
369 };
370
371 quote! {
372 #[automatically_derived]
373 impl ::std::hash::Hash for #ident {
374 fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
375 #fields
376 }
377 }
378 }
379 }
380 Data::Union(_) => panic!("this trait cannot be derived for unions")
381 })
382}