alloy_tx_macros/
serde.rs

1use crate::parse::{GroupedVariants, ProcessedVariant};
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{Ident, Path};
5
6/// Generate serde implementations for the transaction envelope.
7pub(crate) struct SerdeGenerator<'a> {
8    input_type_name: &'a Ident,
9    generics: &'a syn::Generics,
10    variants: &'a GroupedVariants,
11    alloy_consensus: &'a Path,
12    serde: TokenStream,
13    serde_json: TokenStream,
14    serde_cfg: &'a TokenStream,
15}
16
17impl<'a> SerdeGenerator<'a> {
18    pub(crate) fn new(
19        input_type_name: &'a Ident,
20        generics: &'a syn::Generics,
21        variants: &'a GroupedVariants,
22        alloy_consensus: &'a Path,
23        serde_cfg: &'a TokenStream,
24    ) -> Self {
25        let serde = quote! { #alloy_consensus::private::serde };
26        let serde_json = quote! { #alloy_consensus::private::serde_json };
27        Self { input_type_name, generics, variants, alloy_consensus, serde, serde_json, serde_cfg }
28    }
29
30    /// Generate all serde-related code.
31    pub(crate) fn generate(&self) -> TokenStream {
32        let serde_bounds = self.generate_serde_bounds();
33        let serde_bounds_str = serde_bounds.to_string();
34
35        let tagged_enum = self.generate_tagged_enum(&serde_bounds_str);
36        let untagged_enum = self.generate_untagged_enum(&serde_bounds_str);
37        let impls = self.generate_serde_impls(&serde_bounds);
38
39        let serde_cfg = self.serde_cfg;
40
41        quote! {
42            #[cfg(#serde_cfg)]
43            const _: () = {
44                #tagged_enum
45                #untagged_enum
46                #impls
47            };
48        }
49    }
50
51    /// Generate serde bounds.
52    fn generate_serde_bounds(&self) -> TokenStream {
53        let input_type_name = self.input_type_name;
54        let (_, ty_generics, _) = self.generics.split_for_impl();
55        let variant_types = self.variants.all.iter().map(|v| &v.ty);
56        let serde = &self.serde;
57
58        quote! {
59            #input_type_name #ty_generics: Clone,
60            #(#variant_types: #serde::Serialize + #serde::de::DeserializeOwned),*
61        }
62    }
63
64    /// Generate the tagged transaction types enum.
65    fn generate_tagged_enum(&self, serde_bounds_str: &str) -> TokenStream {
66        let generics = self.generics;
67        let serde = &self.serde;
68        let serde_str = serde.to_string();
69
70        let tagged_variants = self.generate_tagged_variants();
71        let from_tagged_impl = self.generate_from_tagged_impl();
72
73        quote! {
74            #[derive(Debug, #serde::Serialize, #serde::Deserialize)]
75            #[serde(tag = "type", bound = #serde_bounds_str, crate = #serde_str)]
76            enum TaggedTxTypes #generics {
77                #(#tagged_variants),*
78            }
79
80            #from_tagged_impl
81        }
82    }
83
84    /// Generate tagged variants for serde.
85    fn generate_tagged_variants(&self) -> Vec<TokenStream> {
86        self.variants
87            .typed
88            .iter()
89            .map(|v| {
90                let ProcessedVariant { name, ty, kind, serde_attrs, typed: _, doc_attrs: _ } = v;
91
92                let (rename, aliases) = kind.serde_tag_and_aliases();
93
94                // Special handling for legacy transactions
95                let maybe_with = if v.is_legacy() {
96                    let alloy_consensus = &self.alloy_consensus;
97                    let path = quote! {
98                        #alloy_consensus::transaction::signed_legacy_serde
99                    }
100                    .to_string();
101                    quote! { with = #path, }
102                } else {
103                    quote! {}
104                };
105
106                let maybe_other = serde_attrs.clone().unwrap_or_default();
107
108                quote! {
109                    #[serde(rename = #rename, #(alias = #aliases,)* #maybe_with #maybe_other)]
110                    #name(#ty)
111                }
112            })
113            .collect()
114    }
115
116    /// Generate From implementation for tagged types.
117    fn generate_from_tagged_impl(&self) -> TokenStream {
118        let input_type_name = self.input_type_name;
119        let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
120        let unwrapped_generics = &self.generics.params;
121
122        let typed_names = self.variants.typed.iter().map(|v| &v.name).collect::<Vec<_>>();
123
124        quote! {
125            impl #impl_generics From<TaggedTxTypes #ty_generics> for #input_type_name #ty_generics {
126                fn from(value: TaggedTxTypes #ty_generics) -> Self {
127                    match value {
128                        #(
129                            TaggedTxTypes::<#unwrapped_generics>::#typed_names(value) => Self::#typed_names(value),
130                        )*
131                    }
132                }
133            }
134        }
135    }
136
137    /// Generate the untagged transaction types enum.
138    fn generate_untagged_enum(&self, serde_bounds_str: &str) -> TokenStream {
139        let generics = self.generics;
140        let serde = &self.serde;
141        let serde_str = serde.to_string();
142
143        let (legacy_variant, legacy_arm, legacy_deserialize) = self.generate_legacy_handling();
144        let untagged_variants = self.generate_untagged_variants(&legacy_variant);
145        let untagged_conversions = self.generate_untagged_conversions(&legacy_arm);
146        let deserialize_impl = self.generate_untagged_deserialize(&legacy_deserialize);
147
148        quote! {
149            #[derive(#serde::Serialize)]
150            #[serde(untagged, bound = #serde_bounds_str, crate = #serde_str)]
151            pub(crate) enum UntaggedTxTypes #generics {
152                Tagged(TaggedTxTypes #generics),
153                #untagged_variants
154            }
155
156            #deserialize_impl
157            #untagged_conversions
158        }
159    }
160
161    /// Generate untagged variants. This includes flattened envelopes and legacy transactions.
162    fn generate_untagged_variants(&self, legacy_variant: &TokenStream) -> TokenStream {
163        let flattened_variants = self.variants.flattened.iter().map(|v| {
164            let name = &v.name;
165            let ty = &v.ty;
166
167            let maybe_attributes = if let Some(attrs) = &v.serde_attrs {
168                quote! { #[serde(#attrs)] }
169            } else {
170                quote! {}
171            };
172
173            quote! { #maybe_attributes #name(#ty) }
174        });
175
176        quote! {
177            #(#flattened_variants,)*
178            #legacy_variant
179        }
180    }
181
182    /// Generate legacy transaction handling for serde.
183    fn generate_legacy_handling(&self) -> (TokenStream, TokenStream, TokenStream) {
184        if let Some(legacy) = self.variants.legacy_variant() {
185            let ty = &legacy.ty;
186            let name = &legacy.name;
187            let alloy_consensus = self.alloy_consensus;
188
189            let variant = quote! { UntaggedLegacy(#ty) };
190            let arm = quote! { UntaggedTxTypes::UntaggedLegacy(tx) => Self::#name(tx), };
191            let deserialize = quote! {
192                if let Ok(val) = #alloy_consensus::transaction::untagged_legacy_serde::deserialize(deserializer).map(Self::UntaggedLegacy) {
193                    return Ok(val);
194                }
195            };
196
197            (variant, arm, deserialize)
198        } else {
199            (quote! {}, quote! {}, quote! {})
200        }
201    }
202
203    /// Generate custom deserialize implementation for untagged types.
204    fn generate_untagged_deserialize(&self, legacy_deserialize: &TokenStream) -> TokenStream {
205        let generics = self.generics;
206        let unwrapped_generics = &generics.params;
207        let serde = &self.serde;
208        let serde_json = &self.serde_json;
209        let serde_bounds = self.generate_serde_bounds();
210
211        let flattened_names = self.variants.flattened.iter().map(|v| &v.name);
212
213        quote! {
214            // Manually modified derived serde(untagged) to preserve the error of the TaggedTxEnvelope
215            // attempt. Note: This uses private serde API
216            impl<'de, #unwrapped_generics> #serde::Deserialize<'de> for UntaggedTxTypes #generics where #serde_bounds {
217                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
218                where
219                    D: #serde::Deserializer<'de>,
220                {
221                    let value = #serde_json::Value::deserialize(deserializer)?;
222                    let deserializer = &value;
223
224                    let tagged_res =
225                        TaggedTxTypes::<#unwrapped_generics>::deserialize(deserializer).map(Self::Tagged).map_err(#serde::de::Error::custom);
226
227                    if tagged_res.is_ok() {
228                        // return tagged if successful
229                        return tagged_res;
230                    }
231
232                    // proceed with flattened variants
233                    #(
234                        if let Ok(val) = #serde::Deserialize::deserialize(deserializer).map(Self::#flattened_names) {
235                            return Ok(val);
236                        }
237                    )*
238
239                    #legacy_deserialize
240
241                    // return the original error, which is more useful than the untagged error
242                    //  > "data did not match any variant of untagged enum MaybeTaggedTxEnvelope"
243                    tagged_res
244                }
245            }
246        }
247    }
248
249    /// Generate conversion implementations for untagged types.
250    fn generate_untagged_conversions(&self, legacy_arm: &TokenStream) -> TokenStream {
251        let input_type_name = self.input_type_name;
252        let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
253        let unwrapped_generics = &self.generics.params;
254        let flattened_names = self.variants.flattened.iter().map(|v| &v.name).collect::<Vec<_>>();
255        let typed_names = self.variants.typed.iter().map(|v| &v.name).collect::<Vec<_>>();
256
257        quote! {
258            impl #impl_generics From<UntaggedTxTypes #ty_generics> for #input_type_name #ty_generics {
259                fn from(value: UntaggedTxTypes #ty_generics) -> Self {
260                    match value {
261                        UntaggedTxTypes::Tagged(value) => value.into(),
262                        #(
263                            UntaggedTxTypes::#flattened_names(value) => Self::#flattened_names(value),
264                        )*
265                        #legacy_arm
266                    }
267                }
268            }
269
270            impl #impl_generics From<#input_type_name #ty_generics> for UntaggedTxTypes #ty_generics {
271                fn from(value: #input_type_name #ty_generics) -> Self {
272                    match value {
273                        #(
274                            #input_type_name::<#unwrapped_generics>::#flattened_names(value) => Self::#flattened_names(value),
275                        )*
276                        #(
277                            #input_type_name::<#unwrapped_generics>::#typed_names(value) => Self::Tagged(TaggedTxTypes::#typed_names(value)),
278                        )*
279                    }
280                }
281            }
282        }
283    }
284
285    /// Generate Deserialize implementation.
286    fn generate_serde_impls(&self, serde_bounds: &TokenStream) -> TokenStream {
287        let input_type_name = self.input_type_name;
288        let serde = &self.serde;
289        let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
290        let unwrapped_generics = &self.generics.params;
291
292        quote! {
293            impl #impl_generics #serde::Serialize for #input_type_name #ty_generics where #serde_bounds {
294                fn serialize<S: #serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
295                    UntaggedTxTypes::<#unwrapped_generics>::from(self.clone()).serialize(serializer)
296                }
297            }
298
299            impl <'de, #unwrapped_generics> #serde::Deserialize<'de> for #input_type_name #ty_generics where #serde_bounds {
300                fn deserialize<D: #serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
301                    UntaggedTxTypes::<#unwrapped_generics>::deserialize(deserializer).map(Into::into)
302                }
303            }
304        }
305    }
306}