comparable_derive 0.5.5

A library for comparing data structures in Rust, oriented toward testing
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use std::collections::BTreeMap;
use std::iter::FromIterator;

pub fn unit_type() -> syn::Type {
	syn::Type::Tuple(syn::TypeTuple {
		paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
		elems: syn::punctuated::Punctuated::new(),
	})
}

pub fn ident_to_type(ident: &syn::Ident) -> syn::Type {
	syn::parse2(quote!(#ident)).unwrap_or_else(|_| panic!("Failed to parse type"))
}

#[allow(dead_code)]
pub fn vec_type(ty: &syn::Type) -> syn::Type {
	syn::parse2(quote!(Vec<#ty>)).unwrap_or_else(|_| panic!("Failed to parse Vec type"))
}

pub fn has_attr<'a>(attrs: &'a [syn::Attribute], attr_name: &str) -> Option<&'a syn::Attribute> {
	attrs.iter().find(|attr| attr.path.is_ident(attr_name))
}

#[allow(dead_code)]
pub fn data_from_variant(variant: &syn::Variant) -> syn::Data {
	syn::Data::Struct(syn::DataStruct {
		fields: variant.fields.clone(),
		struct_token: Default::default(),
		semi_token: Default::default(),
	})
}

pub fn map_on_fields_over_data(
	inject_synthetics: bool,
	data: &syn::Data,
	f: impl FnMut(&FieldRef) -> syn::Field + Copy,
) -> syn::Data {
	match data {
		syn::Data::Struct(st) => map_on_fields_over_datastruct(inject_synthetics, st, f),
		syn::Data::Enum(en) => syn::Data::Enum(syn::DataEnum {
			variants: FromIterator::from_iter(map_variants(&en.variants, move |v| syn::Variant {
				fields: map_on_fields(inject_synthetics, &v.fields, f),
				..v.clone()
			})),
			..*en
		}),
		syn::Data::Union(un) => syn::Data::Union(syn::DataUnion {
			fields: syn::FieldsNamed {
				named: FromIterator::from_iter(map_fields(inject_synthetics, un.fields.named.iter(), true, f)),
				..un.fields.clone()
			},
			..*un
		}),
	}
}

pub fn map_on_fields_over_datastruct(
	inject_synthetics: bool,
	st: &syn::DataStruct,
	f: impl FnMut(&FieldRef) -> syn::Field,
) -> syn::Data {
	syn::Data::Struct(syn::DataStruct { fields: map_on_fields(inject_synthetics, &st.fields, f), ..*st })
}

pub fn _map_on_variants_over_dataenum(en: &syn::DataEnum, f: impl FnMut(&syn::Variant) -> syn::Variant) -> syn::Data {
	syn::Data::Enum(syn::DataEnum { variants: FromIterator::from_iter(map_variants(en.variants.iter(), f)), ..*en })
}

pub fn map_on_fields(
	inject_synthetics: bool,
	fields: &syn::Fields,
	f: impl FnMut(&FieldRef) -> syn::Field,
) -> syn::Fields {
	match fields {
		syn::Fields::Named(named) => syn::Fields::Named(syn::FieldsNamed {
			named: FromIterator::from_iter(map_fields(inject_synthetics, named.named.iter(), true, f)),
			..*named
		}),
		syn::Fields::Unnamed(unnamed) => syn::Fields::Unnamed(syn::FieldsUnnamed {
			unnamed: FromIterator::from_iter(map_fields(inject_synthetics, unnamed.unnamed.iter(), true, f)),
			..*unnamed
		}),
		syn::Fields::Unit => syn::Fields::Unit,
	}
}

pub struct FieldRef<'a> {
	pub index: usize,
	pub field: &'a syn::Field,
	pub accessor: Box<dyn Fn(&syn::Ident) -> syn::Expr>,
}

fn standard_accessor(index: usize, field: &syn::Field) -> Box<dyn Fn(&syn::Ident) -> syn::Expr> {
	let ident = match &field.ident {
		None => {
			let idx = syn::Index::from(index);
			quote!(#idx)
		}
		Some(ident) => quote!(#ident),
	};
	Box::new(move |x| syn::parse2(quote!(#x.#ident)).expect("Could not create standard accessor!"))
}

pub fn map_fields<'a, 'b: 'a, R>(
	inject_synthetics: bool,
	fields: impl IntoIterator<Item = &'a syn::Field>,
	allow_ignore: bool,
	mut f: impl FnMut(&FieldRef) -> R,
) -> Vec<R> {
	let mut index = 0;
	let mut result = Vec::new();
	fields.into_iter().for_each(|field| {
		if inject_synthetics {
			if let Some(synthetics) = has_attr(&field.attrs, "comparable_synthetic").map(|attr| {
				parse_synthetics(&attr.tokens)
					.expect("Argument to comparable_synthetic must be a set of field values in braces")
			}) {
				synthetics.into_iter().for_each(|(ident, closure)| {
					result.push(f(&FieldRef {
						index,
						field: &syn::Field {
							ident: Some(ident),
							ty: match &closure.output {
								syn::ReturnType::Default => unit_type(),
								syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
							},
							attrs: Default::default(),
							vis: syn::Visibility::Inherited,
							colon_token: Default::default(),
						},
						accessor: Box::new(move |x| {
							syn::parse2(quote!((#closure)(&#x))).expect("Could not create synthetic accessor!")
						}),
					}));
					index += 1;
				});
			}
		}
		if has_attr(&field.attrs, "comparable_ignore").is_none() || !allow_ignore {
			result.push(f(&FieldRef { index, field, accessor: standard_accessor(index, field) }));
		}
		index += 1;
	});
	result
}

pub fn field_count<'a>(inject_synthetics: bool, fields: impl IntoIterator<Item = &'a syn::Field>) -> usize {
	map_fields(inject_synthetics, fields, true, |_| ()).len()
}

pub fn map_variants<'a, R>(
	variants: impl IntoIterator<Item = &'a syn::Variant>,
	f: impl FnMut(&syn::Variant) -> R,
) -> Vec<R> {
	variants.into_iter().map(f).collect()
}

fn parse_synthetics(tokens: &TokenStream) -> Result<BTreeMap<syn::Ident, syn::ExprClosure>, syn::Error> {
	let block: syn::Block = syn::parse2(tokens.clone())?;
	Ok(block
		.stmts
		.into_iter()
		.map(|s| {
			if let syn::Stmt::Local(syn::Local {
				attrs: _,
				let_token: _,
				pat: syn::Pat::Ident(syn::PatIdent { attrs: _, by_ref: _, mutability: _, ident, subpat: _ }),
				init: Some((_, expr)),
				semi_token: _,
			}) = s
			{
				if let syn::Expr::Closure(closure) = *expr {
					(ident, closure)
				} else {
					panic!("Let values in comparable_synthetic must be fully typed closures")
				}
			} else {
				panic!("Syntax error parsing argument to comparable_synthetic")
			}
		})
		.collect())
}

pub fn generate_type_definition(visibility: &syn::Visibility, type_name: &syn::Ident, data: &syn::Data) -> TokenStream {
	let (keyword, body) = match data {
		syn::Data::Struct(st) => (
			quote!(struct),
			match &st.fields {
				syn::Fields::Named(named) => {
					let fields = map_fields(false, named.named.iter(), true, |r| {
						let vis = &r.field.vis;
						let ident = r.field.ident.as_ref().expect("Found unnamed field in named struct");
						let ty = &r.field.ty;
						quote!(#vis #ident: #ty)
					});
					quote! {
						{
							#(#fields),*
						}
					}
				}
				syn::Fields::Unnamed(unnamed) => {
					let (field_vis, field_types): (Vec<syn::Visibility>, Vec<syn::Type>) =
						map_fields(false, unnamed.unnamed.iter(), true, |r| (r.field.vis.clone(), r.field.ty.clone()))
							.into_iter()
							.unzip();
					quote! {
						(#(#field_vis #field_types),*);
					}
				}
				syn::Fields::Unit => {
					quote! { ; }
				}
			},
		),
		syn::Data::Enum(en) => (quote!(enum), {
			let variants = map_variants(en.variants.iter(), |variant| {
				let variant_name = &variant.ident;
				match &variant.fields {
					syn::Fields::Named(named) => {
						let fields = map_fields(false, named.named.iter(), true, |r| {
							let vis = &r.field.vis;
							let ident = r.field.ident.as_ref().expect("Found unnamed field in named struct");
							let ty = &r.field.ty;
							quote!(#vis #ident: #ty)
						});
						quote! {
							#variant_name { #(#fields),* }
						}
					}
					syn::Fields::Unnamed(unnamed) => {
						let (field_vis, field_types): (Vec<syn::Visibility>, Vec<syn::Type>) =
							map_fields(false, unnamed.unnamed.iter(), true, |r| {
								(r.field.vis.clone(), r.field.ty.clone())
							})
							.into_iter()
							.unzip();
						quote! {
							#variant_name(#(#field_vis #field_types),*)
						}
					}
					syn::Fields::Unit => {
						quote! {
							#variant_name
						}
					}
				}
			});
			quote! {
				{
					#(#variants),*
				}
			}
		}),
		syn::Data::Union(_un) => {
			panic!("comparable_derive::generate_type_definition not implemented for unions")
		}
	};
	let derive_serde = if cfg!(feature = "serde") {
		quote! {
			#[derive(serde::Serialize, serde::Deserialize)]
		}
	} else {
		quote! {}
	};
	quote! {
		#derive_serde
		#[derive(PartialEq, Debug)]
		#visibility #keyword #type_name#body
	}
}