multi_eq/
lib.rs

1//! # `multi_eq`
2//! `multi_eq` is a macro library for creating custom equality derives.
3//!
4//! ## Description
5//! This crate exports two macros:
6//! [`multi_eq_make_trait!()`](multi_eq_make_trait), and
7//! [`multi_eq_make_derive!()`](multi_eq_make_derive). The first is for creating
8//! custom equality traits. The second is for creating a derive macro for a
9//! custom equality trait. Since derive macros can only be exported by a crate
10//! with the `proc-macro` crate type, a typical usage of this library is in
11//! multi-crate projects: a `proc-macro` crate for the derive macros, and a main
12//! crate importing the derive macros.
13//!
14//! ## Example
15//!
16//! ### File tree
17//! ```text
18//! custom-eq-example
19//! ├── Cargo.lock
20//! ├── Cargo.toml
21//! ├── custom-eq-derive
22//! │   ├── Cargo.lock
23//! │   ├── Cargo.toml
24//! │   └── src
25//! │       └── lib.rs
26//! └── src
27//!     └── lib.rs
28//! ```
29//!
30//! #### `custom-eq-example/custom-eq-derive/Cargo.toml`
31//! ```toml
32//! # ...
33//!
34//! [lib]
35//! proc-macro = true
36//!
37//! # ...
38//! ```
39//!
40//! #### `custom-eq-example/custom-eq-derive/src/lib.rs`
41//! ```ignore
42//! use multi_eq::*;
43//!
44//! /// Derive macro for a comparison trait `CustomEq` with a method `custom_eq`
45//! multi_eq_make_derive!(pub, CustomEq, custom_eq);
46//! ```
47//!
48//! #### `custom-eq-example/Cargo.toml`
49//! ```toml
50//! # ...
51//!
52//! [dependencies.custom-eq-derive]
53//! path = "custom-eq-derive"
54//!
55//! # ...
56//! ```
57//!
58//! #### `custom-eq-example/src/lib.rs`
59//! ```ignore
60//! use multi_eq::*;
61//! use custom_eq_derive::*;
62//!
63//! /// Custom comparison trait `CustomEq` with a method `custom_eq`
64//! multi_eq_make_trait!(CustomEq, custom_eq);
65//!
66//! #[derive(CustomEq)]
67//! struct MyStruct {
68//!   // Use `PartialEq` to compare this field
69//!   #[custom_eq(cmp = "eq")]
70//!   a: u32,
71//!
72//!   // Ignore value of this field when checking equality
73//!   #[custom_eq(ignore)]
74//!   b: bool,
75//! }
76//! ```
77
78pub extern crate proc_macro as multi_eq_proc_macro;
79pub extern crate proc_macro2 as multi_eq_proc_macro2;
80pub extern crate quote as multi_eq_quote;
81pub extern crate syn as multi_eq_syn;
82
83/// Macro to define a comparison trait
84///
85/// The format of the generated trait is the same as
86/// [`PartialEq`](std::cmp::PartialEq), but with potentially different names.
87///
88/// ## Parameters:
89///   * `vis` - optional visibility specifier
90///   * `trait_name` - name of the trait being defined
91///   * `method_name` - name of the method in the trait
92///
93/// ## Example:
94/// ```rust
95/// use multi_eq::*;
96///
97/// multi_eq_make_trait!(pub, PublicCustomEq, custom_eq);
98/// multi_eq_make_trait!(PrivateCustomEq, eq);
99/// ```
100///
101/// ## Generated code:
102/// ```rust
103/// pub trait PublicCustomEq {
104///     fn custom_eq(&self, other: &Self) -> bool;
105/// }
106///
107/// trait PrivateCustomEq {
108///     fn eq(&self, other: &Self) -> bool;
109/// }
110/// ```
111#[macro_export]
112macro_rules! multi_eq_make_trait {
113    ($vis:vis, $trait_name:ident, $method_name:ident) => {
114        $vis trait $trait_name {
115            fn $method_name(&self, other: &Self) -> bool;
116        }
117    };
118    ($trait_name:ident, $method_name:ident) => {
119        trait $trait_name {
120            fn $method_name(&self, other: &Self) -> bool;
121        }
122    };
123}
124
125/// Macro to define a derive macro for a comparison trait
126///
127/// (Yes, this macro generates another macro that generates code) The format of
128/// the derived trait is the same as [`PartialEq`](std::cmp::PartialEq), but
129/// with potentially different names.
130///
131/// ## Note:
132/// This macro can only be used in crates with the `proc-macro` crate type.
133///
134/// ## Parameters:
135///   * `vis` - visibility specifier of the generated derive macro
136///   * `trait_name` - name of the trait to derive
137///   * `method_name` - name of the method in the trait, also used as the name
138///                     of the proc macro
139///
140/// ## Field attributes:
141/// Note that `method_name` refers to the `method_name` parameter supplied to
142/// the macro.
143///   * `#[method_name(cmp = "custom_comparison_method")]`
144///      * Instead of using the derived trait's method to compare this field,
145///     use `custom_comparison_method`.
146///   * `#[method_name(ignore)]`
147///     * When doing equality checking, ignore this field.
148///
149/// ## Example:
150/// ```ignore
151/// use multi_eq::*; // This global import is required for the macro to function
152///
153/// multi_eq_make_derive!(pub, CustomEq, custom_eq);
154/// ```
155///
156/// ## Derive usage example:
157/// ```ignore
158/// #[derive(CustomEq)]
159/// struct MyStruct {
160///   // Use `PartialEq` to compare this field
161///   #[custom_eq(cmp = "eq")]
162///   a: u32,
163///
164///   // Ignore value of this field when checking equality
165///   #[custom_eq(ignore)]
166///   b: bool,
167/// }
168/// ```
169#[macro_export]
170macro_rules! multi_eq_make_derive {
171    ($vis:vis, $trait_name:ident, $method_name:ident) => {
172        #[proc_macro_derive($trait_name, attributes($method_name))]
173        $vis fn $method_name(
174            input: multi_eq_proc_macro::TokenStream
175        ) -> multi_eq_proc_macro::TokenStream {
176            use multi_eq_quote::quote;
177            use multi_eq_quote::ToTokens;
178            use multi_eq_quote::format_ident;
179            use multi_eq_syn as syn;
180            use multi_eq_proc_macro2::TokenStream as TokenStream2;
181
182            let input = syn::parse::<syn::DeriveInput>(input).unwrap();
183            let input_ident = input.ident;
184
185            fn path_is(path: &syn::Path, s: &str) -> bool {
186                let segs = &path.segments;
187                segs.len() == 1 && {
188                    let seg = &segs[0];
189                    seg.arguments.is_empty() && seg.ident.to_string() == s
190                }
191            }
192
193            fn lit_is_str(lit: &syn::Lit, s: &str) -> bool {
194                match lit {
195                    syn::Lit::Str(lit_str) => lit_str.value() == s,
196                    _ => false,
197                }
198            }
199
200            fn get_cmp_method_name(attr: &syn::Attribute) -> Option<String> {
201                let method_name = stringify!($method_name);
202
203                match attr.parse_meta() {
204                    Ok(syn::Meta::List(meta_list)) if path_is(&meta_list.path, method_name) => {
205                        meta_list.nested.iter().find_map(|nested_meta| match nested_meta {
206                            syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
207                                path, lit: syn::Lit::Str(lit_str), ..
208                            })) if path_is(path, "cmp") => Some(lit_str.value()),
209                            _ => None,
210                        })
211                    }
212                    _ => None,
213                }
214            }
215
216            fn is_ignore(attr: &syn::Attribute) -> bool {
217                let method_name = stringify!($method_name);
218
219                match attr.parse_meta() {
220                    Ok(syn::Meta::List(meta_list)) if path_is(&meta_list.path, method_name) => {
221                        meta_list.nested.iter().any(|nested_meta| match nested_meta {
222                            syn::NestedMeta::Meta(syn::Meta::Path(path)) => path_is(path, "ignore"),
223                            _ => false,
224                        })
225                    }
226                    _ => false,
227                }
228            }
229
230            fn fields_eq<I: Iterator<Item = syn::Field>>(fields: I) -> TokenStream2 {
231                fields.enumerate().fold(quote!(true), |acc, (i, field)| {
232                    let name = match field.ident {
233                        Some(ident) => format_ident!("{}", ident),
234                        None => format_ident!("v{}", i),
235                    };
236                    let method_name = match field.attrs.iter().find_map(get_cmp_method_name) {
237                        Some(name) => format_ident!("{}", name),
238                        None => format_ident!("{}", stringify!($method_name)),
239                    };
240		    let refr = if let syn::Type::Reference(_) = field.ty {
241			quote!()
242		    } else {
243			quote!(&)
244		    };
245                    if field.attrs.iter().any(is_ignore) {
246                        acc
247                    } else {
248                        quote!(#acc && self.#name.#method_name(#refr other.#name))
249                    }
250                })
251            };
252
253            struct ArmAcc {
254                pattern_left: TokenStream2,
255                pattern_right: TokenStream2,
256                body: TokenStream2,
257            }
258
259            fn gen_match_arm<I: Iterator<Item = syn::Field>>(fields: I) -> ArmAcc {
260                fields.enumerate().fold(ArmAcc {
261                    pattern_left: TokenStream2::new(),
262                    pattern_right: TokenStream2::new(),
263                    body: quote!(true),
264                }, |ArmAcc { pattern_left, pattern_right, body }, (i, field)| {
265                    let named = field.ident.is_some();
266                    let (name_base) = match field.ident {
267                        Some(ident) => format_ident!("{}", ident),
268                        None => format_ident!("v{}", i),
269                    };
270                    let name_1 = format_ident!("{}_1", name_base);
271                    let name_2 = format_ident!("{}_2", name_base);
272                    let method_name = match field.attrs.iter().find_map(get_cmp_method_name) {
273                        Some(name) => format_ident!("{}", name),
274                        None => format_ident!("{}", stringify!($method_name)),
275                    };
276                    let cmp_expr = if field.attrs.iter().any(is_ignore) {
277                        quote!(true)
278                    } else {
279                        quote!(#name_1.#method_name(#name_2))
280                    };
281                    ArmAcc {
282                        pattern_left: match (named, i == 0) {
283                            (true, true) => quote!(#name_base: #name_1),
284                            (false, true) => quote!(#name_1),
285                            (true, false) => quote!(#pattern_left, #name_base: #name_1),
286                            (false, false) => quote!(#pattern_left, #name_1),
287                        },
288                        pattern_right: match (named, i == 0) {
289                            (true, true) => quote!(#name_base: #name_2),
290                            (false, true) => quote!(#name_2),
291                            (true, false) => quote!(#pattern_right, #name_base: #name_2),
292                            (false, false) => quote!(#pattern_right, #name_2),
293                        },
294                        body: quote!(#body && #cmp_expr),
295                    }
296                })
297            };
298
299            let expr = match input.data {
300                syn::Data::Struct(syn::DataStruct {
301                    fields: syn::Fields::Named(fields),
302                    ..
303                }) => fields_eq(fields.named.iter().cloned()),
304                syn::Data::Struct(syn::DataStruct {
305                    fields: syn::Fields::Unnamed(fields),
306                    ..
307                }) => fields_eq(fields.unnamed.iter().cloned()),
308                syn::Data::Struct(syn::DataStruct {
309                    fields: syn::Fields::Unit,
310                    ..
311                }) => quote!(true).into(),
312                syn::Data::Enum(inner) => {
313                    let arms = inner
314                        .variants
315                        .iter()
316                        .map(|syn::Variant { ident, fields, .. }| {
317                            match fields {
318                                syn::Fields::Named(named) => {
319                                    let ArmAcc {
320                                        pattern_left,
321                                        pattern_right,
322                                        body
323                                    } = gen_match_arm(named.named.iter().cloned());
324                                    quote!((#input_ident::#ident { #pattern_left },
325                                            #input_ident::#ident { #pattern_right }) => #body,)
326                                }
327                                syn::Fields::Unnamed(unnamed) => {
328                                    let ArmAcc {
329                                        pattern_left,
330                                        pattern_right,
331                                        body
332                                    } = gen_match_arm(unnamed.unnamed.iter().cloned());
333                                    quote!((#input_ident::#ident(#pattern_left),
334                                            #input_ident::#ident(#pattern_right)) => #body,)
335                                }
336                                syn::Fields::Unit => quote!((#input_ident::#ident, #input_ident::#ident) => true,),
337                            }
338                        });
339                    let arms = arms.fold(quote!(), |accum, arm| quote!(#accum #arm));
340                    let arms = quote!(#arms (_, _) => false,);
341                    let match_expr = quote!( match (self, other) { #arms } );
342                    match_expr
343                }
344                syn::Data::Union(_) => panic!("unions are not supported"),
345            };
346
347	    let generics = input.generics;
348
349            let ret = quote! {
350                impl #generics $trait_name for #input_ident #generics {
351                    fn $method_name(&self, other: &Self) -> bool {
352                        #expr
353                    }
354                }
355            };
356            ret.into()
357        }
358    }
359}