partialdebug_derive/
lib.rs

1use proc_macro::TokenStream;
2
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::{quote, ToTokens};
5use syn::parse::{Parse, ParseStream};
6use syn::spanned::Spanned;
7use syn::*;
8
9/// The non exhaustive version of `PartialDebug`
10///
11/// Requires the `unstable` feature.
12/// Only available for structs with named fields.
13#[cfg(feature = "unstable")]
14#[proc_macro_derive(NonExhaustivePartialDebug)]
15pub fn derive_non_exhaustive(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as ItemStruct);
17    let name = input.ident;
18    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
19
20    let fields = match input.fields {
21        Fields::Named(FieldsNamed { named, .. }) => named,
22        Fields::Unit => punctuated::Punctuated::new(),
23        Fields::Unnamed(_) => {
24            return Error::new(Span::call_site(), "non_exhaustive currently is only available on structs with named fields. See https://github.com/rust-lang/rust/issues/67364")
25                .to_compile_error()
26                .into();
27        }
28    };
29
30    let as_debug_all_fields = fields.iter().map(|field| {
31        let name = &field.ident;
32        quote! {
33            match ::partialdebug::specialization::AsDebug::as_debug(&self. #name) {
34                None => {
35                    __exhaustive = false;
36                }
37                Some(field) => {
38                    __s.field(stringify!(#name), field);
39                }
40            }
41        }
42    });
43
44    let expanded = quote! {
45        impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause{
46            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
47                let mut __s = f.debug_struct(stringify!(#name));
48                let mut __exhaustive = false;
49
50                #(#as_debug_all_fields)*
51
52                if __exhaustive {
53                    __s.finish()
54                } else {
55                    __s.finish_non_exhaustive()
56                }
57            }
58        }
59    };
60
61    TokenStream::from(expanded)
62}
63
64/// The placeholder version of `PartialDebug`
65#[proc_macro_derive(PlaceholderPartialDebug, attributes(debug_placeholder))]
66pub fn derive_placeholder(input: TokenStream) -> TokenStream {
67    let input = parse_macro_input!(input as DeriveInput);
68    let placeholder = match get_placeholder(&input) {
69        Ok(placeholder) => placeholder,
70        Err(err) => {
71            return err.to_compile_error().into();
72        }
73    };
74
75    let name = input.ident;
76    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
77
78    let implementation = match input.data {
79        Data::Struct(DataStruct { fields, .. }) => gen_variant_debug(
80            &fields,
81            &name,
82            struct_field_conversions(&fields, &placeholder),
83        ),
84        Data::Enum(data_enum) => gen_enum_debug(&data_enum, &name, &placeholder),
85        Data::Union(_) => {
86            return Error::new(
87                Span::call_site(),
88                "PartialDebug can not be derived for unions",
89            )
90            .to_compile_error()
91            .into();
92        }
93    };
94
95    let expanded = quote! {
96        impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause{
97            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
98                #implementation
99            }
100        }
101    };
102
103    TokenStream::from(expanded)
104}
105
106fn gen_variant_debug(
107    fields: &Fields,
108    variant_name: &Ident,
109    field_conversions: impl Iterator<Item = TokenStream2>,
110) -> TokenStream2 {
111    let constructor = match fields {
112        Fields::Named(_) => quote! {debug_struct},
113        Fields::Unnamed(_) | Fields::Unit => quote! {debug_tuple},
114    };
115
116    quote! {
117        f.#constructor(stringify!(#variant_name))
118        #(#field_conversions)*
119        .finish()
120    }
121}
122
123fn gen_enum_debug(
124    data_enum: &DataEnum,
125    enum_name: &Ident,
126    placeholder: &Option<String>,
127) -> TokenStream2 {
128    let all_variants = data_enum.variants.iter().map(|variant| {
129        let variant_name = &variant.ident;
130        let match_content = gen_variant_debug(
131            &variant.fields,
132            variant_name,
133            enum_field_conversions(&variant.fields, placeholder),
134        );
135        let match_pattern = gen_match_pattern(enum_name, variant);
136        quote! {
137            #match_pattern => {
138                #match_content
139            }
140        }
141    });
142
143    quote! {
144        match self {
145            #(#all_variants)*
146        }
147    }
148}
149
150fn struct_field_conversions<'a>(
151    fields: &'a Fields,
152    placeholder: &'a Option<String>,
153) -> impl Iterator<Item = TokenStream2> + 'a {
154    fields.iter().enumerate().map(move |(idx, field)| {
155        let (field_handle, name_arg) = match &field.ident {
156            None => {
157                let index = Index::from(idx);
158                (quote! {self.#index}, None)
159            }
160            Some(name) => (quote! {self.#name}, Some(quote! {stringify!(#name),})),
161        };
162        gen_field_as_debug(field, placeholder, field_handle, name_arg)
163    })
164}
165
166fn enum_field_conversions<'a>(
167    fields: &'a Fields,
168    placeholder: &'a Option<String>,
169) -> impl Iterator<Item = TokenStream2> + 'a {
170    fields.iter().enumerate().map(move |(idx, field)| {
171        let (field_handle, name_arg) = match &field.ident {
172            None => {
173                let ident = Ident::new(&format!("__{}", idx), field.span());
174                (quote! {#ident}, None)
175            }
176            Some(name) => (quote! {#name}, Some(quote! {stringify!(#name),})),
177        };
178        gen_field_as_debug(field, placeholder, field_handle, name_arg)
179    })
180}
181
182#[cfg(feature = "unstable")]
183fn gen_field_as_debug(
184    field: &Field,
185    placeholder: &Option<String>,
186    field_handle: TokenStream2,
187    name_arg: Option<TokenStream2>,
188) -> TokenStream2 {
189    let type_name = get_type_name(&field.ty);
190
191    // type name or given placeholder string
192    let placeholder_string = placeholder.as_ref().unwrap_or(&type_name);
193
194    quote! {
195        .field(
196            #name_arg
197            match ::partialdebug::specialization::AsDebug::as_debug(&#field_handle){
198                None => &::partialdebug::Placeholder(#placeholder_string),
199                Some(__field) => __field,
200            },
201        )
202    }
203}
204
205#[cfg(not(feature = "unstable"))]
206fn gen_field_as_debug(
207    field: &Field,
208    placeholder: &Option<String>,
209    field_handle: TokenStream2,
210    name_arg: Option<TokenStream2>,
211) -> TokenStream2 {
212    let type_name = get_type_name(&field.ty);
213    let field_type = &field.ty;
214
215    // type name or given placeholder string
216    let placeholder_string = placeholder.as_ref().unwrap_or(&type_name);
217
218    quote! {
219        .field(
220            #name_arg
221            match ::partialdebug::no_specialization::DebugDetector::<#field_type>::as_debug(&#field_handle){
222                None => &::partialdebug::Placeholder(#placeholder_string),
223                Some(__field) => __field,
224            },
225        )
226    }
227}
228
229fn gen_match_pattern(enum_name: &Ident, variant: &Variant) -> TokenStream2 {
230    let variant_name = &variant.ident;
231    let destructuring_pattern = match &variant.fields {
232        Fields::Named(FieldsNamed { named, .. }) => {
233            let patterns = named.iter().map(|field| &field.ident);
234            quote! {
235                {#(#patterns),*}
236            }
237        }
238        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
239            let patterns = unnamed
240                .iter()
241                .enumerate()
242                .map(|(idx, field)| Ident::new(&format!("__{}", idx), field.span()));
243            quote! {
244                (#(#patterns),*)
245            }
246        }
247        Fields::Unit => TokenStream2::new(),
248    };
249
250    quote! {#enum_name::#variant_name #destructuring_pattern}
251}
252
253struct Placeholder(String);
254
255impl Parse for Placeholder {
256    fn parse(input: ParseStream) -> Result<Self> {
257        input.parse::<Token![=]>()?;
258        Ok(Placeholder(input.parse::<LitStr>()?.value()))
259    }
260}
261
262/// Tries to parse a placeholder string if there is one
263fn get_placeholder(input: &DeriveInput) -> Result<Option<String>> {
264    let placeholders: Vec<_> = input
265        .attrs
266        .iter()
267        .filter(|attribute| attribute.path.is_ident("debug_placeholder"))
268        .collect();
269
270    if placeholders.len() > 1 {
271        return Err(Error::new_spanned(
272            placeholders[1],
273            "More than one debug_placeholder attribute",
274        ));
275    }
276
277    placeholders
278        .first()
279        .map(|attribute| {
280            parse2::<Placeholder>(attribute.tokens.clone()).map(|placeholder| placeholder.0)
281        })
282        .transpose()
283}
284
285/// returns the type as a string with unnecessary whitespace removed
286fn get_type_name(ty: &Type) -> String {
287    let mut type_name = String::new();
288    let chars: Vec<char> = ty.to_token_stream().to_string().trim().chars().collect();
289
290    for (i, char) in chars.iter().enumerate() {
291        if char.is_whitespace() {
292            // remove whitespace surrounding punctuation
293            // exceptions are:
294            //      - whitespace surrounding `->`
295            //      - whitespace following `,` or `;`
296            let (before, after) = (chars[i - 1], chars[i + 1]); // always valid because string was trimmed before
297            let before_wide = chars.get(i.saturating_sub(2)..i);
298            let after_wide = chars.get(i + 1..=i + 2);
299
300            if (before.is_ascii_punctuation() || after.is_ascii_punctuation())
301                && !matches!(before, ';' | ',')
302                && !matches!(before_wide, Some(['-', '>']))
303                && !matches!(after_wide, Some(['-', '>']))
304            {
305                continue;
306            }
307        }
308
309        type_name.push(*char);
310    }
311
312    type_name
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    fn test_type_name_formatting(type_str: &str) {
320        let ty: Type = parse_str(type_str).unwrap();
321        assert_eq!(get_type_name(&ty), type_str)
322    }
323
324    #[test]
325    fn test_no_spaces() {
326        test_type_name_formatting("u8");
327        test_type_name_formatting("Option<u8>");
328        test_type_name_formatting("[u8]");
329        test_type_name_formatting("()");
330        test_type_name_formatting("std::fmt::Formatter<'_>");
331    }
332    #[test]
333    fn test_array() {
334        test_type_name_formatting("[u8; 4]");
335    }
336    #[test]
337    fn test_lifetime() {
338        test_type_name_formatting("&'a u8");
339    }
340    #[test]
341    fn test_function() {
342        test_type_name_formatting("fn(u8) -> u8");
343    }
344    #[test]
345    fn test_trait_object() {
346        test_type_name_formatting("Box<dyn Send>");
347    }
348    #[test]
349    fn test_tuple() {
350        test_type_name_formatting("(Option<u8>, u8)");
351    }
352}