justcode_derive/
lib.rs

1//! Derive macros for justcode Encode and Decode traits.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7/// Derive macro for the `Encode` trait.
8#[proc_macro_derive(Encode)]
9pub fn derive_encode(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = &input.ident;
12    let generics = &input.generics;
13    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
14
15    let encode_body = match &input.data {
16        Data::Struct(data_struct) => match &data_struct.fields {
17                    Fields::Named(fields) => {
18                        let field_encodes = fields.named.iter().map(|field| {
19                            let field_name = &field.ident;
20                            quote! {
21                                self.#field_name.encode(writer)?;
22                            }
23                        });
24                        quote! {
25                            #(#field_encodes)*
26                        }
27                    }
28                    Fields::Unnamed(fields) => {
29                        let field_encodes = fields.unnamed.iter().enumerate().map(|(i, _)| {
30                            let index = syn::Index::from(i);
31                            quote! {
32                                self.#index.encode(writer)?;
33                            }
34                        });
35                        quote! {
36                            #(#field_encodes)*
37                        }
38                    }
39                    Fields::Unit => {
40                        quote! {}
41                    }
42        },
43        Data::Enum(data_enum) => {
44            // Encode variant index first, then variant data
45            let variant_encodes = data_enum.variants.iter().enumerate().map(|(idx, variant)| {
46                let variant_idx = idx as u32;
47                let variant_name = &variant.ident;
48                match &variant.fields {
49                    Fields::Named(fields) => {
50                        let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
51                        let field_encodes = field_names.iter().map(|field_name| {
52                            quote! {
53                                #field_name.encode(writer)?;
54                            }
55                        });
56                        quote! {
57                            #name::#variant_name { #(#field_names,)* } => {
58                                use justcode_core::varint::encode_length;
59                                encode_length(writer, #variant_idx as usize, writer.config())?;
60                                #(#field_encodes)*
61                            }
62                        }
63                    }
64                    Fields::Unnamed(fields) => {
65                        let field_count = fields.unnamed.len();
66                        let field_indices: Vec<_> = (0..field_count)
67                            .map(|i| syn::Index::from(i))
68                            .collect();
69                        let field_encodes = field_indices.iter().map(|index| {
70                            quote! {
71                                #index.encode(writer)?;
72                            }
73                        });
74                        let field_patterns = field_indices.clone();
75                        quote! {
76                            #name::#variant_name(#(#field_patterns,)*) => {
77                                use justcode_core::varint::encode_length;
78                                encode_length(writer, #variant_idx as usize, writer.config())?;
79                                #(#field_encodes)*
80                            }
81                        }
82                    }
83                    Fields::Unit => {
84                        quote! {
85                            #name::#variant_name => {
86                                use justcode_core::varint::encode_length;
87                                encode_length(writer, #variant_idx as usize, writer.config())?;
88                            }
89                        }
90                    }
91                }
92            });
93            quote! {
94                match self {
95                    #(#variant_encodes)*
96                }
97            }
98        }
99        Data::Union(_) => {
100            return syn::Error::new_spanned(
101                name,
102                "justcode does not support encoding unions",
103            )
104            .to_compile_error()
105            .into();
106        }
107    };
108
109    let expanded = quote! {
110        impl #impl_generics justcode_core::Encode for #name #ty_generics #where_clause {
111            fn encode(&self, writer: &mut justcode_core::writer::Writer) -> justcode_core::Result<()> {
112                #encode_body
113                Ok(())
114            }
115        }
116    };
117
118    TokenStream::from(expanded)
119}
120
121/// Derive macro for the `Decode` trait.
122#[proc_macro_derive(Decode)]
123pub fn derive_decode(input: TokenStream) -> TokenStream {
124    let input = parse_macro_input!(input as DeriveInput);
125    let name = &input.ident;
126    let generics = &input.generics;
127    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
128
129    let decode_body = match &input.data {
130        Data::Struct(data_struct) => match &data_struct.fields {
131            Fields::Named(fields) => {
132                let field_decodes = fields.named.iter().map(|field| {
133                    let field_name = &field.ident;
134                    let field_ty = &field.ty;
135                    quote! {
136                        #field_name: <#field_ty as justcode_core::Decode>::decode(reader)?
137                    }
138                });
139                quote! {
140                    Ok(#name {
141                        #(#field_decodes,)*
142                    })
143                }
144            }
145            Fields::Unnamed(fields) => {
146                let field_decodes = fields.unnamed.iter().map(|field| {
147                    let field_ty = &field.ty;
148                    quote! {
149                        <#field_ty as justcode_core::Decode>::decode(reader)?
150                    }
151                });
152                quote! {
153                    Ok(#name(#(#field_decodes,)*))
154                }
155            }
156            Fields::Unit => {
157                quote! {
158                    Ok(#name)
159                }
160            }
161        },
162        Data::Enum(data_enum) => {
163            // Decode variant index first, then variant data
164            let variant_decodes = data_enum.variants.iter().enumerate().map(|(idx, variant)| {
165                let variant_idx = idx as u32;
166                let variant_name = &variant.ident;
167                match &variant.fields {
168                    Fields::Named(fields) => {
169                        let field_decodes = fields.named.iter().map(|field| {
170                            let field_name = &field.ident;
171                            let field_ty = &field.ty;
172                            quote! {
173                                #field_name: <#field_ty as justcode_core::Decode>::decode(reader)?
174                            }
175                        });
176                        quote! {
177                            #variant_idx => Ok(#name::#variant_name {
178                                #(#field_decodes,)*
179                            })
180                        }
181                    }
182                    Fields::Unnamed(fields) => {
183                        let field_decodes = fields.unnamed.iter().map(|field| {
184                            let field_ty = &field.ty;
185                            quote! {
186                                <#field_ty as justcode_core::Decode>::decode(reader)?
187                            }
188                        });
189                        quote! {
190                            #variant_idx => Ok(#name::#variant_name(
191                                #(#field_decodes,)*
192                            ))
193                        }
194                    }
195                    Fields::Unit => {
196                        quote! {
197                            #variant_idx => Ok(#name::#variant_name)
198                        }
199                    }
200                }
201            });
202            quote! {
203                use justcode_core::varint::decode_length;
204                let variant_idx = decode_length(reader, reader.config())? as u32;
205                match variant_idx {
206                    #(#variant_decodes,)*
207                    _ => {
208                        #[cfg(feature = "std")]
209                        {
210                            Err(justcode_core::error::JustcodeError::custom(format!("invalid variant index: {}", variant_idx)))
211                        }
212                        #[cfg(not(feature = "std"))]
213                        {
214                            extern crate alloc;
215                            use alloc::format;
216                            Err(justcode_core::error::JustcodeError::custom(format!("invalid variant index: {}", variant_idx)))
217                        }
218                    }
219                }
220            }
221        }
222        Data::Union(_) => {
223            return syn::Error::new_spanned(
224                name,
225                "justcode does not support decoding unions",
226            )
227            .to_compile_error()
228            .into();
229        }
230    };
231
232    let expanded = quote! {
233        impl #impl_generics justcode_core::Decode for #name #ty_generics #where_clause {
234            fn decode(reader: &mut justcode_core::reader::Reader) -> justcode_core::Result<Self> {
235                #decode_body
236            }
237        }
238    };
239
240    TokenStream::from(expanded)
241}
242