assertr_derive/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(clippy::unwrap_used)]
3
4use proc_macro::TokenStream;
5
6use darling::*;
7use proc_macro2::Span;
8use quote::quote;
9use syn::{parse_macro_input, DeriveInput, Ident, Path, Type, Visibility};
10
11#[derive(Debug, FromField)]
12#[darling(attributes(assertr_eq))]
13struct MyFieldReceiver {
14    ident: Option<Ident>,
15
16    ty: Type,
17
18    vis: Visibility,
19
20    #[darling(default)]
21    map_type: Option<Type>,
22
23    #[darling(default)]
24    compare_with: Option<Path>,
25}
26
27#[derive(Debug, FromDeriveInput)]
28#[darling(attributes(assertr_eq), supports(struct_any))]
29struct MyInputReceiver {
30    ident: Ident,
31
32    data: ast::Data<(), MyFieldReceiver>,
33}
34
35impl MyInputReceiver {
36    pub fn fields(&self) -> &ast::Fields<MyFieldReceiver> {
37        match &self.data {
38            ast::Data::Enum(_) => panic!("Only structs are supported"),
39            ast::Data::Struct(fields) => fields,
40        }
41    }
42}
43
44#[proc_macro_derive(AssertrEq, attributes(assertr_eq))]
45pub fn store(input: TokenStream) -> TokenStream {
46    let ast = parse_macro_input!(input as DeriveInput);
47
48    let input: MyInputReceiver = match FromDeriveInput::from_derive_input(&ast) {
49        Ok(args) => args,
50        Err(err) => return Error::write_errors(err).into(),
51    };
52
53    let original_struct_ident = input.ident.clone();
54
55    let filtered_fields = input.fields().iter().filter(|field| match field.vis {
56        Visibility::Public(_) => true,
57        Visibility::Restricted(_) => false,
58        Visibility::Inherited => false,
59    });
60
61    let eq_struct_ident = Ident::new(
62        format!("{}AssertrEq", input.ident).as_str(),
63        Span::call_site(),
64    );
65
66    let eq_struct_fields = filtered_fields.clone().map(|field| {
67        let vis = &field.vis;
68        let ident = &field.ident;
69        let ty = match &field.map_type {
70            None => &field.ty,
71            Some(ty) => ty,
72        };
73        quote! { #vis #ident: ::assertr::Eq<#ty> }
74    });
75
76    let eq_impls = filtered_fields.map(|field| {
77        let ident = field
78            .ident
79            .as_ref()
80            .expect("only named fields are supported!");
81        let ident_string = ident.to_string();
82        let ty = match &field.map_type {
83            None => &field.ty,
84            Some(ty) => ty,
85        };
86        let eq_args = quote! { &self.#ident, v, ctx.as_deref_mut() };
87        let eq_check = match &field.compare_with {
88            None => quote! { ::assertr::AssertrPartialEq::<#ty>::eq(#eq_args) },
89            Some(eq_check) => {
90                quote! { #eq_check(#eq_args) }
91            }
92        };
93        quote! {
94            && match &other.#ident {
95                ::assertr::Eq::Any => true,
96                ::assertr::Eq::Eq(v) => {
97                    let eq = #eq_check;
98                    if !eq {
99                        if let Some(ctx) = ctx.as_mut() {
100                            ctx.add_field_difference(#ident_string, v, &self.#ident);
101                        }
102                    }
103                    eq
104                },
105            }
106        }
107    });
108
109    Into::into(quote! {
110        #[derive(::core::fmt::Debug)]
111        pub struct #eq_struct_ident {
112            #(#eq_struct_fields),*
113        }
114
115        impl ::assertr::AssertrPartialEq<#eq_struct_ident> for &#original_struct_ident {
116            fn eq(&self, other: &#eq_struct_ident, mut ctx: Option<&mut ::assertr::EqContext>) -> bool {
117                true #(#eq_impls)*
118            }
119        }
120
121        impl ::assertr::AssertrPartialEq<#eq_struct_ident> for #original_struct_ident {
122            fn eq(&self, other: &#eq_struct_ident, ctx: Option<&mut ::assertr::EqContext>) -> bool {
123                ::assertr::AssertrPartialEq::eq(&self, other, ctx)
124            }
125        }
126    })
127}