1#![forbid(unsafe_code)]
2#![deny(clippy::unwrap_used)]
3
4use proc_macro::TokenStream;
5
6use darling::*;
7use proc_macro2::Span;
8use quote::quote;
9use syn::{DeriveInput, Ident, Path, Type, Visibility, parse_macro_input};
10
11#[derive(Debug, FromField)]
12#[darling(attributes(assertr_eq))]
13struct MyFieldReceiver {
14 ident: Option<Ident>,
15
16 ty: Type,
17
18 vis: Visibility,
19
20 #[darling(default)]
21 map_type: Option<Type>,
22
23 #[darling(default)]
24 compare_with: Option<Path>,
25}
26
27#[derive(Debug, FromDeriveInput)]
28#[darling(attributes(assertr_eq), supports(struct_any))]
29struct MyInputReceiver {
30 ident: Ident,
31
32 data: ast::Data<(), MyFieldReceiver>,
33}
34
35impl MyInputReceiver {
36 pub fn fields(&self) -> &ast::Fields<MyFieldReceiver> {
37 match &self.data {
38 ast::Data::Enum(_) => panic!("Only structs are supported"),
39 ast::Data::Struct(fields) => fields,
40 }
41 }
42}
43
44#[proc_macro_derive(AssertrEq, attributes(assertr_eq))]
45pub fn store(input: TokenStream) -> TokenStream {
46 let ast = parse_macro_input!(input as DeriveInput);
47
48 let input: MyInputReceiver = match FromDeriveInput::from_derive_input(&ast) {
49 Ok(args) => args,
50 Err(err) => return Error::write_errors(err).into(),
51 };
52
53 let original_struct_ident = input.ident.clone();
54
55 let filtered_fields = input.fields().iter().filter(|field| match field.vis {
56 Visibility::Public(_) => true,
57 Visibility::Restricted(_) => false,
58 Visibility::Inherited => false,
59 });
60
61 let eq_struct_ident = Ident::new(
62 format!("{}AssertrEq", input.ident).as_str(),
63 Span::call_site(),
64 );
65
66 let eq_struct_fields = filtered_fields.clone().map(|field| {
67 let vis = &field.vis;
68 let ident = &field.ident;
69 let ty = match &field.map_type {
70 None => &field.ty,
71 Some(ty) => ty,
72 };
73 quote! { #vis #ident: ::assertr::Eq<#ty> }
74 });
75
76 let eq_impls = filtered_fields.map(|field| {
77 let ident = field
78 .ident
79 .as_ref()
80 .expect("only named fields are supported!");
81 let ident_string = ident.to_string();
82 let ty = match &field.map_type {
83 None => &field.ty,
84 Some(ty) => ty,
85 };
86 let eq_args = quote! { &self.#ident, v, ctx.as_deref_mut() };
87 let eq_check = match &field.compare_with {
88 None => quote! { ::assertr::AssertrPartialEq::<#ty>::eq(#eq_args) },
89 Some(eq_check) => {
90 quote! { #eq_check(#eq_args) }
91 }
92 };
93 quote! {
94 && match &other.#ident {
95 ::assertr::Eq::Any => true,
96 ::assertr::Eq::Eq(v) => {
97 let eq = #eq_check;
98 if !eq {
99 if let Some(ctx) = ctx.as_mut() {
100 ctx.add_field_difference(#ident_string, v, &self.#ident);
101 }
102 }
103 eq
104 },
105 }
106 }
107 });
108
109 Into::into(quote! {
110 #[derive(::core::fmt::Debug, ::core::default::Default)]
111 pub struct #eq_struct_ident {
112 #(#eq_struct_fields),*
113 }
114
115 impl ::assertr::AssertrPartialEq<#eq_struct_ident> for &#original_struct_ident {
116 fn eq(&self, other: &#eq_struct_ident, mut ctx: Option<&mut ::assertr::EqContext>) -> bool {
117 true #(#eq_impls)*
118 }
119 }
120
121 impl ::assertr::AssertrPartialEq<#eq_struct_ident> for #original_struct_ident {
122 fn eq(&self, other: &#eq_struct_ident, ctx: Option<&mut ::assertr::EqContext>) -> bool {
123 ::assertr::AssertrPartialEq::eq(&self, other, ctx)
124 }
125 }
126 })
127}