paperless_api_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident};
4
5fn is_option_type(ty: &syn::Type) -> bool {
6 if let syn::Type::Path(type_path) = ty {
7 if let Some(segment) = type_path.path.segments.last() {
8 return segment.ident == "Option";
9 }
10 }
11 false
12}
13
14#[allow(dead_code)]
15struct DtoFieldAttributes {
16 skip: bool,
19}
20
21struct BaseStruct<'a> {
22 fields: Vec<&'a syn::Field>,
23}
24
25impl DtoFieldAttributes {
26 fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
27 let mut skip = false;
28
29 for attr in attrs {
30 if attr.path().is_ident("dto") {
31 attr.parse_nested_meta(|meta| {
32 if meta.path.is_ident("skip") {
33 skip = true;
34 }
35 Ok(())
36 })?;
37 }
38 }
39
40 Ok(Self { skip })
41 }
42}
43
44fn non_dto_attrs(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
45 attrs.iter().filter(|a| !a.path().is_ident("dto")).collect()
46}
47
48fn new_struct(
49 base_struct: &BaseStruct,
50 new_name: &Ident,
51 all_optional: bool,
52) -> proc_macro2::TokenStream {
53 let mut field_defs = Vec::new();
54 for field in &base_struct.fields {
55 let dto = match DtoFieldAttributes::parse(&field.attrs) {
57 Ok(dto) => dto,
58 Err(e) => return e.to_compile_error(),
59 };
60 if dto.skip {
61 continue;
62 }
63
64 let ident = field.ident.as_ref().unwrap();
65 let ty = &field.ty;
66 let vis = &field.vis;
67 let attrs = non_dto_attrs(&field.attrs);
68
69 let def = if all_optional && !is_option_type(ty) {
70 quote! {
71 #(#attrs)*
72 #[serde(skip_serializing_if = "Option::is_none")]
73 #vis #ident: Option<#ty>,
74 }
75 } else {
76 quote! {
77 #(#attrs)*
78 #vis #ident: #ty,
79 }
80 };
81 field_defs.push(def);
82 }
83
84 quote! {
86 #[derive(Debug, Default, Clone, serde::Serialize)]
87 pub struct #new_name {
88 #(#field_defs)*
89 }
90 }
91}
92
93fn derive_create_or_update(input: TokenStream, update: bool) -> TokenStream {
94 let input = parse_macro_input!(input as DeriveInput);
95 let name = &input.ident;
96 let dto_name = if update {
97 format_ident!("Update{}", name)
98 } else {
99 format_ident!("Create{}", name)
100 };
101
102 let fields = match &input.data {
103 Data::Struct(data) => match &data.fields {
104 Fields::Named(fields) => &fields.named,
105 _ => {
106 return syn::Error::new_spanned(
107 &input.ident,
108 "DTO derive only supports structs with named fields",
109 )
110 .to_compile_error()
111 .into();
112 }
113 },
114 _ => {
115 return syn::Error::new_spanned(&input.ident, "DTO derive only supports structs")
116 .to_compile_error()
117 .into();
118 }
119 };
120
121 let mut field_defs = Vec::new();
122 for f in fields {
123 let dto = match DtoFieldAttributes::parse(&f.attrs) {
124 Ok(dto) => dto,
125 Err(e) => return e.to_compile_error().into(),
126 };
127 if dto.skip {
128 continue;
129 }
130
131 let ident = f.ident.as_ref().unwrap();
132 let ty = &f.ty;
133 let vis = &f.vis;
134 let attrs = non_dto_attrs(&f.attrs);
135
136 let def = if update && !is_option_type(ty) {
137 quote! {
138 #(#attrs)*
139 #[serde(skip_serializing_if = "Option::is_none")]
140 #vis #ident: Option<#ty>,
141 }
142 } else {
143 quote! {
144 #(#attrs)*
145 #vis #ident: #ty,
146 }
147 };
148 field_defs.push(def);
149 }
150
151 let trait_path = if update {
152 quote!(crate::dto::UpdateDto)
153 } else {
154 quote!(crate::dto::CreateDtoObject)
155 };
156
157 let expanded = quote! {
158 #[derive(Debug, Default, Clone, serde::Serialize)]
159 pub struct #dto_name {
160 #(#field_defs)*
161 }
162
163 impl #trait_path for #dto_name {}
164 };
165
166 TokenStream::from(expanded)
167}
168
169#[proc_macro_derive(UpdateDto, attributes(dto))]
170pub fn derive_update_dto(input: TokenStream) -> TokenStream {
171 derive_create_or_update(input, true)
172}
173
174#[proc_macro_derive(CreateDto, attributes(dto, api_info))]
175pub fn derive_create_dto(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177 let name = input.ident.clone();
178
179 let fields = match &input.data {
180 Data::Struct(data) => match &data.fields {
181 Fields::Named(fields) => &fields.named,
182 _ => {
183 return syn::Error::new_spanned(
184 &input.ident,
185 "DTO derive only supports structs with named fields",
186 )
187 .to_compile_error()
188 .into();
189 }
190 },
191 _ => {
192 return syn::Error::new_spanned(&input.ident, "DTO derive only supports structs")
193 .to_compile_error()
194 .into();
195 }
196 };
197
198 let mut endpoint = None;
200 for attr in &input.attrs {
201 if attr.path().is_ident("api_info") {
202 attr.parse_nested_meta(|meta| {
203 if meta.path.is_ident("endpoint") {
204 let value = meta.value()?;
205 let lit: syn::LitStr = value.parse()?;
206 endpoint = Some(lit.value());
207 }
208 Ok(())
209 })
210 .unwrap();
211 }
212 }
213
214 let Some(endpoint) = endpoint else {
215 return syn::Error::new_spanned(
216 &input.ident,
217 "CreateDtoObject requires a #[api_info(endpoint = \"...\")] attribute",
218 )
219 .to_compile_error()
220 .into();
221 };
222
223 let new_struct_name = format_ident!("Create{}", name);
224
225 let new_struct = new_struct(
226 &BaseStruct {
227 fields: fields.iter().collect(),
228 },
229 &new_struct_name,
230 false,
231 );
232
233 TokenStream::from(quote! {
234 #new_struct
235
236 impl crate::dto::CreateDtoObject for #new_struct_name {
237 type BaseType = #name;
238
239 fn endpoint() -> &'static str {
240 #endpoint
241 }
242 }
243 })
244}