Skip to main content

dory_derive/
lib.rs

1//! Procedural macros for deriving serialization traits in Dory
2//!
3//! This crate provides derive macros for `DorySerialize`, `DoryDeserialize`, and `Valid` traits.
4//! These macros automatically implement field-by-field serialization for structs.
5
6#![allow(missing_docs)]
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields};
11
12#[proc_macro_derive(DorySerialize)]
13pub fn derive_dory_serialize(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = &input.ident;
16    let mut generics = input.generics.clone();
17
18    // Add DorySerialize bounds to all field types
19    if let Data::Struct(data) = &input.data {
20        if let Fields::Named(fields) = &data.fields {
21            for field in &fields.named {
22                let ty = &field.ty;
23                generics
24                    .make_where_clause()
25                    .predicates
26                    .push(syn::parse_quote! { #ty: DorySerialize });
27            }
28        }
29    }
30
31    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
32
33    let serialize_fields = match &input.data {
34        Data::Struct(data) => match &data.fields {
35            Fields::Named(fields) => {
36                let field_serialize = fields.named.iter().map(|f| {
37                    let field_name = &f.ident;
38                    quote! {
39                        self.#field_name.serialize_with_mode(&mut writer, compress)?;
40                    }
41                });
42                quote! { #(#field_serialize)* }
43            }
44            Fields::Unnamed(fields) => {
45                let field_serialize = fields.unnamed.iter().enumerate().map(|(i, _)| {
46                    let index = syn::Index::from(i);
47                    quote! {
48                        self.#index.serialize_with_mode(&mut writer, compress)?;
49                    }
50                });
51                quote! { #(#field_serialize)* }
52            }
53            Fields::Unit => quote! {},
54        },
55        Data::Enum(_) => {
56            return syn::Error::new_spanned(input, "DorySerialize cannot be derived for enums yet")
57                .to_compile_error()
58                .into();
59        }
60        Data::Union(_) => {
61            return syn::Error::new_spanned(input, "DorySerialize cannot be derived for unions")
62                .to_compile_error()
63                .into();
64        }
65    };
66
67    let size_fields = match &input.data {
68        Data::Struct(data) => match &data.fields {
69            Fields::Named(fields) => {
70                let field_size = fields.named.iter().map(|f| {
71                    let field_name = &f.ident;
72                    quote! {
73                        size += self.#field_name.serialized_size(compress);
74                    }
75                });
76                quote! { #(#field_size)* }
77            }
78            Fields::Unnamed(fields) => {
79                let field_size = fields.unnamed.iter().enumerate().map(|(i, _)| {
80                    let index = syn::Index::from(i);
81                    quote! {
82                        size += self.#index.serialized_size(compress);
83                    }
84                });
85                quote! { #(#field_size)* }
86            }
87            Fields::Unit => quote! {},
88        },
89        _ => unreachable!(),
90    };
91
92    let expanded = quote! {
93        impl #impl_generics DorySerialize for #name #ty_generics #where_clause {
94            fn serialize_with_mode<W: std::io::Write>(
95                &self,
96                mut writer: W,
97                compress: crate::primitives::serialization::Compress,
98            ) -> Result<(), crate::primitives::serialization::SerializationError> {
99                use crate::primitives::serialization::DorySerialize;
100                #serialize_fields
101                Ok(())
102            }
103
104            fn serialized_size(&self, compress: crate::primitives::serialization::Compress) -> usize {
105                use crate::primitives::serialization::DorySerialize;
106                let mut size = 0;
107                #size_fields
108                size
109            }
110        }
111    };
112
113    TokenStream::from(expanded)
114}
115
116#[proc_macro_derive(DoryDeserialize)]
117pub fn derive_dory_deserialize(input: TokenStream) -> TokenStream {
118    let input = parse_macro_input!(input as DeriveInput);
119    let name = &input.ident;
120    let mut generics = input.generics.clone();
121
122    // Add DoryDeserialize bounds to all field types
123    if let Data::Struct(data) = &input.data {
124        if let Fields::Named(fields) = &data.fields {
125            for field in &fields.named {
126                let ty = &field.ty;
127                generics
128                    .make_where_clause()
129                    .predicates
130                    .push(syn::parse_quote! { #ty: DoryDeserialize });
131            }
132        }
133    }
134
135    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136
137    let deserialize_fields = match &input.data {
138        Data::Struct(data) => match &data.fields {
139            Fields::Named(fields) => {
140                let field_names = fields.named.iter().map(|f| &f.ident);
141                let field_deserialize = fields.named.iter().map(|f| {
142                    let field_name = &f.ident;
143                    let field_ty = &f.ty;
144                    quote! {
145                        let #field_name = <#field_ty>::deserialize_with_mode(&mut reader, compress, validate)?;
146                    }
147                });
148                quote! {
149                    #(#field_deserialize)*
150                    Ok(Self { #(#field_names),* })
151                }
152            }
153            Fields::Unnamed(fields) => {
154                let field_deserialize = fields.unnamed.iter().enumerate().map(|(i, f)| {
155                    let field_name = syn::Ident::new(&format!("field_{i}"), f.ty.span());
156                    let field_ty = &f.ty;
157                    quote! {
158                        let #field_name = <#field_ty>::deserialize_with_mode(&mut reader, compress, validate)?;
159                    }
160                });
161                let field_names = (0..fields.unnamed.len())
162                    .map(|i| syn::Ident::new(&format!("field_{i}"), fields.unnamed.span()));
163                quote! {
164                    #(#field_deserialize)*
165                    Ok(Self(#(#field_names),*))
166                }
167            }
168            Fields::Unit => quote! { Ok(Self) },
169        },
170        Data::Enum(_) => {
171            return syn::Error::new_spanned(
172                input,
173                "DoryDeserialize cannot be derived for enums yet",
174            )
175            .to_compile_error()
176            .into();
177        }
178        Data::Union(_) => {
179            return syn::Error::new_spanned(input, "DoryDeserialize cannot be derived for unions")
180                .to_compile_error()
181                .into();
182        }
183    };
184
185    let expanded = quote! {
186        impl #impl_generics DoryDeserialize for #name #ty_generics #where_clause {
187            fn deserialize_with_mode<R: std::io::Read>(
188                mut reader: R,
189                compress: crate::primitives::serialization::Compress,
190                validate: crate::primitives::serialization::Validate,
191            ) -> Result<Self, crate::primitives::serialization::SerializationError> {
192                use crate::primitives::serialization::DoryDeserialize;
193                #deserialize_fields
194            }
195        }
196    };
197
198    TokenStream::from(expanded)
199}
200
201#[proc_macro_derive(Valid)]
202pub fn derive_valid(input: TokenStream) -> TokenStream {
203    let input = parse_macro_input!(input as DeriveInput);
204    let name = &input.ident;
205    let mut generics = input.generics.clone();
206
207    // Add Valid bounds to all field types
208    if let Data::Struct(data) = &input.data {
209        if let Fields::Named(fields) = &data.fields {
210            for field in &fields.named {
211                let ty = &field.ty;
212                generics
213                    .make_where_clause()
214                    .predicates
215                    .push(syn::parse_quote! { #ty: Valid });
216            }
217        }
218    }
219
220    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
221
222    let check_fields = match &input.data {
223        Data::Struct(data) => match &data.fields {
224            Fields::Named(fields) => {
225                let field_checks = fields.named.iter().map(|f| {
226                    let field_name = &f.ident;
227                    quote! {
228                        self.#field_name.check()?;
229                    }
230                });
231                quote! { #(#field_checks)* }
232            }
233            Fields::Unnamed(fields) => {
234                let field_checks = fields.unnamed.iter().enumerate().map(|(i, _)| {
235                    let index = syn::Index::from(i);
236                    quote! {
237                        self.#index.check()?;
238                    }
239                });
240                quote! { #(#field_checks)* }
241            }
242            Fields::Unit => quote! {},
243        },
244        Data::Enum(_) => {
245            return syn::Error::new_spanned(input, "Valid cannot be derived for enums yet")
246                .to_compile_error()
247                .into();
248        }
249        Data::Union(_) => {
250            return syn::Error::new_spanned(input, "Valid cannot be derived for unions")
251                .to_compile_error()
252                .into();
253        }
254    };
255
256    let expanded = quote! {
257        impl #impl_generics Valid for #name #ty_generics #where_clause {
258            fn check(&self) -> Result<(), crate::primitives::serialization::SerializationError> {
259                use crate::primitives::serialization::Valid;
260                #check_fields
261                Ok(())
262            }
263        }
264    };
265
266    TokenStream::from(expanded)
267}