pulldown_html_ext_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_attribute]
6pub fn html_writer(attr: TokenStream, input: TokenStream) -> TokenStream {
7    let skip_docs = attr.to_string().contains("skip_docs");
8    let input = parse_macro_input!(input as DeriveInput);
9
10    match process_html_writer(&input, skip_docs) {
11        Ok(output) => output.into(),
12        Err(err) => err.to_compile_error().into(),
13    }
14}
15
16fn process_html_writer(
17    input: &DeriveInput,
18    skip_docs: bool,
19) -> syn::Result<proc_macro2::TokenStream> {
20    let name = &input.ident;
21    let generics = &input.generics;
22    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
23
24    // Validate the input is a struct
25    let fields = match &input.data {
26        Data::Struct(data_struct) => match &data_struct.fields {
27            Fields::Named(fields) => fields,
28            _ => {
29                return Err(syn::Error::new_spanned(
30                    &input.ident,
31                    "struct must have named fields",
32                ))
33            }
34        },
35        Data::Enum(_) => {
36            return Err(syn::Error::new_spanned(
37                &input.ident,
38                "html_writer can only be applied to structs, not enums",
39            ))
40        }
41        Data::Union(_) => {
42            return Err(syn::Error::new_spanned(
43                &input.ident,
44                "html_writer can only be applied to structs, not unions",
45            ))
46        }
47    };
48
49    // Validate base field exists and has correct type
50    let base_field = fields
51        .named
52        .iter()
53        .find(|f| f.ident.as_ref().map_or(false, |i| i == "base"))
54        .ok_or_else(|| syn::Error::new_spanned(fields, "struct must have a field named 'base'"))?;
55
56    // Check base field type is HtmlWriterBase<W>
57    let base_type = &base_field.ty;
58    let is_valid_base_type = quote!(#base_type).to_string().contains("HtmlWriterBase");
59    if !is_valid_base_type {
60        return Err(syn::Error::new_spanned(
61            base_type,
62            "base field must be of type HtmlWriterBase<W>",
63        ));
64    }
65
66    // Generate implementation
67    let docs = if !skip_docs {
68        quote! {
69            #[doc = "An HTML writer implementation."]
70            #[doc = "This type implements methods for writing HTML elements."]
71        }
72    } else {
73        quote! {}
74    };
75
76    let expanded = quote! {
77        #docs
78        #input
79
80        impl #impl_generics HtmlWriter<W> for #name #ty_generics #where_clause {
81            fn get_writer(&mut self) -> &mut W {
82                self.base.get_writer()
83            }
84
85            fn get_config(&self) -> &HtmlConfig {
86                self.base.get_config()
87            }
88
89            fn get_state(&mut self) -> &mut HtmlState {
90                self.base.get_state()
91            }
92        }
93    };
94
95    Ok(expanded)
96}