1use crate::parse::{GroupedVariants, ProcessedVariant};
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{Ident, Path};
5
6pub(crate) struct SerdeGenerator<'a> {
8 input_type_name: &'a Ident,
9 generics: &'a syn::Generics,
10 variants: &'a GroupedVariants,
11 alloy_consensus: &'a Path,
12 serde: TokenStream,
13 serde_json: TokenStream,
14 serde_cfg: &'a TokenStream,
15}
16
17impl<'a> SerdeGenerator<'a> {
18 pub(crate) fn new(
19 input_type_name: &'a Ident,
20 generics: &'a syn::Generics,
21 variants: &'a GroupedVariants,
22 alloy_consensus: &'a Path,
23 serde_cfg: &'a TokenStream,
24 ) -> Self {
25 let serde = quote! { #alloy_consensus::private::serde };
26 let serde_json = quote! { #alloy_consensus::private::serde_json };
27 Self { input_type_name, generics, variants, alloy_consensus, serde, serde_json, serde_cfg }
28 }
29
30 pub(crate) fn generate(&self) -> TokenStream {
32 let serde_bounds = self.generate_serde_bounds();
33 let serde_bounds_str = serde_bounds.to_string();
34
35 let tagged_enum = self.generate_tagged_enum(&serde_bounds_str);
36 let untagged_enum = self.generate_untagged_enum(&serde_bounds_str);
37 let impls = self.generate_serde_impls(&serde_bounds);
38
39 let serde_cfg = self.serde_cfg;
40
41 quote! {
42 #[cfg(#serde_cfg)]
43 const _: () = {
44 #tagged_enum
45 #untagged_enum
46 #impls
47 };
48 }
49 }
50
51 fn generate_serde_bounds(&self) -> TokenStream {
53 let input_type_name = self.input_type_name;
54 let (_, ty_generics, _) = self.generics.split_for_impl();
55 let variant_types = self.variants.all.iter().map(|v| &v.ty);
56 let serde = &self.serde;
57
58 quote! {
59 #input_type_name #ty_generics: Clone,
60 #(#variant_types: #serde::Serialize + #serde::de::DeserializeOwned),*
61 }
62 }
63
64 fn generate_tagged_enum(&self, serde_bounds_str: &str) -> TokenStream {
66 let generics = self.generics;
67 let serde = &self.serde;
68 let serde_str = serde.to_string();
69
70 let tagged_variants = self.generate_tagged_variants();
71 let from_tagged_impl = self.generate_from_tagged_impl();
72
73 quote! {
74 #[derive(Debug, #serde::Serialize, #serde::Deserialize)]
75 #[serde(tag = "type", bound = #serde_bounds_str, crate = #serde_str)]
76 enum TaggedTxTypes #generics {
77 #(#tagged_variants),*
78 }
79
80 #from_tagged_impl
81 }
82 }
83
84 fn generate_tagged_variants(&self) -> Vec<TokenStream> {
86 self.variants
87 .typed
88 .iter()
89 .map(|v| {
90 let ProcessedVariant { name, ty, kind, serde_attrs, typed: _, doc_attrs: _ } = v;
91
92 let (rename, aliases) = kind.serde_tag_and_aliases();
93
94 let maybe_with = if v.is_legacy() {
96 let alloy_consensus = &self.alloy_consensus;
97 let path = quote! {
98 #alloy_consensus::transaction::signed_legacy_serde
99 }
100 .to_string();
101 quote! { with = #path, }
102 } else {
103 quote! {}
104 };
105
106 let maybe_other = serde_attrs.clone().unwrap_or_default();
107
108 quote! {
109 #[serde(rename = #rename, #(alias = #aliases,)* #maybe_with #maybe_other)]
110 #name(#ty)
111 }
112 })
113 .collect()
114 }
115
116 fn generate_from_tagged_impl(&self) -> TokenStream {
118 let input_type_name = self.input_type_name;
119 let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
120 let unwrapped_generics = &self.generics.params;
121
122 let typed_names = self.variants.typed.iter().map(|v| &v.name).collect::<Vec<_>>();
123
124 quote! {
125 impl #impl_generics From<TaggedTxTypes #ty_generics> for #input_type_name #ty_generics {
126 fn from(value: TaggedTxTypes #ty_generics) -> Self {
127 match value {
128 #(
129 TaggedTxTypes::<#unwrapped_generics>::#typed_names(value) => Self::#typed_names(value),
130 )*
131 }
132 }
133 }
134 }
135 }
136
137 fn generate_untagged_enum(&self, serde_bounds_str: &str) -> TokenStream {
139 let generics = self.generics;
140 let serde = &self.serde;
141 let serde_str = serde.to_string();
142
143 let (legacy_variant, legacy_arm, legacy_deserialize) = self.generate_legacy_handling();
144 let untagged_variants = self.generate_untagged_variants(&legacy_variant);
145 let untagged_conversions = self.generate_untagged_conversions(&legacy_arm);
146 let deserialize_impl = self.generate_untagged_deserialize(&legacy_deserialize);
147
148 quote! {
149 #[derive(#serde::Serialize)]
150 #[serde(untagged, bound = #serde_bounds_str, crate = #serde_str)]
151 pub(crate) enum UntaggedTxTypes #generics {
152 Tagged(TaggedTxTypes #generics),
153 #untagged_variants
154 }
155
156 #deserialize_impl
157 #untagged_conversions
158 }
159 }
160
161 fn generate_untagged_variants(&self, legacy_variant: &TokenStream) -> TokenStream {
163 let flattened_variants = self.variants.flattened.iter().map(|v| {
164 let name = &v.name;
165 let ty = &v.ty;
166
167 let maybe_attributes = if let Some(attrs) = &v.serde_attrs {
168 quote! { #[serde(#attrs)] }
169 } else {
170 quote! {}
171 };
172
173 quote! { #maybe_attributes #name(#ty) }
174 });
175
176 quote! {
177 #(#flattened_variants,)*
178 #legacy_variant
179 }
180 }
181
182 fn generate_legacy_handling(&self) -> (TokenStream, TokenStream, TokenStream) {
184 if let Some(legacy) = self.variants.legacy_variant() {
185 let ty = &legacy.ty;
186 let name = &legacy.name;
187 let alloy_consensus = self.alloy_consensus;
188
189 let variant = quote! { UntaggedLegacy(#ty) };
190 let arm = quote! { UntaggedTxTypes::UntaggedLegacy(tx) => Self::#name(tx), };
191 let deserialize = quote! {
192 if let Ok(val) = #alloy_consensus::transaction::untagged_legacy_serde::deserialize(deserializer).map(Self::UntaggedLegacy) {
193 return Ok(val);
194 }
195 };
196
197 (variant, arm, deserialize)
198 } else {
199 (quote! {}, quote! {}, quote! {})
200 }
201 }
202
203 fn generate_untagged_deserialize(&self, legacy_deserialize: &TokenStream) -> TokenStream {
205 let generics = self.generics;
206 let unwrapped_generics = &generics.params;
207 let serde = &self.serde;
208 let serde_json = &self.serde_json;
209 let serde_bounds = self.generate_serde_bounds();
210
211 let flattened_names = self.variants.flattened.iter().map(|v| &v.name);
212
213 quote! {
214 impl<'de, #unwrapped_generics> #serde::Deserialize<'de> for UntaggedTxTypes #generics where #serde_bounds {
217 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
218 where
219 D: #serde::Deserializer<'de>,
220 {
221 let value = #serde_json::Value::deserialize(deserializer)?;
222 let deserializer = &value;
223
224 let tagged_res =
225 TaggedTxTypes::<#unwrapped_generics>::deserialize(deserializer).map(Self::Tagged).map_err(#serde::de::Error::custom);
226
227 if tagged_res.is_ok() {
228 return tagged_res;
230 }
231
232 #(
234 if let Ok(val) = #serde::Deserialize::deserialize(deserializer).map(Self::#flattened_names) {
235 return Ok(val);
236 }
237 )*
238
239 #legacy_deserialize
240
241 tagged_res
244 }
245 }
246 }
247 }
248
249 fn generate_untagged_conversions(&self, legacy_arm: &TokenStream) -> TokenStream {
251 let input_type_name = self.input_type_name;
252 let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
253 let unwrapped_generics = &self.generics.params;
254 let flattened_names = self.variants.flattened.iter().map(|v| &v.name).collect::<Vec<_>>();
255 let typed_names = self.variants.typed.iter().map(|v| &v.name).collect::<Vec<_>>();
256
257 quote! {
258 impl #impl_generics From<UntaggedTxTypes #ty_generics> for #input_type_name #ty_generics {
259 fn from(value: UntaggedTxTypes #ty_generics) -> Self {
260 match value {
261 UntaggedTxTypes::Tagged(value) => value.into(),
262 #(
263 UntaggedTxTypes::#flattened_names(value) => Self::#flattened_names(value),
264 )*
265 #legacy_arm
266 }
267 }
268 }
269
270 impl #impl_generics From<#input_type_name #ty_generics> for UntaggedTxTypes #ty_generics {
271 fn from(value: #input_type_name #ty_generics) -> Self {
272 match value {
273 #(
274 #input_type_name::<#unwrapped_generics>::#flattened_names(value) => Self::#flattened_names(value),
275 )*
276 #(
277 #input_type_name::<#unwrapped_generics>::#typed_names(value) => Self::Tagged(TaggedTxTypes::#typed_names(value)),
278 )*
279 }
280 }
281 }
282 }
283 }
284
285 fn generate_serde_impls(&self, serde_bounds: &TokenStream) -> TokenStream {
287 let input_type_name = self.input_type_name;
288 let serde = &self.serde;
289 let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
290 let unwrapped_generics = &self.generics.params;
291
292 quote! {
293 impl #impl_generics #serde::Serialize for #input_type_name #ty_generics where #serde_bounds {
294 fn serialize<S: #serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
295 UntaggedTxTypes::<#unwrapped_generics>::from(self.clone()).serialize(serializer)
296 }
297 }
298
299 impl <'de, #unwrapped_generics> #serde::Deserialize<'de> for #input_type_name #ty_generics where #serde_bounds {
300 fn deserialize<D: #serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
301 UntaggedTxTypes::<#unwrapped_generics>::deserialize(deserializer).map(Into::into)
302 }
303 }
304 }
305 }
306}