derive_partial_eq_extras/
lib.rs

1use proc_macro::TokenStream;
2use syn_helpers::{
3    derive_trait,
4    proc_macro2::{Ident, Span},
5    quote,
6    syn::{
7        parse_macro_input, parse_quote, BinOp, DeriveInput, Expr, ExprBinary, ExprLit, Lit,
8        LitBool, Stmt, Token, Type,
9    },
10    CommaSeparatedList, Constructable, Field, FieldMut, Fields, HasAttributes, Trait, TraitItem,
11};
12
13const LEFT_NAME_POSTFIX: &str = "_left";
14const RIGHT_NAME_POSTFIX: &str = "_right";
15
16#[proc_macro_derive(
17    PartialEqExtras,
18    attributes(partial_eq_ignore_types, partial_eq_ignore)
19)]
20pub fn partial_eq_extras(input: TokenStream) -> TokenStream {
21    let input = parse_macro_input!(input as DeriveInput);
22
23    let eq_item = TraitItem::new_method(
24        Ident::new("eq", Span::call_site()),
25        None,
26        syn_helpers::TypeOfSelf::Reference,
27        vec![parse_quote!(other: &Self)],
28        Some(parse_quote!(bool)),
29        |item| {
30            let attributes = item.structure.get_attributes();
31
32            let ignored_types: Vec<Type> = attributes
33                .iter()
34                .filter(|attr| attr.path().is_ident(IGNORE_TYPES))
35                .flat_map(|attr| {
36                    attr.parse_args::<CommaSeparatedList<Type>>()
37                        .unwrap()
38                        .into_iter()
39                })
40                .collect();
41
42            match item.structure {
43                syn_helpers::Structure::Struct(r#struct) => {
44                    let expr =
45                        build_comparison_for_fields(r#struct.get_fields_mut(), &ignored_types);
46
47                    let left_patterns = r#struct.get_fields().to_pattern_with_config(
48                        r#struct.get_constructor_path(),
49                        syn_helpers::TypeOfSelf::Reference,
50                        LEFT_NAME_POSTFIX,
51                    );
52                    let right_patterns = r#struct.get_fields().to_pattern_with_config(
53                        r#struct.get_constructor_path(),
54                        syn_helpers::TypeOfSelf::Reference,
55                        RIGHT_NAME_POSTFIX,
56                    );
57                    let declaration = parse_quote! {
58                        let (#left_patterns, #right_patterns) = (self, other);
59                    };
60
61                    Ok(vec![declaration, Stmt::Expr(expr, None)])
62                }
63                syn_helpers::Structure::Enum(r#enum) => {
64                    let branches = r#enum.get_variants_mut().iter_mut().map(|variant| {
65                        let expr =
66                            build_comparison_for_fields(variant.get_fields_mut(), &ignored_types);
67
68                        let left_patterns = variant.get_fields().to_pattern_with_config(
69                            variant.get_constructor_path(),
70                            syn_helpers::TypeOfSelf::Reference,
71                            LEFT_NAME_POSTFIX,
72                        );
73                        let right_patterns = variant.get_fields().to_pattern_with_config(
74                            variant.get_constructor_path(),
75                            syn_helpers::TypeOfSelf::Reference,
76                            RIGHT_NAME_POSTFIX,
77                        );
78                        let token_stream = quote! { (#left_patterns, #right_patterns) => #expr };
79                        token_stream
80                    });
81                    let match_stmt = parse_quote! {
82                        match (self, other) {
83                            #(#branches,)*
84                            (_, _) => false
85                        }
86                    };
87                    Ok(vec![match_stmt])
88                }
89            }
90        },
91    );
92
93    // PartialEq trait
94    let partial_eq_trait = Trait {
95        name: parse_quote!(::std::cmp::PartialEq),
96        generic_parameters: None,
97        items: vec![eq_item],
98    };
99
100    let derive_trait = derive_trait(input, partial_eq_trait);
101    derive_trait.into()
102}
103
104const IGNORE_TYPES: &str = "partial_eq_ignore_types";
105const IGNORE_FIELD: &str = "partial_eq_ignore";
106
107fn build_comparison_for_fields(fields: &mut Fields, ignored_types: &[Type]) -> Expr {
108    let mut top = None::<Expr>;
109
110    for mut field in fields.fields_iterator_mut() {
111        if field
112            .get_attributes()
113            .iter()
114            .any(|attr| attr.path().is_ident(IGNORE_FIELD))
115        {
116            continue;
117        }
118
119        let ignore_type_reference_and_on_tests =
120            ignored_types.iter().any(|ty| field.get_type() == ty);
121
122        if ignore_type_reference_and_on_tests {
123            continue;
124        }
125
126        let lhs = field.get_reference_with_config(true, LEFT_NAME_POSTFIX);
127        let rhs = field.get_reference_with_config(true, RIGHT_NAME_POSTFIX);
128
129        let expr = Expr::Binary(ExprBinary {
130            attrs: Vec::new(),
131            left: Box::new(lhs),
132            op: BinOp::Eq(Token!(==)(Span::call_site())),
133            right: Box::new(rhs),
134        });
135
136        if let Some(old_top) = top {
137            top = Some(Expr::Binary(ExprBinary {
138                attrs: Vec::new(),
139                left: Box::new(old_top),
140                op: BinOp::And(Token!(&&)(Span::call_site())),
141                right: Box::new(expr),
142            }));
143        } else {
144            top = Some(expr);
145        }
146    }
147    top.unwrap_or_else(|| {
148        Expr::Lit(ExprLit {
149            attrs: vec![],
150            lit: Lit::Bool(LitBool {
151                value: true,
152                span: Span::call_site(),
153            }),
154        })
155    })
156}