1use amplify::proc_attr::ParametrizedAttr;
16use proc_macro2::{Span, TokenStream as TokenStream2};
17use quote::{ToTokens, TokenStreamExt};
18use syn::spanned::Spanned;
19use syn::{
20 Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident,
21 ImplGenerics, Index, Result, TypeGenerics, WhereClause,
22};
23
24use crate::param::{EncodingDerive, TlvDerive, CRATE, REPR, USE_TLV};
25
26pub fn encode_derive(
32 attr_name: &'static str,
33 crate_name: Ident,
34 trait_name: Ident,
35 encode_name: Ident,
36 serialize_name: Ident,
37 input: DeriveInput,
38 tlv_encoding: bool,
39) -> Result<TokenStream2> {
40 let (impl_generics, ty_generics, where_clause) =
41 input.generics.split_for_impl();
42 let ident_name = &input.ident;
43
44 let global_param = ParametrizedAttr::with(attr_name, &input.attrs)?;
45
46 match input.data {
47 Data::Struct(data) => encode_struct_impl(
48 attr_name,
49 &crate_name,
50 &trait_name,
51 &encode_name,
52 &serialize_name,
53 data,
54 ident_name,
55 global_param,
56 impl_generics,
57 ty_generics,
58 where_clause,
59 tlv_encoding,
60 ),
61 Data::Enum(data) => encode_enum_impl(
62 attr_name,
63 &crate_name,
64 &trait_name,
65 &encode_name,
66 &serialize_name,
67 data,
68 ident_name,
69 global_param,
70 impl_generics,
71 ty_generics,
72 where_clause,
73 ),
74 Data::Union(_) => Err(Error::new_spanned(
75 &input,
76 format!("Deriving `{}` is not supported in unions", trait_name),
77 )),
78 }
79}
80
81#[allow(clippy::too_many_arguments)]
82fn encode_struct_impl(
83 attr_name: &'static str,
84 crate_name: &Ident,
85 trait_name: &Ident,
86 encode_name: &Ident,
87 serialize_name: &Ident,
88 data: DataStruct,
89 ident_name: &Ident,
90 mut global_param: ParametrizedAttr,
91 impl_generics: ImplGenerics,
92 ty_generics: TypeGenerics,
93 where_clause: Option<&WhereClause>,
94 tlv_encoding: bool,
95) -> Result<TokenStream2> {
96 let encoding = EncodingDerive::with(
97 &mut global_param,
98 crate_name,
99 true,
100 false,
101 false,
102 )?;
103
104 if !tlv_encoding && encoding.tlv.is_some() {
105 return Err(Error::new(
106 ident_name.span(),
107 format!("TLV extensions are not allowed in `{}`", attr_name),
108 ));
109 }
110
111 let inner_impl = match data.fields {
112 Fields::Named(ref fields) => encode_fields_impl(
113 attr_name,
114 crate_name,
115 trait_name,
116 encode_name,
117 serialize_name,
118 &fields.named,
119 global_param,
120 false,
121 tlv_encoding,
122 )?,
123 Fields::Unnamed(ref fields) => encode_fields_impl(
124 attr_name,
125 crate_name,
126 trait_name,
127 encode_name,
128 serialize_name,
129 &fields.unnamed,
130 global_param,
131 false,
132 tlv_encoding,
133 )?,
134 Fields::Unit => quote! { Ok(0) },
135 };
136
137 let import = encoding.use_crate;
138
139 Ok(quote! {
140 impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
141 fn #encode_name<E: ::std::io::Write>(&self, mut e: E) -> ::core::result::Result<usize, #import::Error> {
142 use #import::#trait_name;
143 let mut len = 0;
144 let data = self;
145 #inner_impl
146 Ok(len)
147 }
148 }
149 })
150}
151
152#[allow(clippy::too_many_arguments)]
153fn encode_enum_impl(
154 attr_name: &'static str,
155 crate_name: &Ident,
156 trait_name: &Ident,
157 encode_name: &Ident,
158 serialize_name: &Ident,
159 data: DataEnum,
160 ident_name: &Ident,
161 mut global_param: ParametrizedAttr,
162 impl_generics: ImplGenerics,
163 ty_generics: TypeGenerics,
164 where_clause: Option<&WhereClause>,
165) -> Result<TokenStream2> {
166 let encoding =
167 EncodingDerive::with(&mut global_param, crate_name, true, true, false)?;
168 let repr = encoding.repr;
169
170 let mut inner_impl = TokenStream2::new();
171
172 for (order, variant) in data.variants.iter().enumerate() {
173 let mut local_param =
174 ParametrizedAttr::with(attr_name, &variant.attrs)?;
175
176 let _ = EncodingDerive::with(
178 &mut local_param,
179 crate_name,
180 false,
181 true,
182 false,
183 )?;
184 let mut combined = global_param.clone().merged(local_param.clone())?;
186 combined.args.remove(REPR);
187 combined.args.remove(CRATE);
188 let encoding = EncodingDerive::with(
189 &mut combined,
190 crate_name,
191 false,
192 true,
193 false,
194 )?;
195
196 if encoding.skip {
197 continue;
198 }
199
200 let captures = variant
201 .fields
202 .iter()
203 .enumerate()
204 .map(|(i, f)| {
205 f.ident.as_ref().map(Ident::to_token_stream).unwrap_or_else(
206 || {
207 Ident::new(&format!("_{}", i), Span::call_site())
208 .to_token_stream()
209 },
210 )
211 })
212 .collect::<Vec<_>>();
213
214 let (field_impl, bra_captures_ket) = match variant.fields {
215 Fields::Named(ref fields) => (
216 encode_fields_impl(
217 attr_name,
218 crate_name,
219 trait_name,
220 encode_name,
221 serialize_name,
222 &fields.named,
223 local_param,
224 true,
225 false,
226 )?,
227 quote! { { #( #captures ),* } },
228 ),
229 Fields::Unnamed(ref fields) => (
230 encode_fields_impl(
231 attr_name,
232 crate_name,
233 trait_name,
234 encode_name,
235 serialize_name,
236 &fields.unnamed,
237 local_param,
238 true,
239 false,
240 )?,
241 quote! { ( #( #captures ),* ) },
242 ),
243 Fields::Unit => (TokenStream2::new(), TokenStream2::new()),
244 };
245
246 let captures = match captures.len() {
247 0 => quote! {},
248 _ => quote! { let data = ( #( #captures ),* , ); },
249 };
250
251 let ident = &variant.ident;
252 let value = match (encoding.value, encoding.by_order) {
253 (Some(val), _) => val.to_token_stream(),
254 (None, true) => Index::from(order as usize).to_token_stream(),
255 (None, false) => quote! { Self::#ident },
256 };
257
258 inner_impl.append_all(quote_spanned! { variant.span() =>
259 Self::#ident #bra_captures_ket => {
260 len += (#value as #repr).#encode_name(&mut e)?;
261 #captures
262 #field_impl
263 }
264 });
265 }
266
267 let import = encoding.use_crate;
268
269 Ok(quote! {
270 impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
271 #[inline]
272 fn #encode_name<E: ::std::io::Write>(&self, mut e: E) -> ::core::result::Result<usize, #import::Error> {
273 use #import::#trait_name;
274 let mut len = 0;
275 match self {
276 #inner_impl
277 }
278 Ok(len)
279 }
280 }
281 })
282}
283
284#[allow(clippy::too_many_arguments)]
285fn encode_fields_impl<'a>(
286 attr_name: &'static str,
287 crate_name: &Ident,
288 _trait_name: &Ident,
289 encode_name: &Ident,
290 serialize_name: &Ident,
291 fields: impl IntoIterator<Item = &'a Field>,
292 mut parent_param: ParametrizedAttr,
293 is_enum: bool,
294 tlv_encoding: bool,
295) -> Result<TokenStream2> {
296 let mut stream = TokenStream2::new();
297
298 let use_tlv = parent_param.args.contains_key(USE_TLV);
299 parent_param.args.remove(CRATE);
300 parent_param.args.remove(USE_TLV);
301
302 if !tlv_encoding && use_tlv {
303 return Err(Error::new(
304 Span::call_site(),
305 format!("TLV extensions are not allowed in `{}`", attr_name),
306 ));
307 }
308
309 let mut strict_fields = vec![];
310 let mut tlv_fields = bmap! {};
311 let mut tlv_aggregator = None;
312
313 for (index, field) in fields.into_iter().enumerate() {
314 let mut local_param = ParametrizedAttr::with(attr_name, &field.attrs)?;
315
316 let _ = EncodingDerive::with(
318 &mut local_param,
319 crate_name,
320 false,
321 is_enum,
322 use_tlv,
323 )?;
324 let mut combined = parent_param.clone().merged(local_param)?;
326 let encoding = EncodingDerive::with(
327 &mut combined,
328 crate_name,
329 false,
330 is_enum,
331 use_tlv,
332 )?;
333
334 if encoding.skip {
335 continue;
336 }
337
338 let index = Index::from(index).to_token_stream();
339 let name = if is_enum {
340 index
341 } else {
342 field
343 .ident
344 .as_ref()
345 .map(Ident::to_token_stream)
346 .unwrap_or(index)
347 };
348
349 encoding.tlv.unwrap_or(TlvDerive::None).process(
350 field,
351 name,
352 &mut strict_fields,
353 &mut tlv_fields,
354 &mut tlv_aggregator,
355 )?;
356 }
357
358 for name in strict_fields {
359 stream.append_all(quote_spanned! { Span::call_site() =>
360 len += data.#name.#encode_name(&mut e)?;
361 })
362 }
363
364 if use_tlv {
365 stream.append_all(quote_spanned! { Span::call_site() =>
366 let mut tlvs = internet2::tlv::Stream::default();
367 });
368 for (type_no, (name, optional)) in tlv_fields {
369 if optional {
370 stream.append_all(quote_spanned! { Span::call_site() =>
371 if let Some(val) = &data.#name {
372 tlvs.insert(#type_no.into(), val.#serialize_name()?);
373 }
374 });
375 } else {
376 stream.append_all(quote_spanned! { Span::call_site() =>
377 if data.#name.iter().count() > 0 {
378 tlvs.insert(#type_no.into(), data.#name.#serialize_name()?);
379 }
380 });
381 }
382 }
383 if let Some(name) = tlv_aggregator {
384 stream.append_all(quote_spanned! { Span::call_site() =>
385 for (type_no, val) in &data.#name {
386 tlvs.insert(*type_no, val);
387 }
388 });
389 }
390
391 stream.append_all(quote_spanned! { Span::call_site() =>
392 len += tlvs.#encode_name(&mut e)?;
393 })
394 }
395
396 Ok(stream)
397}