multi_eq/lib.rs
1//! # `multi_eq`
2//! `multi_eq` is a macro library for creating custom equality derives.
3//!
4//! ## Description
5//! This crate exports two macros:
6//! [`multi_eq_make_trait!()`](multi_eq_make_trait), and
7//! [`multi_eq_make_derive!()`](multi_eq_make_derive). The first is for creating
8//! custom equality traits. The second is for creating a derive macro for a
9//! custom equality trait. Since derive macros can only be exported by a crate
10//! with the `proc-macro` crate type, a typical usage of this library is in
11//! multi-crate projects: a `proc-macro` crate for the derive macros, and a main
12//! crate importing the derive macros.
13//!
14//! ## Example
15//!
16//! ### File tree
17//! ```text
18//! custom-eq-example
19//! ├── Cargo.lock
20//! ├── Cargo.toml
21//! ├── custom-eq-derive
22//! │ ├── Cargo.lock
23//! │ ├── Cargo.toml
24//! │ └── src
25//! │ └── lib.rs
26//! └── src
27//! └── lib.rs
28//! ```
29//!
30//! #### `custom-eq-example/custom-eq-derive/Cargo.toml`
31//! ```toml
32//! # ...
33//!
34//! [lib]
35//! proc-macro = true
36//!
37//! # ...
38//! ```
39//!
40//! #### `custom-eq-example/custom-eq-derive/src/lib.rs`
41//! ```ignore
42//! use multi_eq::*;
43//!
44//! /// Derive macro for a comparison trait `CustomEq` with a method `custom_eq`
45//! multi_eq_make_derive!(pub, CustomEq, custom_eq);
46//! ```
47//!
48//! #### `custom-eq-example/Cargo.toml`
49//! ```toml
50//! # ...
51//!
52//! [dependencies.custom-eq-derive]
53//! path = "custom-eq-derive"
54//!
55//! # ...
56//! ```
57//!
58//! #### `custom-eq-example/src/lib.rs`
59//! ```ignore
60//! use multi_eq::*;
61//! use custom_eq_derive::*;
62//!
63//! /// Custom comparison trait `CustomEq` with a method `custom_eq`
64//! multi_eq_make_trait!(CustomEq, custom_eq);
65//!
66//! #[derive(CustomEq)]
67//! struct MyStruct {
68//! // Use `PartialEq` to compare this field
69//! #[custom_eq(cmp = "eq")]
70//! a: u32,
71//!
72//! // Ignore value of this field when checking equality
73//! #[custom_eq(ignore)]
74//! b: bool,
75//! }
76//! ```
77
78pub extern crate proc_macro as multi_eq_proc_macro;
79pub extern crate proc_macro2 as multi_eq_proc_macro2;
80pub extern crate quote as multi_eq_quote;
81pub extern crate syn as multi_eq_syn;
82
83/// Macro to define a comparison trait
84///
85/// The format of the generated trait is the same as
86/// [`PartialEq`](std::cmp::PartialEq), but with potentially different names.
87///
88/// ## Parameters:
89/// * `vis` - optional visibility specifier
90/// * `trait_name` - name of the trait being defined
91/// * `method_name` - name of the method in the trait
92///
93/// ## Example:
94/// ```rust
95/// use multi_eq::*;
96///
97/// multi_eq_make_trait!(pub, PublicCustomEq, custom_eq);
98/// multi_eq_make_trait!(PrivateCustomEq, eq);
99/// ```
100///
101/// ## Generated code:
102/// ```rust
103/// pub trait PublicCustomEq {
104/// fn custom_eq(&self, other: &Self) -> bool;
105/// }
106///
107/// trait PrivateCustomEq {
108/// fn eq(&self, other: &Self) -> bool;
109/// }
110/// ```
111#[macro_export]
112macro_rules! multi_eq_make_trait {
113 ($vis:vis, $trait_name:ident, $method_name:ident) => {
114 $vis trait $trait_name {
115 fn $method_name(&self, other: &Self) -> bool;
116 }
117 };
118 ($trait_name:ident, $method_name:ident) => {
119 trait $trait_name {
120 fn $method_name(&self, other: &Self) -> bool;
121 }
122 };
123}
124
125/// Macro to define a derive macro for a comparison trait
126///
127/// (Yes, this macro generates another macro that generates code) The format of
128/// the derived trait is the same as [`PartialEq`](std::cmp::PartialEq), but
129/// with potentially different names.
130///
131/// ## Note:
132/// This macro can only be used in crates with the `proc-macro` crate type.
133///
134/// ## Parameters:
135/// * `vis` - visibility specifier of the generated derive macro
136/// * `trait_name` - name of the trait to derive
137/// * `method_name` - name of the method in the trait, also used as the name
138/// of the proc macro
139///
140/// ## Field attributes:
141/// Note that `method_name` refers to the `method_name` parameter supplied to
142/// the macro.
143/// * `#[method_name(cmp = "custom_comparison_method")]`
144/// * Instead of using the derived trait's method to compare this field,
145/// use `custom_comparison_method`.
146/// * `#[method_name(ignore)]`
147/// * When doing equality checking, ignore this field.
148///
149/// ## Example:
150/// ```ignore
151/// use multi_eq::*; // This global import is required for the macro to function
152///
153/// multi_eq_make_derive!(pub, CustomEq, custom_eq);
154/// ```
155///
156/// ## Derive usage example:
157/// ```ignore
158/// #[derive(CustomEq)]
159/// struct MyStruct {
160/// // Use `PartialEq` to compare this field
161/// #[custom_eq(cmp = "eq")]
162/// a: u32,
163///
164/// // Ignore value of this field when checking equality
165/// #[custom_eq(ignore)]
166/// b: bool,
167/// }
168/// ```
169#[macro_export]
170macro_rules! multi_eq_make_derive {
171 ($vis:vis, $trait_name:ident, $method_name:ident) => {
172 #[proc_macro_derive($trait_name, attributes($method_name))]
173 $vis fn $method_name(
174 input: multi_eq_proc_macro::TokenStream
175 ) -> multi_eq_proc_macro::TokenStream {
176 use multi_eq_quote::quote;
177 use multi_eq_quote::ToTokens;
178 use multi_eq_quote::format_ident;
179 use multi_eq_syn as syn;
180 use multi_eq_proc_macro2::TokenStream as TokenStream2;
181
182 let input = syn::parse::<syn::DeriveInput>(input).unwrap();
183 let input_ident = input.ident;
184
185 fn path_is(path: &syn::Path, s: &str) -> bool {
186 let segs = &path.segments;
187 segs.len() == 1 && {
188 let seg = &segs[0];
189 seg.arguments.is_empty() && seg.ident.to_string() == s
190 }
191 }
192
193 fn lit_is_str(lit: &syn::Lit, s: &str) -> bool {
194 match lit {
195 syn::Lit::Str(lit_str) => lit_str.value() == s,
196 _ => false,
197 }
198 }
199
200 fn get_cmp_method_name(attr: &syn::Attribute) -> Option<String> {
201 let method_name = stringify!($method_name);
202
203 match attr.parse_meta() {
204 Ok(syn::Meta::List(meta_list)) if path_is(&meta_list.path, method_name) => {
205 meta_list.nested.iter().find_map(|nested_meta| match nested_meta {
206 syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
207 path, lit: syn::Lit::Str(lit_str), ..
208 })) if path_is(path, "cmp") => Some(lit_str.value()),
209 _ => None,
210 })
211 }
212 _ => None,
213 }
214 }
215
216 fn is_ignore(attr: &syn::Attribute) -> bool {
217 let method_name = stringify!($method_name);
218
219 match attr.parse_meta() {
220 Ok(syn::Meta::List(meta_list)) if path_is(&meta_list.path, method_name) => {
221 meta_list.nested.iter().any(|nested_meta| match nested_meta {
222 syn::NestedMeta::Meta(syn::Meta::Path(path)) => path_is(path, "ignore"),
223 _ => false,
224 })
225 }
226 _ => false,
227 }
228 }
229
230 fn fields_eq<I: Iterator<Item = syn::Field>>(fields: I) -> TokenStream2 {
231 fields.enumerate().fold(quote!(true), |acc, (i, field)| {
232 let name = match field.ident {
233 Some(ident) => format_ident!("{}", ident),
234 None => format_ident!("v{}", i),
235 };
236 let method_name = match field.attrs.iter().find_map(get_cmp_method_name) {
237 Some(name) => format_ident!("{}", name),
238 None => format_ident!("{}", stringify!($method_name)),
239 };
240 let refr = if let syn::Type::Reference(_) = field.ty {
241 quote!()
242 } else {
243 quote!(&)
244 };
245 if field.attrs.iter().any(is_ignore) {
246 acc
247 } else {
248 quote!(#acc && self.#name.#method_name(#refr other.#name))
249 }
250 })
251 };
252
253 struct ArmAcc {
254 pattern_left: TokenStream2,
255 pattern_right: TokenStream2,
256 body: TokenStream2,
257 }
258
259 fn gen_match_arm<I: Iterator<Item = syn::Field>>(fields: I) -> ArmAcc {
260 fields.enumerate().fold(ArmAcc {
261 pattern_left: TokenStream2::new(),
262 pattern_right: TokenStream2::new(),
263 body: quote!(true),
264 }, |ArmAcc { pattern_left, pattern_right, body }, (i, field)| {
265 let named = field.ident.is_some();
266 let (name_base) = match field.ident {
267 Some(ident) => format_ident!("{}", ident),
268 None => format_ident!("v{}", i),
269 };
270 let name_1 = format_ident!("{}_1", name_base);
271 let name_2 = format_ident!("{}_2", name_base);
272 let method_name = match field.attrs.iter().find_map(get_cmp_method_name) {
273 Some(name) => format_ident!("{}", name),
274 None => format_ident!("{}", stringify!($method_name)),
275 };
276 let cmp_expr = if field.attrs.iter().any(is_ignore) {
277 quote!(true)
278 } else {
279 quote!(#name_1.#method_name(#name_2))
280 };
281 ArmAcc {
282 pattern_left: match (named, i == 0) {
283 (true, true) => quote!(#name_base: #name_1),
284 (false, true) => quote!(#name_1),
285 (true, false) => quote!(#pattern_left, #name_base: #name_1),
286 (false, false) => quote!(#pattern_left, #name_1),
287 },
288 pattern_right: match (named, i == 0) {
289 (true, true) => quote!(#name_base: #name_2),
290 (false, true) => quote!(#name_2),
291 (true, false) => quote!(#pattern_right, #name_base: #name_2),
292 (false, false) => quote!(#pattern_right, #name_2),
293 },
294 body: quote!(#body && #cmp_expr),
295 }
296 })
297 };
298
299 let expr = match input.data {
300 syn::Data::Struct(syn::DataStruct {
301 fields: syn::Fields::Named(fields),
302 ..
303 }) => fields_eq(fields.named.iter().cloned()),
304 syn::Data::Struct(syn::DataStruct {
305 fields: syn::Fields::Unnamed(fields),
306 ..
307 }) => fields_eq(fields.unnamed.iter().cloned()),
308 syn::Data::Struct(syn::DataStruct {
309 fields: syn::Fields::Unit,
310 ..
311 }) => quote!(true).into(),
312 syn::Data::Enum(inner) => {
313 let arms = inner
314 .variants
315 .iter()
316 .map(|syn::Variant { ident, fields, .. }| {
317 match fields {
318 syn::Fields::Named(named) => {
319 let ArmAcc {
320 pattern_left,
321 pattern_right,
322 body
323 } = gen_match_arm(named.named.iter().cloned());
324 quote!((#input_ident::#ident { #pattern_left },
325 #input_ident::#ident { #pattern_right }) => #body,)
326 }
327 syn::Fields::Unnamed(unnamed) => {
328 let ArmAcc {
329 pattern_left,
330 pattern_right,
331 body
332 } = gen_match_arm(unnamed.unnamed.iter().cloned());
333 quote!((#input_ident::#ident(#pattern_left),
334 #input_ident::#ident(#pattern_right)) => #body,)
335 }
336 syn::Fields::Unit => quote!((#input_ident::#ident, #input_ident::#ident) => true,),
337 }
338 });
339 let arms = arms.fold(quote!(), |accum, arm| quote!(#accum #arm));
340 let arms = quote!(#arms (_, _) => false,);
341 let match_expr = quote!( match (self, other) { #arms } );
342 match_expr
343 }
344 syn::Data::Union(_) => panic!("unions are not supported"),
345 };
346
347 let generics = input.generics;
348
349 let ret = quote! {
350 impl #generics $trait_name for #input_ident #generics {
351 fn $method_name(&self, other: &Self) -> bool {
352 #expr
353 }
354 }
355 };
356 ret.into()
357 }
358 }
359}