Skip to main content

paperless_api_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident};
4
5fn is_option_type(ty: &syn::Type) -> bool {
6    if let syn::Type::Path(type_path) = ty {
7        if let Some(segment) = type_path.path.segments.last() {
8            return segment.ident == "Option";
9        }
10    }
11    false
12}
13
14#[allow(dead_code)]
15struct DtoFieldAttributes {
16    /// If true, this field can not be used for creating or updating the DTO.
17    /// e.g. the id of the entity
18    skip: bool,
19}
20
21struct BaseStruct<'a> {
22    fields: Vec<&'a syn::Field>,
23}
24
25impl DtoFieldAttributes {
26    fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
27        let mut skip = false;
28
29        for attr in attrs {
30            if attr.path().is_ident("dto") {
31                attr.parse_nested_meta(|meta| {
32                    if meta.path.is_ident("skip") {
33                        skip = true;
34                    }
35                    Ok(())
36                })?;
37            }
38        }
39
40        Ok(Self { skip })
41    }
42}
43
44fn non_dto_attrs(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
45    attrs.iter().filter(|a| !a.path().is_ident("dto")).collect()
46}
47
48fn new_struct(
49    base_struct: &BaseStruct,
50    new_name: &Ident,
51    all_optional: bool,
52) -> proc_macro2::TokenStream {
53    let mut field_defs = Vec::new();
54    for field in &base_struct.fields {
55        // Check if the field should be skipped
56        let dto = match DtoFieldAttributes::parse(&field.attrs) {
57            Ok(dto) => dto,
58            Err(e) => return e.to_compile_error(),
59        };
60        if dto.skip {
61            continue;
62        }
63
64        let ident = field.ident.as_ref().unwrap();
65        let ty = &field.ty;
66        let vis = &field.vis;
67        let attrs = non_dto_attrs(&field.attrs);
68
69        let def = if all_optional && !is_option_type(ty) {
70            quote! {
71                #(#attrs)*
72                #[serde(skip_serializing_if = "Option::is_none")]
73                #vis #ident: Option<#ty>,
74            }
75        } else {
76            quote! {
77                #(#attrs)*
78                #vis #ident: #ty,
79            }
80        };
81        field_defs.push(def);
82    }
83
84    // Generate the struct
85    quote! {
86        #[derive(Debug, Default, Clone, serde::Serialize)]
87        pub struct #new_name {
88            #(#field_defs)*
89        }
90    }
91}
92
93fn derive_create_or_update(input: TokenStream, update: bool) -> TokenStream {
94    let input = parse_macro_input!(input as DeriveInput);
95    let name = &input.ident;
96    let dto_name = if update {
97        format_ident!("Update{}", name)
98    } else {
99        format_ident!("Create{}", name)
100    };
101
102    let fields = match &input.data {
103        Data::Struct(data) => match &data.fields {
104            Fields::Named(fields) => &fields.named,
105            _ => {
106                return syn::Error::new_spanned(
107                    &input.ident,
108                    "DTO derive only supports structs with named fields",
109                )
110                .to_compile_error()
111                .into();
112            }
113        },
114        _ => {
115            return syn::Error::new_spanned(&input.ident, "DTO derive only supports structs")
116                .to_compile_error()
117                .into();
118        }
119    };
120
121    let mut field_defs = Vec::new();
122    for f in fields {
123        let dto = match DtoFieldAttributes::parse(&f.attrs) {
124            Ok(dto) => dto,
125            Err(e) => return e.to_compile_error().into(),
126        };
127        if dto.skip {
128            continue;
129        }
130
131        let ident = f.ident.as_ref().unwrap();
132        let ty = &f.ty;
133        let vis = &f.vis;
134        let attrs = non_dto_attrs(&f.attrs);
135
136        let def = if update && !is_option_type(ty) {
137            quote! {
138                #(#attrs)*
139                #[serde(skip_serializing_if = "Option::is_none")]
140                #vis #ident: Option<#ty>,
141            }
142        } else {
143            quote! {
144                #(#attrs)*
145                #vis #ident: #ty,
146            }
147        };
148        field_defs.push(def);
149    }
150
151    let trait_path = if update {
152        quote!(crate::dto::UpdateDto)
153    } else {
154        quote!(crate::dto::CreateDtoObject)
155    };
156
157    let expanded = quote! {
158        #[derive(Debug, Default, Clone, serde::Serialize)]
159        pub struct #dto_name {
160            #(#field_defs)*
161        }
162
163        impl #trait_path for #dto_name {}
164    };
165
166    TokenStream::from(expanded)
167}
168
169#[proc_macro_derive(UpdateDto, attributes(dto))]
170pub fn derive_update_dto(input: TokenStream) -> TokenStream {
171    derive_create_or_update(input, true)
172}
173
174#[proc_macro_derive(CreateDto, attributes(dto, api_info))]
175pub fn derive_create_dto(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177    let name = input.ident.clone();
178
179    let fields = match &input.data {
180        Data::Struct(data) => match &data.fields {
181            Fields::Named(fields) => &fields.named,
182            _ => {
183                return syn::Error::new_spanned(
184                    &input.ident,
185                    "DTO derive only supports structs with named fields",
186                )
187                .to_compile_error()
188                .into();
189            }
190        },
191        _ => {
192            return syn::Error::new_spanned(&input.ident, "DTO derive only supports structs")
193                .to_compile_error()
194                .into();
195        }
196    };
197
198    // Parse #[api_info(endpoint = "...")] attribute
199    let mut endpoint = None;
200    for attr in &input.attrs {
201        if attr.path().is_ident("api_info") {
202            attr.parse_nested_meta(|meta| {
203                if meta.path.is_ident("endpoint") {
204                    let value = meta.value()?;
205                    let lit: syn::LitStr = value.parse()?;
206                    endpoint = Some(lit.value());
207                }
208                Ok(())
209            })
210            .unwrap();
211        }
212    }
213
214    let Some(endpoint) = endpoint else {
215        return syn::Error::new_spanned(
216            &input.ident,
217            "CreateDtoObject requires a #[api_info(endpoint = \"...\")] attribute",
218        )
219        .to_compile_error()
220        .into();
221    };
222
223    let new_struct_name = format_ident!("Create{}", name);
224
225    let new_struct = new_struct(
226        &BaseStruct {
227            fields: fields.iter().collect(),
228        },
229        &new_struct_name,
230        false,
231    );
232
233    TokenStream::from(quote! {
234        #new_struct
235
236        impl crate::dto::CreateDtoObject for #new_struct_name {
237            type BaseType = #name;
238
239            fn endpoint() -> &'static str {
240                #endpoint
241            }
242        }
243    })
244}