approx_derive/
lib.rs

1#![deny(missing_docs)]
2//! This crate provides derive macros for the
3//! [AbsDiffEq](https://docs.rs/approx/latest/approx/trait.AbsDiffEq.html) and
4//! [RelativeEq](https://docs.rs/approx/latest/approx/trait.RelativeEq.html) traits of the
5//! [approx](https://docs.rs/approx/latest/approx/) crate.
6//!
7//! These derive macros only implement both traits with `...<Rhs = Self>`.
8//! The macros infer the `EPSILON` type of the [AbsDiffEq] trait by looking
9//! at the type of the first struct or enum field or any type specified by the user.
10//!
11//! This table lists all attributes which can be used to customize the derived traits.
12//! They are ordered in descending priority, meaning setting the `#[approx(equal)]` will overwrite
13//! any specifications made in the `#[approx(map = ...)]` attribute.
14//!
15//! | Field Attribute | Functionality |
16//! |:--- | --- |
17//! | [`#[approx(skip)]`](#skipping-fields) | Skips the field entirely |
18//! | [`#[approx(equal)]`](#testing-for-equality) | Checks this field with `==` for Equality |
19//! | [`#[approx(cast_field)]`](#casting-fields) | Casts the field with `.. as ..` syntax. |
20//! | [`#[approx(map = ..)]`](#mapping-values) | Maps values before comparing them. |
21//! | [`#[approx(static_epsilon = ..)]`](#static-values) | Defines a static epsilon value for this particular field. |
22//! | | |
23//! | **Object Attribute** | |
24//! | [`#[approx(default_epsilon = ...)]`](#default-epsilon) | Sets the default epsilon value |
25//! | [`#[approx(default_max_relative = ...)]`](#default-max-relative) | Sets the default `max_relative` value. |
26//! | [`#[approx(epsilon_type = ...)]`](#epsilon-type) | Sets the type of the epsilon value |
27//!
28//! # Usage
29//!
30//! ```
31//! use approx_derive::AbsDiffEq;
32//!
33//! // Define a new type and derive the AbsDiffEq trait
34//! #[derive(AbsDiffEq, PartialEq, Debug)]
35//! struct Position {
36//!     x: f64,
37//!     y: f64
38//! }
39//!
40//! // Compare if two given positions match
41//! // with respect to geiven epsilon.
42//! let p1 = Position { x: 1.01, y: 2.36 };
43//! let p2 = Position { x: 0.99, y: 2.38 };
44//! approx::assert_abs_diff_eq!(p1, p2, epsilon = 0.021);
45//! ```
46//! In this case, the generated code looks something like this:
47//! ```ignore
48//! const _ : () =
49//! {
50//!     #[automatically_derived] impl approx :: AbsDiffEq for Position
51//!     {
52//!         type Epsilon = <f64 as approx::AbsDiffEq>::Epsilon;
53//!
54//!         fn default_epsilon() -> Self :: Epsilon {
55//!             <f64 as approx::AbsDiffEq>::default_epsilon()
56//!         }
57//!
58//!         fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
59//!             <f64 as approx::AbsDiffEq>::abs_diff_eq(
60//!                 &self.x,
61//!                 & other.x,
62//!                 epsilon.clone()
63//!             ) &&
64//!             <f64 as approx::AbsDiffEq>::abs_diff_eq(
65//!                 &self.y,
66//!                 &other.y,
67//!                 epsilon.clone()
68//!             ) && true
69//!         }
70//!     }
71//! };
72//! ```
73//! The [AbsDiffEq] derive macro calls the `abs_diff_eq` method repeatedly on all fields
74//! to determine if all are matching.
75//!
76//! ## Enums
77//! Since `approx-derive` supports enums since `0.2`
78//!
79//! ```
80//! use approx_derive::AbsDiffEq;
81//!
82//! #[derive(AbsDiffEq, PartialEq, Debug)]
83//! enum Position {
84//!     Smooth { x: f32, y: f32, },
85//!     #[approx(cast_value)]
86//!     Lattice { x: isize, y: isize },
87//! }
88//!
89//! let p1 = Position::Smooth { x: 1.0, y: 1.1 };
90//! let p2 = Position::Smooth { x: 1.1, y: 1.0};
91//! let p3 = Position::Lattice { x: 1, y: 1 };
92//!
93//! approx::assert_abs_diff_eq!(p1, p2, epsilon=0.2);
94//! ```
95//!
96//! ```should_panic
97//! # use approx_derive::AbsDiffEq;
98//! # #[derive(AbsDiffEq, PartialEq, Debug)]
99//! # enum Position {
100//! #     Smooth { x: f32, y: f32, },
101//! #     #[approx(cast_value)]
102//! #     Lattice { x: isize, y: isize },
103//! # }
104//! # let p1 = Position::Smooth { x: 1.0, y: 1.1 };
105//! # let p3 = Position::Lattice { x: 1, y: 1 };
106//! // Note! Different enum variants can never be equal!
107//! approx::assert_abs_diff_eq!(p1, p3, epsilon = 1000.0);
108//! ```
109//!
110//!
111//! # Field Attributes
112//! ## Skipping Fields
113//!
114//! Sometimes, we only want to compare certain fields and omit others completely.
115//! ```
116//! # use approx_derive::*;
117//! #[derive(AbsDiffEq, PartialEq, Debug)]
118//! struct Player {
119//!     hit_points: f32,
120//!     pos_x: f32,
121//!     pos_y: f32,
122//!     #[approx(skip)]
123//!     id: (usize, usize),
124//! }
125//!
126//! let player1 = Player {
127//!     hit_points: 100.0,
128//!     pos_x: 2.0,
129//!     pos_y: -650.345,
130//!     id: (0, 1),
131//! };
132//!
133//! let player2 = Player {
134//!     hit_points: 99.9,
135//!     pos_x: 2.001,
136//!     pos_y: -649.898,
137//!     id: (22, 0),
138//! };
139//!
140//! approx::assert_abs_diff_eq!(player1, player2, epsilon = 0.5);
141//! ```
142//!
143//! ## Testing for [Equality](core::cmp::Eq)
144//!
145//! When identical equality is desired, we can specify this with the `#[approx(equal)]` attribute.
146//!
147//! ```
148//! # use approx_derive::*;
149//! #[derive(AbsDiffEq, PartialEq, Debug)]
150//! struct Prediction {
151//!     confidence: f64,
152//!     #[approx(equal)]
153//!     category: String,
154//! }
155//! ```
156//!
157//! Note that in this case, the type of the epsilon value for the implementation of
158//! [AbsDiffEq](https://docs.rs/approx/latest/approx/trait.AbsDiffEq.html) is inferred from the
159//! first field of the `Prediction` struct.
160//! This means if we reorder the arguments of the struct, we need to manually set the epsilon type.
161//!
162//! ```
163//! # use approx_derive::*;
164//! #[derive(AbsDiffEq, PartialEq, Debug)]
165//! #[approx(epsilon_type = f64)]
166//! struct Prediction {
167//!     #[approx(equal)]
168//!     category: String,
169//!     confidence: f64,
170//! }
171//! ```
172//!
173//! ## Casting Fields
174//!
175//! Structs which consist of multiple fields with different
176//! numeric types, can not be derived without additional hints.
177//! After all, we should specify how this type mismatch will be handled.
178//!
179//! ```compile_fail
180//! # use approx_derive::*;
181//! #[derive(AbsDiffEq, PartialEq, Debug)]
182//! struct MyStruct {
183//!     v1: f32,
184//!     v2: f64,
185//! }
186//! ```
187//!
188//! We can use the `#[approx(cast_field)]` and `#[approx(cast_value)]`
189//! attributes to achieve this goal.
190//! ```
191//! # use approx_derive::*;
192//! #[derive(AbsDiffEq, PartialEq, Debug)]
193//! struct MyStruct {
194//!     v1: f32,
195//!     #[approx(cast_field)]
196//!     v2: f64,
197//! }
198//! ```
199//! Now the second field will be casted to the type of the inferred epsilon value (`f32`).
200//! We can check this by testing if a change in the size of `f64::MIN_POSITIVE` would get lost by
201//! this procedure.
202//! ```
203//! # use approx_derive::*;
204//! # #[derive(RelativeEq, PartialEq, Debug)]
205//! # struct MyStruct {
206//! #   v1: f32,
207//! #   #[approx(cast_field)]
208//! #   v2: f64,
209//! # }
210//! let ms1 = MyStruct {
211//!     v1: 1.0,
212//!     v2: 3.0,
213//! };
214//! let ms2 = MyStruct {
215//!     v1: 1.0,
216//!     v2: 3.0 + f64::MIN_POSITIVE,
217//! };
218//! approx::assert_relative_eq!(ms1, ms2);
219//! ```
220//!
221//! ## Mapping Values
222//!
223//! We can map values before comparing them.
224//! By default, we need to return an option of the value in question.
225//! This allows to do computations where error can occur.
226//! Although this error is not caught, the comparison will fail if any of the two compared objects
227//! return a `None` value.
228//! ```
229//! # use approx_derive::*;
230//! # use approx::*;
231//! #[derive(AbsDiffEq, PartialEq, Debug)]
232//! struct Tower {
233//!     height_in_meters: f32,
234//!     #[approx(map = |x: &f32| Some(x.sqrt()))]
235//!     area_in_meters_squared: f32,
236//! }
237//! # let t1 = Tower {
238//! #   height_in_meters: 100.0,
239//! #   area_in_meters_squared: 30.1,
240//! # };
241//! # let t2 = Tower {
242//! #   height_in_meters: 100.0,
243//! #   area_in_meters_squared: 30.5,
244//! # };
245//! # approx::assert_abs_diff_ne!(t1, t2, epsilon = 0.03);
246//! ```
247//!
248//! This functionality can also be useful when having more complex datatypes.
249//! ```
250//! # use approx_derive::*;
251//! # use approx::*;
252//! #[derive(PartialEq, Debug)]
253//! enum Time {
254//!     Years(u16),
255//!     Months(u16),
256//!     Weeks(u16),
257//!     Days(u16),
258//! }
259//!
260//! fn time_to_days(time: &Time) -> Option<u16> {
261//!     match time {
262//!         Time::Years(y) => Some(365 * y),
263//!         Time::Months(m) => Some(30 * m),
264//!         Time::Weeks(w) => Some(7 * w),
265//!         Time::Days(d) => Some(*d),
266//!     }
267//! }
268//!
269//! #[derive(AbsDiffEq, PartialEq, Debug)]
270//! #[approx(epsilon_type = u16)]
271//! struct Dog {
272//!     #[approx(map = time_to_days)]
273//!     age: Time,
274//!     #[approx(map = time_to_days)]
275//!     next_doctors_appointment: Time,
276//! }
277//! ```
278//!
279//! ## Static Values
280//! We can force a static `EPSILON` or `max_relative` value for individual fields.
281//! ```
282//! # use approx_derive::*;
283//! #[derive(AbsDiffEq, PartialEq, Debug)]
284//! struct Rectangle {
285//!     #[approx(static_epsilon = 5e-2)]
286//!     a: f64,
287//!     b: f64,
288//!     #[approx(static_epsilon = 7e-2)]
289//!     c: f64,
290//! }
291//!
292//! let r1 = Rectangle {
293//!     a: 100.01,
294//!     b: 40.0001,
295//!     c: 30.055,
296//! };
297//! let r2 = Rectangle {
298//!     a: 99.97,
299//!     b: 40.0005,
300//!     c: 30.049,
301//! };
302//!
303//! // This is always true although the epsilon is smaller than the
304//! // difference between fields a and b respectively.
305//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-1);
306//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-2);
307//! approx::assert_abs_diff_eq!(r1, r2, epsilon = 1e-3);
308//!
309//! // Here, the epsilon value has become larger than the difference between the
310//! // b field values.
311//! approx::assert_abs_diff_ne!(r1, r2, epsilon = 1e-4);
312//! ```
313//! # Object Attributes
314//! ## Default Epsilon
315//! The [AbsDiffEq] trait allows to specify a default value for its `EPSILON` associated type.
316//! We can control this value by specifying it on an object level.
317//!
318//! ```
319//! # use approx_derive::*;
320//! #[derive(AbsDiffEq, PartialEq, Debug)]
321//! #[approx(default_epsilon = 10)]
322//! struct Benchmark {
323//!     cycles: u64,
324//!     warm_up: u64,
325//! }
326//!
327//! let benchmark1 = Benchmark {
328//!     cycles: 248,
329//!     warm_up: 36,
330//! };
331//! let benchmark2 = Benchmark {
332//!     cycles: 239,
333//!     warm_up: 28,
334//! };
335//!
336//! // When testing with not additional arguments, the results match
337//! approx::assert_abs_diff_eq!(benchmark1, benchmark2);
338//! // Once we specify a lower epsilon, the values do not agree anymore.
339//! approx::assert_abs_diff_ne!(benchmark1, benchmark2, epsilon = 5);
340//! ```
341//!
342//! ## Default Max Relative
343//! Similarly to [Default Epsilon], we can also choose a default max_relative devaition.
344//! ```
345//! # use approx_derive::*;
346//! #[derive(RelativeEq, PartialEq, Debug)]
347//! #[approx(default_max_relative = 0.1)]
348//! struct Benchmark {
349//!     time: f32,
350//!     warm_up: f32,
351//! }
352//!
353//! let bench1 = Benchmark {
354//!     time: 3.502785781,
355//!     warm_up: 0.58039458,
356//! };
357//! let bench2 = Benchmark {
358//!     time: 3.7023458,
359//!     warm_up: 0.59015897,
360//! };
361//!
362//! approx::assert_relative_eq!(bench1, bench2);
363//! approx::assert_relative_ne!(bench1, bench2, max_relative = 0.05);
364//! ```
365//! ## Epsilon Type
366//! When specifying nothing, the macros will infer the `EPSILON` type from the type of the
367//! first struct/enum field (the order in which it is parsed).
368//! This can be problematic in certain scenarios which is why we can also manually specify this
369//! type.
370//!
371//! ```
372//! # use approx_derive::*;
373//! #[derive(RelativeEq, PartialEq, Debug)]
374//! #[approx(epsilon_type = f32)]
375//! struct Car {
376//!     #[approx(cast_field)]
377//!     produced_year: u32,
378//!     horse_power: f32,
379//! }
380//!
381//! let car1 = Car {
382//!     produced_year: 1992,
383//!     horse_power: 122.87,
384//! };
385//! let car2 = Car {
386//!     produced_year: 2000,
387//!     horse_power: 117.45,
388//! };
389//!
390//! approx::assert_relative_eq!(car1, car2, max_relative = 0.05);
391//! approx::assert_relative_ne!(car1, car2, max_relative = 0.01);
392//! ```
393
394mod args_parsing;
395
396use args_parsing::*;
397
398enum BaseType {
399    Struct {
400        item_struct: syn::ItemStruct,
401        fields_with_args: Vec<FieldWithArgs>,
402    },
403    Enum {
404        item_enum: syn::ItemEnum,
405        variants_with_args: Vec<EnumVariant>,
406    },
407}
408
409impl syn::parse::Parse for BaseType {
410    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
411        if input.fork().parse::<syn::ItemStruct>().is_ok() {
412            use syn::spanned::Spanned;
413            let item_struct: syn::ItemStruct = input.parse()?;
414            let fields_with_args = match item_struct.fields.clone() {
415                syn::Fields::Named(named_fields) => named_fields
416                    .named
417                    .iter()
418                    .map(FieldWithArgs::from_field)
419                    .collect::<syn::Result<Vec<_>>>(),
420                syn::Fields::Unnamed(unnamed_fields) => unnamed_fields
421                    .unnamed
422                    .iter()
423                    .map(FieldWithArgs::from_field)
424                    .collect::<syn::Result<Vec<_>>>(),
425                syn::Fields::Unit => Err(syn::Error::new(
426                    item_struct.span(),
427                    "cannot derive from unit struct",
428                )),
429            }?;
430            Ok(BaseType::Struct {
431                item_struct,
432                fields_with_args,
433            })
434        } else if let Ok(item_enum) = input.parse::<syn::ItemEnum>() {
435            // let item_enum: syn::ItemEnum = input.parse()?;
436            let variants_with_args = item_enum
437                .variants
438                .iter()
439                .map(|v| {
440                    let args = FieldArgs::from_attrs(&v.attrs)?;
441                    let fields_with_args = v
442                        .fields
443                        .iter()
444                        .map(|f| {
445                            let mut fwa = FieldWithArgs::from_field(f)?;
446                            fwa.args.patch_if_not_exists(&args);
447                            Ok(fwa)
448                        })
449                        .collect::<syn::Result<Vec<_>>>()?;
450                    Ok(EnumVariant {
451                        fields_with_args,
452                        ident: v.ident.clone(),
453                        discriminant: v.discriminant.clone().map(|x| x.1),
454                    })
455                })
456                .collect::<syn::Result<Vec<_>>>()?;
457            Ok(BaseType::Enum {
458                item_enum,
459                variants_with_args,
460            })
461        } else {
462            Err(syn::Error::new(
463                input.span(),
464                "Could not parse enum or struct",
465            ))
466        }
467    }
468}
469
470impl BaseType {
471    fn attrs(&self) -> &Vec<syn::Attribute> {
472        match self {
473            #[allow(unused)]
474            BaseType::Struct {
475                item_struct,
476                fields_with_args,
477            } => &item_struct.attrs,
478            #[allow(unused)]
479            BaseType::Enum {
480                item_enum,
481                variants_with_args,
482            } => &item_enum.attrs,
483        }
484    }
485
486    fn generics(&self) -> &syn::Generics {
487        match self {
488            #[allow(unused)]
489            BaseType::Struct {
490                item_struct,
491                fields_with_args,
492            } => &item_struct.generics,
493            #[allow(unused)]
494            BaseType::Enum {
495                item_enum,
496                variants_with_args,
497            } => &item_enum.generics,
498        }
499    }
500
501    fn ident(&self) -> &syn::Ident {
502        match self {
503            #[allow(unused)]
504            BaseType::Struct {
505                item_struct,
506                fields_with_args,
507            } => &item_struct.ident,
508            #[allow(unused)]
509            BaseType::Enum {
510                item_enum,
511                variants_with_args,
512            } => &item_enum.ident,
513        }
514    }
515}
516
517struct AbsDiffEqParser {
518    base_type: BaseType,
519    struct_args: StructArgs,
520}
521
522impl syn::parse::Parse for AbsDiffEqParser {
523    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
524        let base_type: BaseType = input.parse()?;
525        let struct_args = StructArgs::from_attrs(base_type.attrs())?;
526        Ok(Self {
527            base_type,
528            struct_args,
529        })
530    }
531}
532
533#[derive(Debug)]
534struct FieldFormatted {
535    base_type: proc_macro2::TokenStream,
536    own_field: proc_macro2::TokenStream,
537    other_field: proc_macro2::TokenStream,
538    epsilon: proc_macro2::TokenStream,
539    max_relative: proc_macro2::TokenStream,
540    mapping: Option<proc_macro2::TokenStream>,
541    set_equal: bool,
542}
543
544impl AbsDiffEqParser {
545    fn get_epsilon_parent_type(&self) -> proc_macro2::TokenStream {
546        self.struct_args
547            .epsilon_type
548            .clone()
549            .map(|x| quote::quote!(#x))
550            .or_else(|| {
551                #[allow(unused)]
552                match &self.base_type {
553                    BaseType::Struct {
554                        item_struct,
555                        fields_with_args,
556                    } => fields_with_args
557                        .iter()
558                        .find(|f| f.args.skip.is_none_or(|x| x == false)),
559                    BaseType::Enum {
560                        item_enum,
561                        variants_with_args,
562                    } => variants_with_args
563                        .iter()
564                        .flat_map(|v| v.fields_with_args.iter())
565                        .find(|f| f.args.skip.is_none_or(|x| x == false)),
566                }
567                .map(|field| {
568                    let field_type = &field.ty;
569                    quote::quote!(#field_type)
570                })
571            })
572            .or_else(|| Some(quote::quote!(f64)))
573            .unwrap()
574    }
575
576    fn get_derived_epsilon_type(&self) -> proc_macro2::TokenStream {
577        let parent = self.get_epsilon_parent_type();
578        quote::quote!(<#parent as approx::AbsDiffEq>::Epsilon)
579    }
580
581    fn get_epsilon_type_and_default_value(
582        &self,
583    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
584        let parent = self.get_epsilon_parent_type();
585        let epsilon_type = self.get_derived_epsilon_type();
586        let epsilon_default_value = self
587            .struct_args
588            .default_epsilon_value
589            .clone()
590            .map(|x| quote::quote!(#x))
591            .or_else(|| Some(quote::quote!(<#parent as approx::AbsDiffEq>::default_epsilon())))
592            .unwrap();
593        (epsilon_type, epsilon_default_value)
594    }
595
596    fn generics_involved(&self) -> bool {
597        let parent = self.get_epsilon_parent_type();
598        self.base_type
599            .generics()
600            .params
601            .iter()
602            .any(|param| quote::quote!(#param).to_string() == parent.to_string())
603    }
604
605    fn get_max_relative_default_value(&self) -> proc_macro2::TokenStream {
606        let epsilon_type = self.get_epsilon_parent_type();
607        self.struct_args
608            .default_max_relative_value
609            .clone()
610            .map(|x| quote::quote!(#x))
611            .or_else(|| {
612                Some(quote::quote!(<#epsilon_type as approx::RelativeEq>::default_max_relative()))
613            })
614            .unwrap()
615    }
616
617    fn format_nth_field(
618        &self,
619        n: usize,
620        field_with_args: &FieldWithArgs,
621        idents: Option<(syn::Ident, syn::Ident)>,
622    ) -> Option<FieldFormatted> {
623        // Determine if this field will be skipped and exit early
624        if let Some(true) = field_with_args.args.skip {
625            return None;
626        }
627
628        // Get types for epsilon and max_relative
629        let parent_type = self.get_epsilon_parent_type();
630
631        // Save field name and type in variables for easy access
632        use std::str::FromStr;
633        let (field_name1, field_name2) = match (&field_with_args.ident, idents) {
634            (Some(id), None) => (quote::quote!(self.#id), quote::quote!(other.#id)),
635            (None, None) => {
636                let field_number = proc_macro2::TokenStream::from_str(&format!("{}", n)).unwrap();
637                (
638                    quote::quote!(self.#field_number),
639                    quote::quote!(other.#field_number),
640                )
641            }
642            (_, Some((id1, id2))) => (quote::quote!(#id1), quote::quote!(#id2)),
643        };
644        let field_type = &field_with_args.ty;
645
646        // Determine if the field or the value will be casted in any way
647        let cast_strategy = &field_with_args.args.cast_strategy;
648
649        // Get static values (if present) for epsilon and max_relative
650        let epsilon = &field_with_args
651            .args
652            .epsilon_static_value
653            .clone()
654            .map(|x| quote::quote!(#x))
655            .or_else(|| Some(quote::quote!(epsilon)))
656            .unwrap();
657        let max_relative = field_with_args
658            .args
659            .max_relative_static_value
660            .clone()
661            .map(|x| quote::quote!(#x))
662            .or_else(|| Some(quote::quote!(max_relative)))
663            .unwrap();
664
665        // Use the casting strategy
666        let (base_type, own_field, other_field, epsilon, max_relative) = match cast_strategy {
667            Some(TypeCast::CastField) => (
668                quote::quote!(#parent_type),
669                quote::quote!(&(#field_name1.clone() as #parent_type)),
670                quote::quote!(&(#field_name2.clone() as #parent_type)),
671                quote::quote!(#epsilon.clone()),
672                quote::quote!(#max_relative.clone()),
673            ),
674            Some(TypeCast::CastValue) => (
675                quote::quote!(#field_type),
676                quote::quote!(&#field_name1),
677                quote::quote!(&#field_name2),
678                quote::quote!((#epsilon.clone() as #field_type)),
679                quote::quote!((#max_relative.clone() as #field_type)),
680            ),
681            None => (
682                quote::quote!(#parent_type),
683                quote::quote!(&#field_name1),
684                quote::quote!(&#field_name2),
685                quote::quote!(#epsilon.clone()),
686                quote::quote!(#max_relative.clone()),
687            ),
688        };
689
690        let mapping = field_with_args
691            .args
692            .mapping
693            .clone()
694            .map(|expr| quote::quote!(#expr));
695
696        // Return the fully formatted field
697        Some(FieldFormatted {
698            base_type,
699            own_field,
700            other_field,
701            epsilon,
702            max_relative,
703            set_equal: field_with_args.args.set_equal.unwrap_or(false),
704            mapping,
705        })
706    }
707
708    fn get_abs_diff_eq_struct_fields(
709        &self,
710        fields_with_args: &[FieldWithArgs],
711    ) -> Vec<proc_macro2::TokenStream> {
712        // We need to extend the where clause for all generics
713        let fields = fields_with_args
714            .iter()
715            .enumerate()
716            .filter_map(|(n, field_with_args)| {
717                if let Some(FieldFormatted {
718                    base_type,
719                    own_field,
720                    other_field,
721                    epsilon,
722                    #[allow(unused)]
723                    max_relative,
724                    set_equal,
725                    mapping,
726                }) = self.format_nth_field(n, field_with_args, None)
727                {
728                    if set_equal {
729                        Some(quote::quote!(#own_field == #other_field &&))
730                    } else if let Some(map) = mapping {
731                        Some(quote::quote!(
732                            (if let ((Some(a), Some(b))) = (
733                                (#map)(#own_field),
734                                (#map)(#other_field)
735                            ) {
736                                approx::AbsDiffEq::abs_diff_eq(&a, &b, #epsilon)
737                            } else {
738                                false
739                            }) &&
740                        ))
741                    } else {
742                        Some(quote::quote!(
743                            <#base_type as approx::AbsDiffEq>::abs_diff_eq(
744                                #own_field,
745                                #other_field,
746                                #epsilon
747                            ) &&
748                        ))
749                    }
750                } else {
751                    None
752                }
753            });
754        fields.collect()
755    }
756
757    fn get_abs_diff_eq_enum_variants(
758        &self,
759        variants_with_args: &[EnumVariant],
760    ) -> Vec<proc_macro2::TokenStream> {
761        variants_with_args
762            .iter()
763            .map(|variant_with_args| {
764                let variant = &variant_with_args.ident;
765                use syn::spanned::Spanned;
766
767                let gen_field_names = |var: &str| -> Vec<syn::Ident> {
768                    variant_with_args
769                        .fields_with_args
770                        .iter()
771                        .enumerate()
772                        .map(|(n, field)| syn::Ident::new(&format!("{var}{n}"), field.ident.span()))
773                        .collect()
774                };
775                if variant_with_args
776                    .fields_with_args
777                    .first()
778                    .and_then(|f| f.ident.clone())
779                    .is_some()
780                {
781                    let field_placeholders1 = gen_field_names("x");
782                    let field_placeholders2 = gen_field_names("y");
783                    let gen_combos = |iterator: Vec<syn::Ident>| {
784                        iterator
785                            .iter()
786                            .zip(&variant_with_args.fields_with_args)
787                            .map(|(fph, fwa)| {
788                                let id = &fwa.ident;
789                                quote::quote!(#id: #fph)
790                            })
791                            .collect::<Vec<_>>()
792                    };
793                    let comps: Vec<_> = field_placeholders1
794                        .iter()
795                        .zip(field_placeholders2.iter())
796                        .zip(variant_with_args.fields_with_args.iter())
797                        .map(|((xi, yi), field)| {
798                            self.get_abs_diff_eq_single_field(xi.clone(), yi.clone(), field)
799                        })
800                        .collect();
801                    let field_name_placeholder_combos1 = gen_combos(field_placeholders1);
802                    let field_name_placeholder_combos2 = gen_combos(field_placeholders2);
803                    quote::quote!(
804                        (
805                            Self:: #variant {
806                                #(#field_name_placeholder_combos1),*
807                            },
808                            Self:: #variant {
809                                #(#field_name_placeholder_combos2),*
810                            }
811                        ) => #(#comps) &&*,
812                    )
813                } else if !variant_with_args.fields_with_args.is_empty() {
814                    let field_names1 = gen_field_names("x");
815                    let field_names2 = gen_field_names("y");
816                    let comps: Vec<_> = field_names1
817                        .iter()
818                        .zip(field_names2.iter())
819                        .zip(variant_with_args.fields_with_args.iter())
820                        .map(|((xi, yi), field)| {
821                            self.get_abs_diff_eq_single_field(xi.clone(), yi.clone(), field)
822                        })
823                        .collect();
824                    quote::quote!(
825                        (
826                            Self:: #variant (#(#field_names1),*),
827                            Self:: #variant (#(#field_names2),*)
828                        ) => {#(#comps) &&*},
829                    )
830                } else {
831                    quote::quote!(
832                        (Self:: #variant, Self:: #variant) => true,
833                    )
834                }
835            })
836            .collect()
837    }
838
839    fn get_abs_diff_eq_single_field(
840        &self,
841        xi: syn::Ident,
842        yi: syn::Ident,
843        field_with_args: &FieldWithArgs,
844    ) -> Option<proc_macro2::TokenStream> {
845        if let Some(FieldFormatted {
846            base_type,
847            own_field,
848            other_field,
849            epsilon,
850            #[allow(unused)]
851            max_relative,
852            set_equal,
853            mapping,
854        }) = self.format_nth_field(0, field_with_args, Some((xi, yi)))
855        {
856            if set_equal {
857                Some(quote::quote!(#own_field == #other_field))
858            } else if let Some(map) = mapping {
859                Some(quote::quote!(
860                    (if let ((Some(a), Some(b))) = (
861                        (#map)(#own_field),
862                        (#map)(#other_field)
863                    ) {
864                        <#base_type as approx::AbsDiffEq>::abs_diff_eq(&a, &b, #epsilon)
865                    } else {
866                        false
867                    })
868                ))
869            } else {
870                Some(quote::quote!(
871                    <#base_type as approx::AbsDiffEq>::abs_diff_eq(
872                        &#own_field,
873                        &#other_field,
874                        #epsilon
875                    )
876                ))
877            }
878        } else {
879            None
880        }
881    }
882
883    fn get_rel_eq_single_field(
884        &self,
885        xi: syn::Ident,
886        yi: syn::Ident,
887        field_with_args: &FieldWithArgs,
888    ) -> Option<proc_macro2::TokenStream> {
889        if let Some(FieldFormatted {
890            base_type,
891            own_field,
892            other_field,
893            epsilon,
894            max_relative,
895            set_equal,
896            mapping,
897        }) = self.format_nth_field(0, field_with_args, Some((xi, yi)))
898        {
899            if set_equal {
900                Some(quote::quote!(#own_field == #other_field))
901            } else if let Some(map) = mapping {
902                Some(quote::quote!(
903                    (if let ((Some(a), Some(b))) = (
904                        (#map)(#own_field),
905                        (#map)(#other_field)
906                    ) {
907                        approx::RelativeEq::relative_eq(&a, &b, #epsilon, #max_relative)
908                    } else {
909                        false
910                    })
911                ))
912            } else {
913                Some(quote::quote!(
914                    <#base_type as approx::RelativeEq>::relative_eq(
915                        #own_field,
916                        #other_field,
917                        #epsilon,
918                        #max_relative
919                    )
920                ))
921            }
922        } else {
923            None
924        }
925    }
926
927    fn get_rel_eq_struct_fields(
928        &self,
929        fields_with_args: &[FieldWithArgs],
930    ) -> Vec<proc_macro2::TokenStream> {
931        // We need to extend the where clause for all generics
932        let fields = fields_with_args
933            .iter()
934            .enumerate()
935            .filter_map(|(n, field_with_args)| {
936                if let Some(FieldFormatted {
937                    base_type,
938                    own_field,
939                    other_field,
940                    epsilon,
941                    #[allow(unused)]
942                    max_relative,
943                    set_equal,
944                    mapping,
945                }) = self.format_nth_field(n, field_with_args, None)
946                {
947                    if set_equal {
948                        Some(quote::quote!(#own_field == #other_field &&))
949                    } else if let Some(map) = mapping {
950                        Some(quote::quote!(
951                            (if let ((Some(a), Some(b))) = (
952                                (#map)(#own_field),
953                                (#map)(#other_field)
954                            ) {
955                                approx::RelativeEq::relative_eq(&a, &b, #epsilon, #max_relative)
956                            } else {
957                                false
958                            }) &&
959                        ))
960                    } else {
961                        Some(quote::quote!(
962                            <#base_type as approx::RelativeEq>::relative_eq(
963                                #own_field,
964                                #other_field,
965                                #epsilon,
966                                #max_relative,
967                            ) &&
968                        ))
969                    }
970                } else {
971                    None
972                }
973            });
974        fields.collect()
975    }
976
977    fn get_rel_eq_variants(
978        &self,
979        variants_with_args: &[EnumVariant],
980    ) -> Vec<proc_macro2::TokenStream> {
981        variants_with_args
982            .iter()
983            .map(|variant_with_args| {
984                let variant = &variant_with_args.ident;
985                use syn::spanned::Spanned;
986
987                let gen_field_names = |var: &str| -> Vec<syn::Ident> {
988                    variant_with_args
989                        .fields_with_args
990                        .iter()
991                        .enumerate()
992                        .map(|(n, field)| syn::Ident::new(&format!("{var}{n}"), field.ident.span()))
993                        .collect()
994                };
995                if variant_with_args
996                    .fields_with_args
997                    .first()
998                    .and_then(|f| f.ident.clone())
999                    .is_some()
1000                {
1001                    let field_placeholders1 = gen_field_names("x");
1002                    let field_placeholders2 = gen_field_names("y");
1003                    let gen_combos = |iterator: Vec<syn::Ident>| {
1004                        iterator
1005                            .iter()
1006                            .zip(&variant_with_args.fields_with_args)
1007                            .map(|(fph, fwa)| {
1008                                let id = &fwa.ident;
1009                                quote::quote!(#id: #fph)
1010                            })
1011                            .collect::<Vec<_>>()
1012                    };
1013                    let comps: Vec<_> = field_placeholders1
1014                        .iter()
1015                        .zip(field_placeholders2.iter())
1016                        .zip(variant_with_args.fields_with_args.iter())
1017                        .map(|((xi, yi), field)| {
1018                            self.get_rel_eq_single_field(xi.clone(), yi.clone(), field)
1019                        })
1020                        .collect();
1021                    let field_name_placeholder_combos1 = gen_combos(field_placeholders1);
1022                    let field_name_placeholder_combos2 = gen_combos(field_placeholders2);
1023                    quote::quote!(
1024                        (
1025                            Self:: #variant {
1026                                #(#field_name_placeholder_combos1),*
1027                            },
1028                            Self:: #variant {
1029                                #(#field_name_placeholder_combos2),*
1030                            }
1031                        ) => #(#comps) &&*,
1032                    )
1033                } else if !variant_with_args.fields_with_args.is_empty() {
1034                    let field_names1 = gen_field_names("x");
1035                    let field_names2 = gen_field_names("y");
1036                    let comps: Vec<_> = field_names1
1037                        .iter()
1038                        .zip(field_names2.iter())
1039                        .zip(variant_with_args.fields_with_args.iter())
1040                        .map(|((xi, yi), field)| {
1041                            self.get_rel_eq_single_field(xi.clone(), yi.clone(), field)
1042                        })
1043                        .collect();
1044                    quote::quote!(
1045                        (
1046                            Self:: #variant (#(#field_names1),*),
1047                            Self:: #variant (#(#field_names2),*)
1048                        ) => {#(#comps) &&*},
1049                    )
1050                } else {
1051                    quote::quote!(
1052                        (Self::#variant, Self::#variant) => true,
1053                    )
1054                }
1055            })
1056            .collect()
1057    }
1058
1059    fn generate_where_clause(&self, abs_diff_eq: bool) -> proc_macro2::TokenStream {
1060        let (epsilon_type, _) = self.get_epsilon_type_and_default_value();
1061        let (_, _, where_clause) = self.base_type.generics().split_for_impl();
1062        let trait_bound = match abs_diff_eq {
1063            true => quote::quote!(approx::AbsDiffEq),
1064            false => quote::quote!(approx::RelativeEq),
1065        };
1066        if self.generics_involved() {
1067            let parent = self.get_epsilon_parent_type();
1068            match where_clause {
1069                Some(clause) => quote::quote!(
1070                    #clause
1071                        #parent: #trait_bound,
1072                        #parent: PartialEq,
1073                        #epsilon_type: Clone,
1074                ),
1075                None => quote::quote!(
1076                where
1077                    #parent: #trait_bound,
1078                    #parent: PartialEq,
1079                    #epsilon_type: Clone,
1080                ),
1081            }
1082        } else {
1083            quote::quote!(#where_clause)
1084        }
1085    }
1086
1087    fn implement_derive_abs_diff_eq(&self) -> proc_macro2::TokenStream {
1088        let struct_name = &self.base_type.ident();
1089        let (epsilon_type, epsilon_default_value) = self.get_epsilon_type_and_default_value();
1090
1091        let (impl_generics, ty_generics, _) = self.base_type.generics().split_for_impl();
1092        let where_clause = self.generate_where_clause(true);
1093
1094        match &self.base_type {
1095            #[allow(unused)]
1096            BaseType::Struct {
1097                item_struct,
1098                fields_with_args,
1099            } => {
1100                let fields = self.get_abs_diff_eq_struct_fields(fields_with_args);
1101
1102                quote::quote!(
1103                    const _ : () = {
1104                        #[automatically_derived]
1105                        impl #impl_generics approx::AbsDiffEq for #struct_name #ty_generics
1106                        #where_clause
1107                        {
1108                            type Epsilon = #epsilon_type;
1109
1110                            fn default_epsilon() -> Self::Epsilon {
1111                                #epsilon_default_value
1112                            }
1113
1114                            fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
1115                                #(#fields)*
1116                                true
1117                            }
1118                        }
1119                    };
1120                )
1121            }
1122            #[allow(unused)]
1123            BaseType::Enum {
1124                item_enum,
1125                variants_with_args,
1126            } => {
1127                let variants = self.get_abs_diff_eq_enum_variants(variants_with_args);
1128                quote::quote!(
1129                    const _: () = {
1130                        #[automatically_derived]
1131                        impl #impl_generics approx::AbsDiffEq for #struct_name #ty_generics
1132                        #where_clause
1133                        {
1134                            type Epsilon = #epsilon_type;
1135
1136                            fn default_epsilon() -> Self::Epsilon {
1137                                #epsilon_default_value
1138                            }
1139
1140                            fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
1141                                match (self, other) {
1142                                    #(#variants)*
1143                                    _ => false,
1144                                }
1145                            }
1146                        }
1147                    };
1148                )
1149            }
1150        }
1151    }
1152
1153    fn implement_derive_rel_diff_eq(&self) -> proc_macro2::TokenStream {
1154        let obj_name = &self.base_type.ident();
1155        let max_relative_default_value = self.get_max_relative_default_value();
1156
1157        let (impl_generics, ty_generics, _) = self.base_type.generics().split_for_impl();
1158        let where_clause = self.generate_where_clause(false);
1159
1160        match &self.base_type {
1161            #[allow(unused)]
1162            BaseType::Struct {
1163                item_struct,
1164                fields_with_args,
1165            } => {
1166                let fields = self.get_rel_eq_struct_fields(fields_with_args);
1167
1168                quote::quote!(
1169                    const _ : () = {
1170                        #[automatically_derived]
1171                        impl #impl_generics approx::RelativeEq for #obj_name #ty_generics
1172                        #where_clause
1173                        {
1174                            fn default_max_relative() -> Self::Epsilon {
1175                                #max_relative_default_value
1176                            }
1177
1178                            fn relative_eq(
1179                                &self,
1180                                other: &Self,
1181                                epsilon: Self::Epsilon,
1182                                max_relative: Self::Epsilon
1183                            ) -> bool {
1184                                #(#fields)*
1185                                true
1186                            }
1187                        }
1188                    };
1189                )
1190            }
1191            #[allow(unused)]
1192            BaseType::Enum {
1193                item_enum,
1194                variants_with_args,
1195            } => {
1196                let variants = self.get_rel_eq_variants(variants_with_args);
1197                quote::quote!(
1198                    const _: () = {
1199                        #[automatically_derived]
1200                        impl #impl_generics approx::RelativeEq for #obj_name #ty_generics
1201                        #where_clause
1202                        {
1203                            fn default_max_relative() -> Self::Epsilon {
1204                                #max_relative_default_value
1205                            }
1206
1207                            fn relative_eq(
1208                                &self,
1209                                other: &Self,
1210                                epsilon: Self::Epsilon,
1211                                max_relative: Self::Epsilon
1212                            ) -> bool {
1213                                match (self, other) {
1214                                    #(#variants)*
1215                                    _ => false,
1216                                }
1217                            }
1218                        }
1219                    };
1220                )
1221            }
1222        }
1223    }
1224}
1225
1226/// See the [crate] level documentation for a guide.
1227#[proc_macro_derive(AbsDiffEq, attributes(approx))]
1228pub fn derive_abs_diff_eq(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1229    let parsed = syn::parse_macro_input!(input as AbsDiffEqParser);
1230    parsed.implement_derive_abs_diff_eq().into()
1231}
1232
1233/// See the [crate] level documentation for a guide.
1234#[proc_macro_derive(RelativeEq, attributes(approx))]
1235pub fn derive_rel_diff_eq(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1236    let parsed = syn::parse_macro_input!(input as AbsDiffEqParser);
1237    let mut output = quote::quote!();
1238    output.extend(parsed.implement_derive_abs_diff_eq());
1239    output.extend(parsed.implement_derive_rel_diff_eq());
1240    output.into()
1241}