derive_partial_eq_extras/
lib.rs1use proc_macro::TokenStream;
2use syn_helpers::{
3 derive_trait,
4 proc_macro2::{Ident, Span},
5 quote,
6 syn::{
7 parse_macro_input, parse_quote, BinOp, DeriveInput, Expr, ExprBinary, ExprLit, Lit,
8 LitBool, Stmt, Token, Type,
9 },
10 CommaSeparatedList, Constructable, Field, FieldMut, Fields, HasAttributes, Trait, TraitItem,
11};
12
13const LEFT_NAME_POSTFIX: &str = "_left";
14const RIGHT_NAME_POSTFIX: &str = "_right";
15
16#[proc_macro_derive(
17 PartialEqExtras,
18 attributes(partial_eq_ignore_types, partial_eq_ignore)
19)]
20pub fn partial_eq_extras(input: TokenStream) -> TokenStream {
21 let input = parse_macro_input!(input as DeriveInput);
22
23 let eq_item = TraitItem::new_method(
24 Ident::new("eq", Span::call_site()),
25 None,
26 syn_helpers::TypeOfSelf::Reference,
27 vec![parse_quote!(other: &Self)],
28 Some(parse_quote!(bool)),
29 |item| {
30 let attributes = item.structure.get_attributes();
31
32 let ignored_types: Vec<Type> = attributes
33 .iter()
34 .filter(|attr| attr.path().is_ident(IGNORE_TYPES))
35 .flat_map(|attr| {
36 attr.parse_args::<CommaSeparatedList<Type>>()
37 .unwrap()
38 .into_iter()
39 })
40 .collect();
41
42 match item.structure {
43 syn_helpers::Structure::Struct(r#struct) => {
44 let expr =
45 build_comparison_for_fields(r#struct.get_fields_mut(), &ignored_types);
46
47 let left_patterns = r#struct.get_fields().to_pattern_with_config(
48 r#struct.get_constructor_path(),
49 syn_helpers::TypeOfSelf::Reference,
50 LEFT_NAME_POSTFIX,
51 );
52 let right_patterns = r#struct.get_fields().to_pattern_with_config(
53 r#struct.get_constructor_path(),
54 syn_helpers::TypeOfSelf::Reference,
55 RIGHT_NAME_POSTFIX,
56 );
57 let declaration = parse_quote! {
58 let (#left_patterns, #right_patterns) = (self, other);
59 };
60
61 Ok(vec![declaration, Stmt::Expr(expr, None)])
62 }
63 syn_helpers::Structure::Enum(r#enum) => {
64 let branches = r#enum.get_variants_mut().iter_mut().map(|variant| {
65 let expr =
66 build_comparison_for_fields(variant.get_fields_mut(), &ignored_types);
67
68 let left_patterns = variant.get_fields().to_pattern_with_config(
69 variant.get_constructor_path(),
70 syn_helpers::TypeOfSelf::Reference,
71 LEFT_NAME_POSTFIX,
72 );
73 let right_patterns = variant.get_fields().to_pattern_with_config(
74 variant.get_constructor_path(),
75 syn_helpers::TypeOfSelf::Reference,
76 RIGHT_NAME_POSTFIX,
77 );
78 let token_stream = quote! { (#left_patterns, #right_patterns) => #expr };
79 token_stream
80 });
81 let match_stmt = parse_quote! {
82 match (self, other) {
83 #(#branches,)*
84 (_, _) => false
85 }
86 };
87 Ok(vec![match_stmt])
88 }
89 }
90 },
91 );
92
93 let partial_eq_trait = Trait {
95 name: parse_quote!(::std::cmp::PartialEq),
96 generic_parameters: None,
97 items: vec![eq_item],
98 };
99
100 let derive_trait = derive_trait(input, partial_eq_trait);
101 derive_trait.into()
102}
103
104const IGNORE_TYPES: &str = "partial_eq_ignore_types";
105const IGNORE_FIELD: &str = "partial_eq_ignore";
106
107fn build_comparison_for_fields(fields: &mut Fields, ignored_types: &[Type]) -> Expr {
108 let mut top = None::<Expr>;
109
110 for mut field in fields.fields_iterator_mut() {
111 if field
112 .get_attributes()
113 .iter()
114 .any(|attr| attr.path().is_ident(IGNORE_FIELD))
115 {
116 continue;
117 }
118
119 let ignore_type_reference_and_on_tests =
120 ignored_types.iter().any(|ty| field.get_type() == ty);
121
122 if ignore_type_reference_and_on_tests {
123 continue;
124 }
125
126 let lhs = field.get_reference_with_config(true, LEFT_NAME_POSTFIX);
127 let rhs = field.get_reference_with_config(true, RIGHT_NAME_POSTFIX);
128
129 let expr = Expr::Binary(ExprBinary {
130 attrs: Vec::new(),
131 left: Box::new(lhs),
132 op: BinOp::Eq(Token!(==)(Span::call_site())),
133 right: Box::new(rhs),
134 });
135
136 if let Some(old_top) = top {
137 top = Some(Expr::Binary(ExprBinary {
138 attrs: Vec::new(),
139 left: Box::new(old_top),
140 op: BinOp::And(Token!(&&)(Span::call_site())),
141 right: Box::new(expr),
142 }));
143 } else {
144 top = Some(expr);
145 }
146 }
147 top.unwrap_or_else(|| {
148 Expr::Lit(ExprLit {
149 attrs: vec![],
150 lit: Lit::Bool(LitBool {
151 value: true,
152 span: Span::call_site(),
153 }),
154 })
155 })
156}