locspan_derive/
lib.rs

1//! This library provides the `StrippedPartialEq` derive macro
2//! used to automatically implement the `StrippedPartialEq` comparison
3//! trait defined in the `locspan` library.
4//!
5//! ## Usage
6//!
7//! ```
8//! use locspan::Loc;
9//! use locspan_derive::StrippedPartialEq;
10//!
11//! // Implement `StrippedPartialEq` for the `Foo` type.
12//! // Type parameters will be required to implement
13//! // `StrippedPartialEq` themselves unless they are marked
14//! // with `#[locspan(ignore(...))]`.
15//! #[derive(StrippedPartialEq)]
16//! #[locspan(ignore(S, P))]
17//! struct Foo<T, S, P> {
18//!   a: Loc<T, S, P>,
19//!
20//!   // Files are compared using `StrippedPartialEq`
21//!   // unless they are marked with `#[locspan(stripped)]`, in
22//!   // which case `PartialEq` is used.
23//!   #[locspan(stripped)]
24//!   b: std::path::PathBuf
25//! }
26//! ```
27use proc_macro::{Span, TokenStream};
28use proc_macro_error::{emit_error, proc_macro_error};
29use quote::quote;
30use std::fmt;
31use syn::{parse_macro_input, DeriveInput};
32
33mod eq;
34mod hash;
35mod ord;
36mod partial_eq;
37mod partial_ord;
38pub(crate) mod syntax;
39pub(crate) mod util;
40
41#[derive(Default, Clone, Copy)]
42pub(crate) struct ParamConfig {
43	ignore: bool,
44	stripped: bool,
45	fixed: bool,
46}
47
48enum Access {
49	Direct(proc_macro2::TokenStream),
50	Reference(proc_macro2::TokenStream),
51}
52
53impl Access {
54	pub fn by_ref(&self) -> ByRef {
55		ByRef(self)
56	}
57
58	pub fn by_deref(&self) -> ByDeref {
59		ByDeref(self)
60	}
61}
62
63impl quote::ToTokens for Access {
64	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
65		match self {
66			Access::Direct(t) => t.to_tokens(tokens),
67			Access::Reference(t) => t.to_tokens(tokens),
68		}
69	}
70}
71
72struct ByRef<'a>(&'a Access);
73
74impl<'a> quote::ToTokens for ByRef<'a> {
75	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
76		match &self.0 {
77			Access::Direct(t) => {
78				tokens.extend(quote! { & });
79				t.to_tokens(tokens)
80			}
81			Access::Reference(t) => t.to_tokens(tokens),
82		}
83	}
84}
85
86struct ByDeref<'a>(&'a Access);
87
88impl<'a> quote::ToTokens for ByDeref<'a> {
89	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
90		match &self.0 {
91			Access::Direct(t) => {
92				tokens.extend(quote! { * });
93				t.to_tokens(tokens)
94			}
95			Access::Reference(t) => {
96				tokens.extend(quote! { ** });
97				t.to_tokens(tokens)
98			}
99		}
100	}
101}
102
103pub(crate) enum Method {
104	Normal,
105	Ignore,
106	Stripped,
107	DerefThenStripped,
108	Deref2ThenStripped,
109	UnwrapThenStripped,
110	UnwrapThenDerefThenStripped,
111	UnwrapThenDeref2ThenStripped,
112}
113
114/// Returns an iterator over the fields and access methods for `self`.
115fn fields_access(
116	fields: &syn::Fields,
117) -> impl std::iter::DoubleEndedIterator<Item = (&syn::Field, Access)> {
118	fields.iter().enumerate().map(|(i, field)| {
119		let id = match &field.ident {
120			Some(ident) => quote! { #ident },
121			None => {
122				let index = syn::Index::from(i);
123				quote! { #index }
124			}
125		};
126
127		(field, Access::Direct(quote! { self.#id }))
128	})
129}
130
131/// Returns an iterator over the fields and access methods for `self` and `other`.
132fn fields_access_pairs(
133	fields: &syn::Fields,
134) -> impl std::iter::DoubleEndedIterator<Item = (&syn::Field, (Access, Access))> {
135	fields.iter().enumerate().map(|(i, field)| {
136		let id = match &field.ident {
137			Some(ident) => quote! { #ident },
138			None => {
139				let index = syn::Index::from(i);
140				quote! { #index }
141			}
142		};
143
144		(
145			field,
146			(
147				Access::Direct(quote! { self.#id }),
148				Access::Direct(quote! { other.#id }),
149			),
150		)
151	})
152}
153
154enum VariantArg<'a> {
155	Named(&'a syn::Ident),
156	Unnamed(usize),
157}
158
159impl<'a> fmt::Display for VariantArg<'a> {
160	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161		match self {
162			Self::Named(ident) => ident.fmt(f),
163			Self::Unnamed(i) => i.fmt(f),
164		}
165	}
166}
167
168fn variant_pattern<B>(
169	variant: &syn::Variant,
170	binding: B,
171) -> (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>)
172where
173	B: Fn(VariantArg) -> Option<syn::Ident>,
174{
175	let ident = &variant.ident;
176	let args: Vec<_> = variant
177		.fields
178		.iter()
179		.enumerate()
180		.map(|(i, field)| {
181			let arg = match field.ident.as_ref() {
182				Some(ident) => VariantArg::Named(ident),
183				None => VariantArg::Unnamed(i),
184			};
185
186			match binding(arg) {
187				Some(ident) => quote! { #ident },
188				None => quote! { _ },
189			}
190		})
191		.collect();
192
193	if args.is_empty() {
194		(quote! { #ident }, args)
195	} else {
196		(quote! { #ident ( #(#args),* ) }, args)
197	}
198}
199
200#[proc_macro_error]
201#[proc_macro_derive(StrippedPartialEq, attributes(locspan))]
202pub fn derive_stripped_partial_eq(input: TokenStream) -> TokenStream {
203	let input = parse_macro_input!(input as DeriveInput);
204	match partial_eq::derive(input) {
205		Ok(output) => output.into(),
206		Err(e) => match e {
207			partial_eq::Error::Syntax(e) => e.to_compile_error().into(),
208			partial_eq::Error::Union => {
209				emit_error!(
210					Span::call_site(),
211					"`StrippedPartialEq` derive for unions is not supported"
212				);
213				quote! { false }.into()
214			}
215		},
216	}
217}
218
219#[proc_macro_error]
220#[proc_macro_derive(StrippedEq, attributes(locspan))]
221pub fn derive_stripped_eq(input: TokenStream) -> TokenStream {
222	let input = parse_macro_input!(input as DeriveInput);
223	eq::derive(input).into()
224}
225
226#[proc_macro_error]
227#[proc_macro_derive(StrippedPartialOrd, attributes(locspan))]
228pub fn derive_stripped_partial_ord(input: TokenStream) -> TokenStream {
229	let input = parse_macro_input!(input as DeriveInput);
230	match partial_ord::derive(input) {
231		Ok(output) => output.into(),
232		Err(e) => match e {
233			partial_ord::Error::Syntax(e) => e.to_compile_error().into(),
234			partial_ord::Error::Union => {
235				emit_error!(
236					Span::call_site(),
237					"`StrippedPartialOrd` derive for unions is not supported"
238				);
239				quote! { false }.into()
240			}
241		},
242	}
243}
244
245#[proc_macro_error]
246#[proc_macro_derive(StrippedOrd, attributes(locspan))]
247pub fn derive_stripped_ord(input: TokenStream) -> TokenStream {
248	let input = parse_macro_input!(input as DeriveInput);
249	match ord::derive(input) {
250		Ok(output) => output.into(),
251		Err(e) => match e {
252			ord::Error::Syntax(e) => e.to_compile_error().into(),
253			ord::Error::Union => {
254				emit_error!(
255					Span::call_site(),
256					"`StrippedOrd` derive for unions is not supported"
257				);
258				quote! { false }.into()
259			}
260		},
261	}
262}
263
264#[proc_macro_error]
265#[proc_macro_derive(StrippedHash, attributes(locspan))]
266pub fn derive_stripped_hash(input: TokenStream) -> TokenStream {
267	let input = parse_macro_input!(input as DeriveInput);
268	match hash::derive(input) {
269		Ok(output) => output.into(),
270		Err(e) => match e {
271			hash::Error::Syntax(e) => e.to_compile_error().into(),
272			hash::Error::Union => {
273				emit_error!(
274					Span::call_site(),
275					"`StrippedHash` derive for unions is not supported"
276				);
277				quote! { false }.into()
278			}
279		},
280	}
281}