approx_derive/lib.rs
1#![deny(missing_docs)]
2//! This crate provides derive macros for the
3//! [AbsDiffEq](https://docs.rs/approx/latest/approx/trait.AbsDiffEq.html) and
4//! [RelativeEq](https://docs.rs/approx/latest/approx/trait.RelativeEq.html) traits of the
5//! [approx](https://docs.rs/approx/latest/approx/) crate.
6//!
7//! These derive macros only implement both traits with `...<Rhs = Self>`.
8//! The macros infer the `EPSILON` type of the [AbsDiffEq] trait by looking
9//! at the type of the first struct or enum field or any type specified by the user.
10//!
11//! This table lists all attributes which can be used to customize the derived traits.
12//! They are ordered in descending priority, meaning setting the `#[approx(equal)]` will overwrite
13//! any specifications made in the `#[approx(map = ...)]` attribute.
14//!
15//! | Field Attribute | Functionality |
16//! |:--- | --- |
17//! | [`#[approx(skip)]`](#skipping-fields) | Skips the field entirely |
18//! | [`#[approx(equal)]`](#testing-for-equality) | Checks this field with `==` for Equality |
19//! | [`#[approx(cast_field)]`](#casting-fields) | Casts the field with `.. as ..` syntax. |
20//! | [`#[approx(map = ..)]`](#mapping-values) | Maps values before comparing them. |
21//! | [`#[approx(static_epsilon = ..)]`](#static-values) | Defines a static epsilon value for this particular field. |
22//! | | |
23//! | **Object Attribute** | |
24//! | [`#[approx(default_epsilon = ...)]`](#default-epsilon) | Sets the default epsilon value |
25//! | [`#[approx(default_max_relative = ...)]`](#default-max-relative) | Sets the default `max_relative` value. |
26//! | [`#[approx(epsilon_type = ...)]`](#epsilon-type) | Sets the type of the epsilon value |
27//!
28//! # Usage
29//!
30//! ```
31//! use approx_derive::AbsDiffEq;
32//!
33//! // Define a new type and derive the AbsDiffEq trait
34//! #[derive(AbsDiffEq, PartialEq, Debug)]
35//! struct Position {
36//! x: f64,
37//! y: f64
38//! }
39//!
40//! // Compare if two given positions match
41//! // with respect to geiven epsilon.
42//! let p1 = Position { x: 1.01, y: 2.36 };
43//! let p2 = Position { x: 0.99, y: 2.38 };
44//! approx::assert_abs_diff_eq!(p1, p2, epsilon = 0.021);
45//! ```
46//! In this case, the generated code looks something like this:
47//! ```ignore
48//! const _ : () =
49//! {
50//! #[automatically_derived] impl approx :: AbsDiffEq for Position
51//! {
52//! type Epsilon = <f64 as approx::AbsDiffEq>::Epsilon;
53//!
54//! fn default_epsilon() -> Self :: Epsilon {
55//! <f64 as approx::AbsDiffEq>::default_epsilon()
56//! }
57//!
58//! fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
59//! <f64 as approx::AbsDiffEq>::abs_diff_eq(
60//! &self.x,
61//! & other.x,
62//! epsilon.clone()
63//! ) &&
64//! <f64 as approx::AbsDiffEq>::abs_diff_eq(
65//! &self.y,
66//! &other.y,
67//! epsilon.clone()
68//! ) && true
69//! }
70//! }
71//! };
72//! ```
73//! The [AbsDiffEq] derive macro calls the `abs_diff_eq` method repeatedly on all fields
74//! to determine if all are matching.
75//!
76//! ## Enums
77//! Since `approx-derive` supports enums since `0.2`
78//!
79//! ```
80//! use approx_derive::AbsDiffEq;
81//!
82//! #[derive(AbsDiffEq, PartialEq, Debug)]
83//! enum Position {
84//! Smooth { x: f32, y: f32, },
85//! #[approx(cast_value)]
86//! Lattice { x: isize, y: isize },
87//! }
88//!
89//! let p1 = Position::Smooth { x: 1.0, y: 1.1 };
90//! let p2 = Position::Smooth { x: 1.1, y: 1.0};
91//! let p3 = Position::Lattice { x: 1, y: 1 };
92//!
93//! approx::assert_abs_diff_eq!(p1, p2, epsilon=0.2);
94//! ```
95//!
96//! ```should_panic
97//! # use approx_derive::AbsDiffEq;
98//! # #[derive(AbsDiffEq, PartialEq, Debug)]
99//! # enum Position {
100//! # Smooth { x: f32, y: f32, },
101//! # #[approx(cast_value)]
102//! # Lattice { x: isize, y: isize },
103//! # }
104//! # let p1 = Position::Smooth { x: 1.0, y: 1.1 };
105//! # let p3 = Position::Lattice { x: 1, y: 1 };
106//! // Note! Different enum variants can never be equal!
107//! approx::assert_abs_diff_eq!(p1, p3, epsilon = 1000.0);
108//! ```
109//!
110//!
111//! # Field Attributes
112//! ## Skipping Fields
113//!
114//! Sometimes, we only want to compare certain fields and omit others completely.
115//! ```
116//! # use approx_derive::*;
117//! #[derive(AbsDiffEq, PartialEq, Debug)]
118//! struct Player {
119//! hit_points: f32,
120//! pos_x: f32,
121//! pos_y: f32,
122//! #[approx(skip)]
123//! id: (usize, usize),
124//! }
125//!
126//! let player1 = Player {
127//! hit_points: 100.0,
128//! pos_x: 2.0,
129//! pos_y: -650.345,
130//! id: (0, 1),
131//! };
132//!
133//! let player2 = Player {
134//! hit_points: 99.9,
135//! pos_x: 2.001,
136//! pos_y: -649.898,
137//! id: (22, 0),
138//! };
139//!
140//! approx::assert_abs_diff_eq!(player1, player2, epsilon = 0.5);
141//! ```
142//!
143//! ## Testing for [Equality](core::cmp::Eq)
144//!
145//! When identical equality is desired, we can specify this with the `#[approx(equal)]` attribute.
146//!
147//! ```
148//! # use approx_derive::*;
149//! #[derive(AbsDiffEq, PartialEq, Debug)]
150//! struct Prediction {
151//! confidence: f64,
152//! #[approx(equal)]
153//! category: String,
154//! }
155//! ```
156//!
157//! Note that in this case, the type of the epsilon value for the implementation of
158//! [AbsDiffEq](https://docs.rs/approx/latest/approx/trait.AbsDiffEq.html) is inferred from the
159//! first field of the `Prediction` struct.
160//! This means if we reorder the arguments of the struct, we need to manually set the epsilon type.
161//!
162//! ```
163//! # use approx_derive::*;
164//! #[derive(AbsDiffEq, PartialEq, Debug)]
165//! #[approx(epsilon_type = f64)]
166//! struct Prediction {
167//! #[approx(equal)]
168//! category: String,
169//! confidence: f64,
170//! }
171//! ```
172//!
173//! ## Casting Fields
174//!
175//! Structs which consist of multiple fields with different
176//! numeric types, can not be derived without additional hints.
177//! After all, we should specify how this type mismatch will be handled.
178//!
179//! ```compile_fail
180//! # use approx_derive::*;
181//! #[derive(AbsDiffEq, PartialEq, Debug)]
182//! struct MyStruct {
183//! v1: f32,
184//! v2: f64,
185//! }
186//! ```
187//!
188//! We can use the `#[approx(cast_field)]` and `#[approx(cast_value)]`
189//! attributes to achieve this goal.
190//! ```
191//! # use approx_derive::*;
192//! #[derive(AbsDiffEq, PartialEq, Debug)]
193//! struct MyStruct {
194//! v1: f32,
195//! #[approx(cast_field)]
196//! v2: f64,
197//! }
198//! ```
199//! Now the second field will be casted to the type of the inferred epsilon value (`f32`).
200//! We can check this by testing if a change in the size of `f64::MIN_POSITIVE` would get lost by
201//! this procedure.
202//! ```
203//! # use approx_derive::*;
204//! # #[derive(RelativeEq, PartialEq, Debug)]
205//! # struct MyStruct {
206//! # v1: f32,
207//! # #[approx(cast_field)]
208//! # v2: f64,
209//! # }
210//! let ms1 = MyStruct {
211//! v1: 1.0,
212//! v2: 3.0,
213//! };
214//! let ms2 = MyStruct {
215//! v1: 1.0,
216//! v2: 3.0 + f64::MIN_POSITIVE,
217//! };
218//! approx::assert_relative_eq!(ms1, ms2);
219//! ```
220//!
221//! ## Mapping Values
222//!
223//! We can map values before comparing them.
224//! By default, we need to return an option of the value in question.
225//! This allows to do computations where error can occur.
226//! Although this error is not caught, the comparison will fail if any of the two compared objects
227//! return a `None` value.
228//! ```
229//! # use approx_derive::*;
230//! # use approx::*;
231//! #[derive(AbsDiffEq, PartialEq, Debug)]
232//! struct Tower {
233//! height_in_meters: f32,
234//! #[approx(map = |x: &f32| Some(x.sqrt()))]
235//! area_in_meters_squared: f32,
236//! }
237//! # let t1 = Tower {
238//! # height_in_meters: 100.0,
239//! # area_in_meters_squared: 30.1,
240//! # };
241//! # let t2 = Tower {
242//! # height_in_meters: 100.0,
243//! # area_in_meters_squared: 30.5,
244//! # };
245//! # approx::assert_abs_diff_ne!(t1, t2, epsilon = 0.03);
246//! ```
247//!
248//! This functionality can also be useful when having more complex datatypes.
249//! ```
250//! # use approx_derive::*;
251//! # use approx::*;
252//! #[derive(PartialEq, Debug)]
253//! enum Time {
254//! Years(u16),
255//! Months(u16),
256//! Weeks(u16),
257//! Days(u16),
258//! }
259//!
260//! fn time_to_days(time: &Time) -> Option<u16> {
261//! match time {
262//! Time::Years(y) => Some(365 * y),
263//! Time::Months(m) => Some(30 * m),
264//! Time::Weeks(w) => Some(7 * w),
265//! Time::Days(d) => Some(*d),
266//! }
267//! }
268//!
269//! #[derive(AbsDiffEq, PartialEq, Debug)]
270//! #[approx(epsilon_type = u16)]
271//! struct Dog {
272//! #[approx(map = time_to_days)]
273//! age: Time,
274//! #[approx(map = time_to_days)]
275//! next_doctors_appointment: Time,
276//! }
277//! ```
278//!
279//! ## Static Values
280//! We can force a static `EPSILON` or `max_relative` value for individual fields.
281//! ```
282//! # use approx_derive::*;
283//! #[derive(AbsDiffEq, PartialEq, Debug)]
284//! struct Rectangle {
285//! #[approx(static_epsilon = 5e-2)]
286//! a: f64,
287//! b: f64,
288//! #[approx(static_epsilon = 7e-2)]
289//! c: f64,
290//! }
291//!
292//! let r1 = Rectangle {
293//! a: 100.01,
294//! b: 40.0001,
295//! c: 30.055,
296//! };
297//! let r2 = Rectangle {
298//! a: 99.97,
299//! b: 40.0005,
300//! c: 30.049,
301//! };
302//!
303//! // This is always true although the epsilon is smaller than the
304//! // difference between fields a and b respectively.
305//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-1);
306//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-2);
307//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-3);
308//!
309//! // Here, the epsilon value has become larger than the difference between the
310//! // b field values.
311//! approx::assert_abs_diff_ne!(r1, r2, epsilon = 1e-4);
312//! ```
313//! # Object Attributes
314//! ## Default Epsilon
315//! The [AbsDiffEq] trait allows to specify a default value for its `EPSILON` associated type.
316//! We can control this value by specifying it on an object level.
317//!
318//! ```
319//! # use approx_derive::*;
320//! #[derive(AbsDiffEq, PartialEq, Debug)]
321//! #[approx(default_epsilon = 10)]
322//! struct Benchmark {
323//! cycles: u64,
324//! warm_up: u64,
325//! }
326//!
327//! let benchmark1 = Benchmark {
328//! cycles: 248,
329//! warm_up: 36,
330//! };
331//! let benchmark2 = Benchmark {
332//! cycles: 239,
333//! warm_up: 28,
334//! };
335//!
336//! // When testing with not additional arguments, the results match
337//! approx::assert_abs_diff_eq!(benchmark1, benchmark2);
338//! // Once we specify a lower epsilon, the values do not agree anymore.
339//! approx::assert_abs_diff_ne!(benchmark1, benchmark2, epsilon = 5);
340//! ```
341//!
342//! ## Default Max Relative
343//! Similarly to [Default Epsilon], we can also choose a default max_relative devaition.
344//! ```
345//! # use approx_derive::*;
346//! #[derive(RelativeEq, PartialEq, Debug)]
347//! #[approx(default_max_relative = 0.1)]
348//! struct Benchmark {
349//! time: f32,
350//! warm_up: f32,
351//! }
352//!
353//! let bench1 = Benchmark {
354//! time: 3.502785781,
355//! warm_up: 0.58039458,
356//! };
357//! let bench2 = Benchmark {
358//! time: 3.7023458,
359//! warm_up: 0.59015897,
360//! };
361//!
362//! approx::assert_relative_eq!(bench1, bench2);
363//! approx::assert_relative_ne!(bench1, bench2, max_relative = 0.05);
364//! ```
365//! ## Epsilon Type
366//! When specifying nothing, the macros will infer the `EPSILON` type from the type of the
367//! first struct/enum field (the order in which it is parsed).
368//! This can be problematic in certain scenarios which is why we can also manually specify this
369//! type.
370//!
371//! ```
372//! # use approx_derive::*;
373//! #[derive(RelativeEq, PartialEq, Debug)]
374//! #[approx(epsilon_type = f32)]
375//! struct Car {
376//! #[approx(cast_field)]
377//! produced_year: u32,
378//! horse_power: f32,
379//! }
380//!
381//! let car1 = Car {
382//! produced_year: 1992,
383//! horse_power: 122.87,
384//! };
385//! let car2 = Car {
386//! produced_year: 2000,
387//! horse_power: 117.45,
388//! };
389//!
390//! approx::assert_relative_eq!(car1, car2, max_relative = 0.05);
391//! approx::assert_relative_ne!(car1, car2, max_relative = 0.01);
392//! ```
393
394mod args_parsing;
395
396use args_parsing::*;
397
398enum BaseType {
399 Struct {
400 item_struct: syn::ItemStruct,
401 fields_with_args: Vec<FieldWithArgs>,
402 },
403 Enum {
404 item_enum: syn::ItemEnum,
405 variants_with_args: Vec<EnumVariant>,
406 },
407}
408
409impl syn::parse::Parse for BaseType {
410 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
411 if input.fork().parse::<syn::ItemStruct>().is_ok() {
412 use syn::spanned::Spanned;
413 let item_struct: syn::ItemStruct = input.parse()?;
414 let fields_with_args = match item_struct.fields.clone() {
415 syn::Fields::Named(named_fields) => named_fields
416 .named
417 .iter()
418 .map(FieldWithArgs::from_field)
419 .collect::<syn::Result<Vec<_>>>(),
420 syn::Fields::Unnamed(unnamed_fields) => unnamed_fields
421 .unnamed
422 .iter()
423 .map(FieldWithArgs::from_field)
424 .collect::<syn::Result<Vec<_>>>(),
425 syn::Fields::Unit => Err(syn::Error::new(
426 item_struct.span(),
427 "cannot derive from unit struct",
428 )),
429 }?;
430 Ok(BaseType::Struct {
431 item_struct,
432 fields_with_args,
433 })
434 } else if let Ok(item_enum) = input.parse::<syn::ItemEnum>() {
435 // let item_enum: syn::ItemEnum = input.parse()?;
436 let variants_with_args = item_enum
437 .variants
438 .iter()
439 .map(|v| {
440 let args = FieldArgs::from_attrs(&v.attrs)?;
441 let fields_with_args = v
442 .fields
443 .iter()
444 .map(|f| {
445 let mut fwa = FieldWithArgs::from_field(f)?;
446 fwa.args.patch_if_not_exists(&args);
447 Ok(fwa)
448 })
449 .collect::<syn::Result<Vec<_>>>()?;
450 Ok(EnumVariant {
451 fields_with_args,
452 ident: v.ident.clone(),
453 discriminant: v.discriminant.clone().map(|x| x.1),
454 })
455 })
456 .collect::<syn::Result<Vec<_>>>()?;
457 Ok(BaseType::Enum {
458 item_enum,
459 variants_with_args,
460 })
461 } else {
462 Err(syn::Error::new(
463 input.span(),
464 "Could not parse enum or struct",
465 ))
466 }
467 }
468}
469
470impl BaseType {
471 fn attrs(&self) -> &Vec<syn::Attribute> {
472 match self {
473 #[allow(unused)]
474 BaseType::Struct {
475 item_struct,
476 fields_with_args,
477 } => &item_struct.attrs,
478 #[allow(unused)]
479 BaseType::Enum {
480 item_enum,
481 variants_with_args,
482 } => &item_enum.attrs,
483 }
484 }
485
486 fn generics(&self) -> &syn::Generics {
487 match self {
488 #[allow(unused)]
489 BaseType::Struct {
490 item_struct,
491 fields_with_args,
492 } => &item_struct.generics,
493 #[allow(unused)]
494 BaseType::Enum {
495 item_enum,
496 variants_with_args,
497 } => &item_enum.generics,
498 }
499 }
500
501 fn ident(&self) -> &syn::Ident {
502 match self {
503 #[allow(unused)]
504 BaseType::Struct {
505 item_struct,
506 fields_with_args,
507 } => &item_struct.ident,
508 #[allow(unused)]
509 BaseType::Enum {
510 item_enum,
511 variants_with_args,
512 } => &item_enum.ident,
513 }
514 }
515}
516
517struct AbsDiffEqParser {
518 base_type: BaseType,
519 struct_args: StructArgs,
520}
521
522impl syn::parse::Parse for AbsDiffEqParser {
523 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
524 let base_type: BaseType = input.parse()?;
525 let struct_args = StructArgs::from_attrs(base_type.attrs())?;
526 Ok(Self {
527 base_type,
528 struct_args,
529 })
530 }
531}
532
533#[derive(Debug)]
534struct FieldFormatted {
535 base_type: proc_macro2::TokenStream,
536 own_field: proc_macro2::TokenStream,
537 other_field: proc_macro2::TokenStream,
538 epsilon: proc_macro2::TokenStream,
539 max_relative: proc_macro2::TokenStream,
540 mapping: Option<proc_macro2::TokenStream>,
541 set_equal: bool,
542}
543
544impl AbsDiffEqParser {
545 fn get_epsilon_parent_type(&self) -> proc_macro2::TokenStream {
546 self.struct_args
547 .epsilon_type
548 .clone()
549 .map(|x| quote::quote!(#x))
550 .or_else(|| {
551 #[allow(unused)]
552 match &self.base_type {
553 BaseType::Struct {
554 item_struct,
555 fields_with_args,
556 } => fields_with_args
557 .iter()
558 .find(|f| f.args.skip.is_none_or(|x| x == false)),
559 BaseType::Enum {
560 item_enum,
561 variants_with_args,
562 } => variants_with_args
563 .iter()
564 .flat_map(|v| v.fields_with_args.iter())
565 .find(|f| f.args.skip.is_none_or(|x| x == false)),
566 }
567 .map(|field| {
568 let field_type = &field.ty;
569 quote::quote!(#field_type)
570 })
571 })
572 .or_else(|| Some(quote::quote!(f64)))
573 .unwrap()
574 }
575
576 fn get_derived_epsilon_type(&self) -> proc_macro2::TokenStream {
577 let parent = self.get_epsilon_parent_type();
578 quote::quote!(<#parent as approx::AbsDiffEq>::Epsilon)
579 }
580
581 fn get_epsilon_type_and_default_value(
582 &self,
583 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
584 let parent = self.get_epsilon_parent_type();
585 let epsilon_type = self.get_derived_epsilon_type();
586 let epsilon_default_value = self
587 .struct_args
588 .default_epsilon_value
589 .clone()
590 .map(|x| quote::quote!(#x))
591 .or_else(|| Some(quote::quote!(<#parent as approx::AbsDiffEq>::default_epsilon())))
592 .unwrap();
593 (epsilon_type, epsilon_default_value)
594 }
595
596 fn generics_involved(&self) -> bool {
597 let parent = self.get_epsilon_parent_type();
598 self.base_type
599 .generics()
600 .params
601 .iter()
602 .any(|param| quote::quote!(#param).to_string() == parent.to_string())
603 }
604
605 fn get_max_relative_default_value(&self) -> proc_macro2::TokenStream {
606 let epsilon_type = self.get_epsilon_parent_type();
607 self.struct_args
608 .default_max_relative_value
609 .clone()
610 .map(|x| quote::quote!(#x))
611 .or_else(|| {
612 Some(quote::quote!(<#epsilon_type as approx::RelativeEq>::default_max_relative()))
613 })
614 .unwrap()
615 }
616
617 fn format_nth_field(
618 &self,
619 n: usize,
620 field_with_args: &FieldWithArgs,
621 idents: Option<(syn::Ident, syn::Ident)>,
622 ) -> Option<FieldFormatted> {
623 // Determine if this field will be skipped and exit early
624 if let Some(true) = field_with_args.args.skip {
625 return None;
626 }
627
628 // Get types for epsilon and max_relative
629 let parent_type = self.get_epsilon_parent_type();
630
631 // Save field name and type in variables for easy access
632 use std::str::FromStr;
633 let (field_name1, field_name2) = match (&field_with_args.ident, idents) {
634 (Some(id), None) => (quote::quote!(self.#id), quote::quote!(other.#id)),
635 (None, None) => {
636 let field_number = proc_macro2::TokenStream::from_str(&format!("{}", n)).unwrap();
637 (
638 quote::quote!(self.#field_number),
639 quote::quote!(other.#field_number),
640 )
641 }
642 (_, Some((id1, id2))) => (quote::quote!(#id1), quote::quote!(#id2)),
643 };
644 let field_type = &field_with_args.ty;
645
646 // Determine if the field or the value will be casted in any way
647 let cast_strategy = &field_with_args.args.cast_strategy;
648
649 // Get static values (if present) for epsilon and max_relative
650 let epsilon = &field_with_args
651 .args
652 .epsilon_static_value
653 .clone()
654 .map(|x| quote::quote!(#x))
655 .or_else(|| Some(quote::quote!(epsilon)))
656 .unwrap();
657 let max_relative = field_with_args
658 .args
659 .max_relative_static_value
660 .clone()
661 .map(|x| quote::quote!(#x))
662 .or_else(|| Some(quote::quote!(max_relative)))
663 .unwrap();
664
665 // Use the casting strategy
666 let (base_type, own_field, other_field, epsilon, max_relative) = match cast_strategy {
667 Some(TypeCast::CastField) => (
668 quote::quote!(#parent_type),
669 quote::quote!(&(#field_name1.clone() as #parent_type)),
670 quote::quote!(&(#field_name2.clone() as #parent_type)),
671 quote::quote!(#epsilon.clone()),
672 quote::quote!(#max_relative.clone()),
673 ),
674 Some(TypeCast::CastValue) => (
675 quote::quote!(#field_type),
676 quote::quote!(&#field_name1),
677 quote::quote!(&#field_name2),
678 quote::quote!((#epsilon.clone() as #field_type)),
679 quote::quote!((#max_relative.clone() as #field_type)),
680 ),
681 None => (
682 quote::quote!(#parent_type),
683 quote::quote!(&#field_name1),
684 quote::quote!(&#field_name2),
685 quote::quote!(#epsilon.clone()),
686 quote::quote!(#max_relative.clone()),
687 ),
688 };
689
690 let mapping = field_with_args
691 .args
692 .mapping
693 .clone()
694 .map(|expr| quote::quote!(#expr));
695
696 // Return the fully formatted field
697 Some(FieldFormatted {
698 base_type,
699 own_field,
700 other_field,
701 epsilon,
702 max_relative,
703 set_equal: field_with_args.args.set_equal.unwrap_or(false),
704 mapping,
705 })
706 }
707
708 fn get_abs_diff_eq_struct_fields(
709 &self,
710 fields_with_args: &[FieldWithArgs],
711 ) -> Vec<proc_macro2::TokenStream> {
712 // We need to extend the where clause for all generics
713 let fields = fields_with_args
714 .iter()
715 .enumerate()
716 .filter_map(|(n, field_with_args)| {
717 if let Some(FieldFormatted {
718 base_type,
719 own_field,
720 other_field,
721 epsilon,
722 #[allow(unused)]
723 max_relative,
724 set_equal,
725 mapping,
726 }) = self.format_nth_field(n, field_with_args, None)
727 {
728 if set_equal {
729 Some(quote::quote!(#own_field == #other_field &&))
730 } else if let Some(map) = mapping {
731 Some(quote::quote!(
732 (if let ((Some(a), Some(b))) = (
733 (#map)(#own_field),
734 (#map)(#other_field)
735 ) {
736 approx::AbsDiffEq::abs_diff_eq(&a, &b, #epsilon)
737 } else {
738 false
739 }) &&
740 ))
741 } else {
742 Some(quote::quote!(
743 <#base_type as approx::AbsDiffEq>::abs_diff_eq(
744 #own_field,
745 #other_field,
746 #epsilon
747 ) &&
748 ))
749 }
750 } else {
751 None
752 }
753 });
754 fields.collect()
755 }
756
757 fn get_abs_diff_eq_enum_variants(
758 &self,
759 variants_with_args: &[EnumVariant],
760 ) -> Vec<proc_macro2::TokenStream> {
761 variants_with_args
762 .iter()
763 .map(|variant_with_args| {
764 let variant = &variant_with_args.ident;
765 use syn::spanned::Spanned;
766
767 let gen_field_names = |var: &str| -> Vec<syn::Ident> {
768 variant_with_args
769 .fields_with_args
770 .iter()
771 .enumerate()
772 .map(|(n, field)| syn::Ident::new(&format!("{var}{n}"), field.ident.span()))
773 .collect()
774 };
775 if variant_with_args
776 .fields_with_args
777 .first()
778 .and_then(|f| f.ident.clone())
779 .is_some()
780 {
781 let field_placeholders1 = gen_field_names("x");
782 let field_placeholders2 = gen_field_names("y");
783 let gen_combos = |iterator: Vec<syn::Ident>| {
784 iterator
785 .iter()
786 .zip(&variant_with_args.fields_with_args)
787 .map(|(fph, fwa)| {
788 let id = &fwa.ident;
789 quote::quote!(#id: #fph)
790 })
791 .collect::<Vec<_>>()
792 };
793 let comps: Vec<_> = field_placeholders1
794 .iter()
795 .zip(field_placeholders2.iter())
796 .zip(variant_with_args.fields_with_args.iter())
797 .map(|((xi, yi), field)| {
798 self.get_abs_diff_eq_single_field(xi.clone(), yi.clone(), field)
799 })
800 .collect();
801 let field_name_placeholder_combos1 = gen_combos(field_placeholders1);
802 let field_name_placeholder_combos2 = gen_combos(field_placeholders2);
803 quote::quote!(
804 (
805 Self:: #variant {
806 #(#field_name_placeholder_combos1),*
807 },
808 Self:: #variant {
809 #(#field_name_placeholder_combos2),*
810 }
811 ) => #(#comps) &&*,
812 )
813 } else if !variant_with_args.fields_with_args.is_empty() {
814 let field_names1 = gen_field_names("x");
815 let field_names2 = gen_field_names("y");
816 let comps: Vec<_> = field_names1
817 .iter()
818 .zip(field_names2.iter())
819 .zip(variant_with_args.fields_with_args.iter())
820 .map(|((xi, yi), field)| {
821 self.get_abs_diff_eq_single_field(xi.clone(), yi.clone(), field)
822 })
823 .collect();
824 quote::quote!(
825 (
826 Self:: #variant (#(#field_names1),*),
827 Self:: #variant (#(#field_names2),*)
828 ) => {#(#comps) &&*},
829 )
830 } else {
831 quote::quote!(
832 (Self:: #variant, Self:: #variant) => true,
833 )
834 }
835 })
836 .collect()
837 }
838
839 fn get_abs_diff_eq_single_field(
840 &self,
841 xi: syn::Ident,
842 yi: syn::Ident,
843 field_with_args: &FieldWithArgs,
844 ) -> Option<proc_macro2::TokenStream> {
845 if let Some(FieldFormatted {
846 base_type,
847 own_field,
848 other_field,
849 epsilon,
850 #[allow(unused)]
851 max_relative,
852 set_equal,
853 mapping,
854 }) = self.format_nth_field(0, field_with_args, Some((xi, yi)))
855 {
856 if set_equal {
857 Some(quote::quote!(#own_field == #other_field))
858 } else if let Some(map) = mapping {
859 Some(quote::quote!(
860 (if let ((Some(a), Some(b))) = (
861 (#map)(#own_field),
862 (#map)(#other_field)
863 ) {
864 <#base_type as approx::AbsDiffEq>::abs_diff_eq(&a, &b, #epsilon)
865 } else {
866 false
867 })
868 ))
869 } else {
870 Some(quote::quote!(
871 <#base_type as approx::AbsDiffEq>::abs_diff_eq(
872 &#own_field,
873 &#other_field,
874 #epsilon
875 )
876 ))
877 }
878 } else {
879 None
880 }
881 }
882
883 fn get_rel_eq_single_field(
884 &self,
885 xi: syn::Ident,
886 yi: syn::Ident,
887 field_with_args: &FieldWithArgs,
888 ) -> Option<proc_macro2::TokenStream> {
889 if let Some(FieldFormatted {
890 base_type,
891 own_field,
892 other_field,
893 epsilon,
894 max_relative,
895 set_equal,
896 mapping,
897 }) = self.format_nth_field(0, field_with_args, Some((xi, yi)))
898 {
899 if set_equal {
900 Some(quote::quote!(#own_field == #other_field))
901 } else if let Some(map) = mapping {
902 Some(quote::quote!(
903 (if let ((Some(a), Some(b))) = (
904 (#map)(#own_field),
905 (#map)(#other_field)
906 ) {
907 approx::RelativeEq::relative_eq(&a, &b, #epsilon, #max_relative)
908 } else {
909 false
910 })
911 ))
912 } else {
913 Some(quote::quote!(
914 <#base_type as approx::RelativeEq>::relative_eq(
915 #own_field,
916 #other_field,
917 #epsilon,
918 #max_relative
919 )
920 ))
921 }
922 } else {
923 None
924 }
925 }
926
927 fn get_rel_eq_struct_fields(
928 &self,
929 fields_with_args: &[FieldWithArgs],
930 ) -> Vec<proc_macro2::TokenStream> {
931 // We need to extend the where clause for all generics
932 let fields = fields_with_args
933 .iter()
934 .enumerate()
935 .filter_map(|(n, field_with_args)| {
936 if let Some(FieldFormatted {
937 base_type,
938 own_field,
939 other_field,
940 epsilon,
941 #[allow(unused)]
942 max_relative,
943 set_equal,
944 mapping,
945 }) = self.format_nth_field(n, field_with_args, None)
946 {
947 if set_equal {
948 Some(quote::quote!(#own_field == #other_field &&))
949 } else if let Some(map) = mapping {
950 Some(quote::quote!(
951 (if let ((Some(a), Some(b))) = (
952 (#map)(#own_field),
953 (#map)(#other_field)
954 ) {
955 approx::RelativeEq::relative_eq(&a, &b, #epsilon, #max_relative)
956 } else {
957 false
958 }) &&
959 ))
960 } else {
961 Some(quote::quote!(
962 <#base_type as approx::RelativeEq>::relative_eq(
963 #own_field,
964 #other_field,
965 #epsilon,
966 #max_relative,
967 ) &&
968 ))
969 }
970 } else {
971 None
972 }
973 });
974 fields.collect()
975 }
976
977 fn get_rel_eq_variants(
978 &self,
979 variants_with_args: &[EnumVariant],
980 ) -> Vec<proc_macro2::TokenStream> {
981 variants_with_args
982 .iter()
983 .map(|variant_with_args| {
984 let variant = &variant_with_args.ident;
985 use syn::spanned::Spanned;
986
987 let gen_field_names = |var: &str| -> Vec<syn::Ident> {
988 variant_with_args
989 .fields_with_args
990 .iter()
991 .enumerate()
992 .map(|(n, field)| syn::Ident::new(&format!("{var}{n}"), field.ident.span()))
993 .collect()
994 };
995 if variant_with_args
996 .fields_with_args
997 .first()
998 .and_then(|f| f.ident.clone())
999 .is_some()
1000 {
1001 let field_placeholders1 = gen_field_names("x");
1002 let field_placeholders2 = gen_field_names("y");
1003 let gen_combos = |iterator: Vec<syn::Ident>| {
1004 iterator
1005 .iter()
1006 .zip(&variant_with_args.fields_with_args)
1007 .map(|(fph, fwa)| {
1008 let id = &fwa.ident;
1009 quote::quote!(#id: #fph)
1010 })
1011 .collect::<Vec<_>>()
1012 };
1013 let comps: Vec<_> = field_placeholders1
1014 .iter()
1015 .zip(field_placeholders2.iter())
1016 .zip(variant_with_args.fields_with_args.iter())
1017 .map(|((xi, yi), field)| {
1018 self.get_rel_eq_single_field(xi.clone(), yi.clone(), field)
1019 })
1020 .collect();
1021 let field_name_placeholder_combos1 = gen_combos(field_placeholders1);
1022 let field_name_placeholder_combos2 = gen_combos(field_placeholders2);
1023 quote::quote!(
1024 (
1025 Self:: #variant {
1026 #(#field_name_placeholder_combos1),*
1027 },
1028 Self:: #variant {
1029 #(#field_name_placeholder_combos2),*
1030 }
1031 ) => #(#comps) &&*,
1032 )
1033 } else if !variant_with_args.fields_with_args.is_empty() {
1034 let field_names1 = gen_field_names("x");
1035 let field_names2 = gen_field_names("y");
1036 let comps: Vec<_> = field_names1
1037 .iter()
1038 .zip(field_names2.iter())
1039 .zip(variant_with_args.fields_with_args.iter())
1040 .map(|((xi, yi), field)| {
1041 self.get_rel_eq_single_field(xi.clone(), yi.clone(), field)
1042 })
1043 .collect();
1044 quote::quote!(
1045 (
1046 Self:: #variant (#(#field_names1),*),
1047 Self:: #variant (#(#field_names2),*)
1048 ) => {#(#comps) &&*},
1049 )
1050 } else {
1051 quote::quote!(
1052 (Self::#variant, Self::#variant) => true,
1053 )
1054 }
1055 })
1056 .collect()
1057 }
1058
1059 fn generate_where_clause(&self, abs_diff_eq: bool) -> proc_macro2::TokenStream {
1060 let (epsilon_type, _) = self.get_epsilon_type_and_default_value();
1061 let (_, _, where_clause) = self.base_type.generics().split_for_impl();
1062 let trait_bound = match abs_diff_eq {
1063 true => quote::quote!(approx::AbsDiffEq),
1064 false => quote::quote!(approx::RelativeEq),
1065 };
1066 if self.generics_involved() {
1067 let parent = self.get_epsilon_parent_type();
1068 match where_clause {
1069 Some(clause) => quote::quote!(
1070 #clause
1071 #parent: #trait_bound,
1072 #parent: PartialEq,
1073 #epsilon_type: Clone,
1074 ),
1075 None => quote::quote!(
1076 where
1077 #parent: #trait_bound,
1078 #parent: PartialEq,
1079 #epsilon_type: Clone,
1080 ),
1081 }
1082 } else {
1083 quote::quote!(#where_clause)
1084 }
1085 }
1086
1087 fn implement_derive_abs_diff_eq(&self) -> proc_macro2::TokenStream {
1088 let struct_name = &self.base_type.ident();
1089 let (epsilon_type, epsilon_default_value) = self.get_epsilon_type_and_default_value();
1090
1091 let (impl_generics, ty_generics, _) = self.base_type.generics().split_for_impl();
1092 let where_clause = self.generate_where_clause(true);
1093
1094 match &self.base_type {
1095 #[allow(unused)]
1096 BaseType::Struct {
1097 item_struct,
1098 fields_with_args,
1099 } => {
1100 let fields = self.get_abs_diff_eq_struct_fields(fields_with_args);
1101
1102 quote::quote!(
1103 const _ : () = {
1104 #[automatically_derived]
1105 impl #impl_generics approx::AbsDiffEq for #struct_name #ty_generics
1106 #where_clause
1107 {
1108 type Epsilon = #epsilon_type;
1109
1110 fn default_epsilon() -> Self::Epsilon {
1111 #epsilon_default_value
1112 }
1113
1114 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
1115 #(#fields)*
1116 true
1117 }
1118 }
1119 };
1120 )
1121 }
1122 #[allow(unused)]
1123 BaseType::Enum {
1124 item_enum,
1125 variants_with_args,
1126 } => {
1127 let variants = self.get_abs_diff_eq_enum_variants(variants_with_args);
1128 quote::quote!(
1129 const _: () = {
1130 #[automatically_derived]
1131 impl #impl_generics approx::AbsDiffEq for #struct_name #ty_generics
1132 #where_clause
1133 {
1134 type Epsilon = #epsilon_type;
1135
1136 fn default_epsilon() -> Self::Epsilon {
1137 #epsilon_default_value
1138 }
1139
1140 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
1141 match (self, other) {
1142 #(#variants)*
1143 _ => false,
1144 }
1145 }
1146 }
1147 };
1148 )
1149 }
1150 }
1151 }
1152
1153 fn implement_derive_rel_diff_eq(&self) -> proc_macro2::TokenStream {
1154 let obj_name = &self.base_type.ident();
1155 let max_relative_default_value = self.get_max_relative_default_value();
1156
1157 let (impl_generics, ty_generics, _) = self.base_type.generics().split_for_impl();
1158 let where_clause = self.generate_where_clause(false);
1159
1160 match &self.base_type {
1161 #[allow(unused)]
1162 BaseType::Struct {
1163 item_struct,
1164 fields_with_args,
1165 } => {
1166 let fields = self.get_rel_eq_struct_fields(fields_with_args);
1167
1168 quote::quote!(
1169 const _ : () = {
1170 #[automatically_derived]
1171 impl #impl_generics approx::RelativeEq for #obj_name #ty_generics
1172 #where_clause
1173 {
1174 fn default_max_relative() -> Self::Epsilon {
1175 #max_relative_default_value
1176 }
1177
1178 fn relative_eq(
1179 &self,
1180 other: &Self,
1181 epsilon: Self::Epsilon,
1182 max_relative: Self::Epsilon
1183 ) -> bool {
1184 #(#fields)*
1185 true
1186 }
1187 }
1188 };
1189 )
1190 }
1191 #[allow(unused)]
1192 BaseType::Enum {
1193 item_enum,
1194 variants_with_args,
1195 } => {
1196 let variants = self.get_rel_eq_variants(variants_with_args);
1197 quote::quote!(
1198 const _: () = {
1199 #[automatically_derived]
1200 impl #impl_generics approx::RelativeEq for #obj_name #ty_generics
1201 #where_clause
1202 {
1203 fn default_max_relative() -> Self::Epsilon {
1204 #max_relative_default_value
1205 }
1206
1207 fn relative_eq(
1208 &self,
1209 other: &Self,
1210 epsilon: Self::Epsilon,
1211 max_relative: Self::Epsilon
1212 ) -> bool {
1213 match (self, other) {
1214 #(#variants)*
1215 _ => false,
1216 }
1217 }
1218 }
1219 };
1220 )
1221 }
1222 }
1223 }
1224}
1225
1226/// See the [crate] level documentation for a guide.
1227#[proc_macro_derive(AbsDiffEq, attributes(approx))]
1228pub fn derive_abs_diff_eq(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1229 let parsed = syn::parse_macro_input!(input as AbsDiffEqParser);
1230 parsed.implement_derive_abs_diff_eq().into()
1231}
1232
1233/// See the [crate] level documentation for a guide.
1234#[proc_macro_derive(RelativeEq, attributes(approx))]
1235pub fn derive_rel_diff_eq(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1236 let parsed = syn::parse_macro_input!(input as AbsDiffEqParser);
1237 let mut output = quote::quote!();
1238 output.extend(parsed.implement_derive_abs_diff_eq());
1239 output.extend(parsed.implement_derive_rel_diff_eq());
1240 output.into()
1241}