1use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7#[proc_macro_derive(Encode)]
9pub fn derive_encode(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 let name = &input.ident;
12 let generics = &input.generics;
13 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
14
15 let encode_body = match &input.data {
16 Data::Struct(data_struct) => match &data_struct.fields {
17 Fields::Named(fields) => {
18 let field_encodes = fields.named.iter().map(|field| {
19 let field_name = &field.ident;
20 quote! {
21 self.#field_name.encode(writer)?;
22 }
23 });
24 quote! {
25 #(#field_encodes)*
26 }
27 }
28 Fields::Unnamed(fields) => {
29 let field_encodes = fields.unnamed.iter().enumerate().map(|(i, _)| {
30 let index = syn::Index::from(i);
31 quote! {
32 self.#index.encode(writer)?;
33 }
34 });
35 quote! {
36 #(#field_encodes)*
37 }
38 }
39 Fields::Unit => {
40 quote! {}
41 }
42 },
43 Data::Enum(data_enum) => {
44 let variant_encodes = data_enum.variants.iter().enumerate().map(|(idx, variant)| {
46 let variant_idx = idx as u32;
47 let variant_name = &variant.ident;
48 match &variant.fields {
49 Fields::Named(fields) => {
50 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
51 let field_encodes = field_names.iter().map(|field_name| {
52 quote! {
53 #field_name.encode(writer)?;
54 }
55 });
56 quote! {
57 #name::#variant_name { #(#field_names,)* } => {
58 use justcode_core::varint::encode_length;
59 encode_length(writer, #variant_idx as usize, writer.config())?;
60 #(#field_encodes)*
61 }
62 }
63 }
64 Fields::Unnamed(fields) => {
65 let field_count = fields.unnamed.len();
66 let field_indices: Vec<_> = (0..field_count)
67 .map(|i| syn::Index::from(i))
68 .collect();
69 let field_encodes = field_indices.iter().map(|index| {
70 quote! {
71 #index.encode(writer)?;
72 }
73 });
74 let field_patterns = field_indices.clone();
75 quote! {
76 #name::#variant_name(#(#field_patterns,)*) => {
77 use justcode_core::varint::encode_length;
78 encode_length(writer, #variant_idx as usize, writer.config())?;
79 #(#field_encodes)*
80 }
81 }
82 }
83 Fields::Unit => {
84 quote! {
85 #name::#variant_name => {
86 use justcode_core::varint::encode_length;
87 encode_length(writer, #variant_idx as usize, writer.config())?;
88 }
89 }
90 }
91 }
92 });
93 quote! {
94 match self {
95 #(#variant_encodes)*
96 }
97 }
98 }
99 Data::Union(_) => {
100 return syn::Error::new_spanned(
101 name,
102 "justcode does not support encoding unions",
103 )
104 .to_compile_error()
105 .into();
106 }
107 };
108
109 let expanded = quote! {
110 impl #impl_generics justcode_core::Encode for #name #ty_generics #where_clause {
111 fn encode(&self, writer: &mut justcode_core::writer::Writer) -> justcode_core::Result<()> {
112 #encode_body
113 Ok(())
114 }
115 }
116 };
117
118 TokenStream::from(expanded)
119}
120
121#[proc_macro_derive(Decode)]
123pub fn derive_decode(input: TokenStream) -> TokenStream {
124 let input = parse_macro_input!(input as DeriveInput);
125 let name = &input.ident;
126 let generics = &input.generics;
127 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
128
129 let decode_body = match &input.data {
130 Data::Struct(data_struct) => match &data_struct.fields {
131 Fields::Named(fields) => {
132 let field_decodes = fields.named.iter().map(|field| {
133 let field_name = &field.ident;
134 let field_ty = &field.ty;
135 quote! {
136 #field_name: <#field_ty as justcode_core::Decode>::decode(reader)?
137 }
138 });
139 quote! {
140 Ok(#name {
141 #(#field_decodes,)*
142 })
143 }
144 }
145 Fields::Unnamed(fields) => {
146 let field_decodes = fields.unnamed.iter().map(|field| {
147 let field_ty = &field.ty;
148 quote! {
149 <#field_ty as justcode_core::Decode>::decode(reader)?
150 }
151 });
152 quote! {
153 Ok(#name(#(#field_decodes,)*))
154 }
155 }
156 Fields::Unit => {
157 quote! {
158 Ok(#name)
159 }
160 }
161 },
162 Data::Enum(data_enum) => {
163 let variant_decodes = data_enum.variants.iter().enumerate().map(|(idx, variant)| {
165 let variant_idx = idx as u32;
166 let variant_name = &variant.ident;
167 match &variant.fields {
168 Fields::Named(fields) => {
169 let field_decodes = fields.named.iter().map(|field| {
170 let field_name = &field.ident;
171 let field_ty = &field.ty;
172 quote! {
173 #field_name: <#field_ty as justcode_core::Decode>::decode(reader)?
174 }
175 });
176 quote! {
177 #variant_idx => Ok(#name::#variant_name {
178 #(#field_decodes,)*
179 })
180 }
181 }
182 Fields::Unnamed(fields) => {
183 let field_decodes = fields.unnamed.iter().map(|field| {
184 let field_ty = &field.ty;
185 quote! {
186 <#field_ty as justcode_core::Decode>::decode(reader)?
187 }
188 });
189 quote! {
190 #variant_idx => Ok(#name::#variant_name(
191 #(#field_decodes,)*
192 ))
193 }
194 }
195 Fields::Unit => {
196 quote! {
197 #variant_idx => Ok(#name::#variant_name)
198 }
199 }
200 }
201 });
202 quote! {
203 use justcode_core::varint::decode_length;
204 let variant_idx = decode_length(reader, reader.config())? as u32;
205 match variant_idx {
206 #(#variant_decodes,)*
207 _ => {
208 #[cfg(feature = "std")]
209 {
210 Err(justcode_core::error::JustcodeError::custom(format!("invalid variant index: {}", variant_idx)))
211 }
212 #[cfg(not(feature = "std"))]
213 {
214 extern crate alloc;
215 use alloc::format;
216 Err(justcode_core::error::JustcodeError::custom(format!("invalid variant index: {}", variant_idx)))
217 }
218 }
219 }
220 }
221 }
222 Data::Union(_) => {
223 return syn::Error::new_spanned(
224 name,
225 "justcode does not support decoding unions",
226 )
227 .to_compile_error()
228 .into();
229 }
230 };
231
232 let expanded = quote! {
233 impl #impl_generics justcode_core::Decode for #name #ty_generics #where_clause {
234 fn decode(reader: &mut justcode_core::reader::Reader) -> justcode_core::Result<Self> {
235 #decode_body
236 }
237 }
238 };
239
240 TokenStream::from(expanded)
241}
242