1use amplify::proc_attr::ParametrizedAttr;
16use proc_macro2::{Span, TokenStream as TokenStream2};
17use quote::{ToTokens, TokenStreamExt};
18use syn::spanned::Spanned;
19use syn::{
20 Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident,
21 ImplGenerics, Index, LitStr, Result, TypeGenerics, WhereClause,
22};
23
24use crate::param::{EncodingDerive, TlvDerive, CRATE, REPR, USE_TLV};
25
26pub fn decode_derive(
32 attr_name: &'static str,
33 crate_name: Ident,
34 trait_name: Ident,
35 decode_name: Ident,
36 deserialize_name: Ident,
37 input: DeriveInput,
38 tlv_encoding: bool,
39) -> Result<TokenStream2> {
40 let (impl_generics, ty_generics, where_clause) =
41 input.generics.split_for_impl();
42 let ident_name = &input.ident;
43
44 let global_param = ParametrizedAttr::with(attr_name, &input.attrs)?;
45
46 match input.data {
47 Data::Struct(data) => decode_struct_impl(
48 attr_name,
49 &crate_name,
50 &trait_name,
51 &decode_name,
52 &deserialize_name,
53 data,
54 ident_name,
55 global_param,
56 impl_generics,
57 ty_generics,
58 where_clause,
59 tlv_encoding,
60 ),
61 Data::Enum(data) => decode_enum_impl(
62 attr_name,
63 &crate_name,
64 &trait_name,
65 &decode_name,
66 &deserialize_name,
67 data,
68 ident_name,
69 global_param,
70 impl_generics,
71 ty_generics,
72 where_clause,
73 ),
74 Data::Union(_) => Err(Error::new_spanned(
75 &input,
76 format!("Deriving `{}` is not supported in unions", trait_name),
77 )),
78 }
79}
80
81#[allow(clippy::too_many_arguments)]
82fn decode_struct_impl(
83 attr_name: &'static str,
84 crate_name: &Ident,
85 trait_name: &Ident,
86 decode_name: &Ident,
87 deserialize_name: &Ident,
88 data: DataStruct,
89 ident_name: &Ident,
90 mut global_param: ParametrizedAttr,
91 impl_generics: ImplGenerics,
92 ty_generics: TypeGenerics,
93 where_clause: Option<&WhereClause>,
94 tlv_encoding: bool,
95) -> Result<TokenStream2> {
96 let encoding = EncodingDerive::with(
97 &mut global_param,
98 crate_name,
99 true,
100 false,
101 false,
102 )?;
103
104 if !tlv_encoding && encoding.tlv.is_some() {
105 return Err(Error::new(
106 ident_name.span(),
107 format!("TLV extensions are not allowed in `{}`", attr_name),
108 ));
109 }
110
111 let inner_impl = match data.fields {
112 Fields::Named(ref fields) => decode_fields_impl(
113 attr_name,
114 crate_name,
115 trait_name,
116 decode_name,
117 deserialize_name,
118 ident_name,
119 &fields.named,
120 global_param,
121 false,
122 tlv_encoding,
123 )?,
124 Fields::Unnamed(ref fields) => decode_fields_impl(
125 attr_name,
126 crate_name,
127 trait_name,
128 decode_name,
129 deserialize_name,
130 ident_name,
131 &fields.unnamed,
132 global_param,
133 false,
134 tlv_encoding,
135 )?,
136 Fields::Unit => quote! {},
137 };
138
139 let import = encoding.use_crate;
140
141 Ok(quote! {
142 impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
143 #[inline]
144 fn #decode_name<D: ::std::io::Read>(mut d: D) -> ::core::result::Result<Self, #import::Error> {
145 use #import::#trait_name;
146 #inner_impl
147 }
148 }
149 })
150}
151
152#[allow(clippy::too_many_arguments)]
153fn decode_enum_impl(
154 attr_name: &'static str,
155 crate_name: &Ident,
156 trait_name: &Ident,
157 decode_name: &Ident,
158 deserialize_name: &Ident,
159 data: DataEnum,
160 ident_name: &Ident,
161 mut global_param: ParametrizedAttr,
162 impl_generics: ImplGenerics,
163 ty_generics: TypeGenerics,
164 where_clause: Option<&WhereClause>,
165) -> Result<TokenStream2> {
166 let encoding =
167 EncodingDerive::with(&mut global_param, crate_name, true, true, false)?;
168 let repr = encoding.repr;
169
170 let mut inner_impl = TokenStream2::new();
171
172 for (order, variant) in data.variants.iter().enumerate() {
173 let mut local_param =
174 ParametrizedAttr::with(attr_name, &variant.attrs)?;
175
176 let _ = EncodingDerive::with(
178 &mut local_param,
179 crate_name,
180 false,
181 true,
182 false,
183 )?;
184 let mut combined = global_param.clone().merged(local_param.clone())?;
186 combined.args.remove(REPR);
187 combined.args.remove(CRATE);
188 let encoding = EncodingDerive::with(
189 &mut combined,
190 crate_name,
191 false,
192 true,
193 false,
194 )?;
195
196 if encoding.skip {
197 continue;
198 }
199
200 let field_impl = match variant.fields {
201 Fields::Named(ref fields) => decode_fields_impl(
202 attr_name,
203 crate_name,
204 trait_name,
205 decode_name,
206 deserialize_name,
207 ident_name,
208 &fields.named,
209 local_param,
210 true,
211 false,
212 )?,
213 Fields::Unnamed(ref fields) => decode_fields_impl(
214 attr_name,
215 crate_name,
216 trait_name,
217 decode_name,
218 deserialize_name,
219 ident_name,
220 &fields.unnamed,
221 local_param,
222 true,
223 false,
224 )?,
225 Fields::Unit => TokenStream2::new(),
226 };
227
228 let ident = &variant.ident;
229 let value = match (encoding.value, encoding.by_order) {
230 (Some(val), _) => val.to_token_stream(),
231 (None, true) => Index::from(order as usize).to_token_stream(),
232 (None, false) => quote! { Self::#ident as #repr },
233 };
234
235 inner_impl.append_all(quote_spanned! { variant.span() =>
236 x if x == #value => {
237 Self::#ident {
238 #field_impl
239 }
240 }
241 });
242 }
243
244 let import = encoding.use_crate;
245 let enum_name = LitStr::new(&ident_name.to_string(), Span::call_site());
246
247 Ok(quote! {
248 impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
249 fn #decode_name<D: ::std::io::Read>(mut d: D) -> ::core::result::Result<Self, #import::Error> {
250 use #import::#trait_name;
251 Ok(match #repr::#decode_name(&mut d)? {
252 #inner_impl
253 unknown => Err(#import::Error::EnumValueNotKnown(#enum_name, unknown as usize))?
254 })
255 }
256 }
257 })
258}
259
260#[allow(clippy::too_many_arguments)]
261fn decode_fields_impl<'a>(
262 attr_name: &'static str,
263 crate_name: &Ident,
264 trait_name: &Ident,
265 decode_name: &Ident,
266 deserialize_name: &Ident,
267 ident_name: &Ident,
268 fields: impl IntoIterator<Item = &'a Field>,
269 mut parent_param: ParametrizedAttr,
270 is_enum: bool,
271 tlv_encoding: bool,
272) -> Result<TokenStream2> {
273 let mut stream = TokenStream2::new();
274
275 let use_tlv = parent_param.args.contains_key(USE_TLV);
276 parent_param.args.remove(CRATE);
277 parent_param.args.remove(USE_TLV);
278
279 if !tlv_encoding && use_tlv {
280 return Err(Error::new(
281 Span::call_site(),
282 format!("TLV extensions are not allowed in `{}`", attr_name),
283 ));
284 }
285
286 let parent_attr = EncodingDerive::with(
287 &mut parent_param.clone(),
288 crate_name,
289 false,
290 is_enum,
291 false,
292 )?;
293 let import = parent_attr.use_crate;
294
295 let mut skipped_fields = vec![];
296 let mut strict_fields = vec![];
297 let mut tlv_fields = bmap! {};
298 let mut tlv_aggregator = None;
299
300 for (index, field) in fields.into_iter().enumerate() {
301 let mut local_param = ParametrizedAttr::with(attr_name, &field.attrs)?;
302
303 let _ = EncodingDerive::with(
305 &mut local_param,
306 crate_name,
307 false,
308 is_enum,
309 use_tlv,
310 )?;
311 let mut combined = parent_param.clone().merged(local_param)?;
313 let encoding = EncodingDerive::with(
314 &mut combined,
315 crate_name,
316 false,
317 is_enum,
318 use_tlv,
319 )?;
320
321 let name = field
322 .ident
323 .as_ref()
324 .map(Ident::to_token_stream)
325 .unwrap_or_else(|| Index::from(index).to_token_stream());
326
327 if encoding.skip {
328 skipped_fields.push(name);
329 continue;
330 }
331
332 encoding.tlv.unwrap_or(TlvDerive::None).process(
333 field,
334 name,
335 &mut strict_fields,
336 &mut tlv_fields,
337 &mut tlv_aggregator,
338 )?;
339 }
340
341 for name in strict_fields {
342 stream.append_all(quote_spanned! { Span::call_site() =>
343 #name: #import::#trait_name::#decode_name(&mut d)?,
344 });
345 }
346
347 let mut default_fields = skipped_fields;
348 default_fields.extend(tlv_fields.values().map(|(n, _)| n).cloned());
349 default_fields.extend(tlv_aggregator.clone());
350 for name in default_fields {
351 stream.append_all(quote_spanned! { Span::call_site() =>
352 #name: Default::default(),
353 });
354 }
355
356 if !is_enum {
357 if use_tlv {
358 let mut inner = TokenStream2::new();
359 for (type_no, (name, optional)) in tlv_fields {
360 if optional {
361 inner.append_all(quote_spanned! { Span::call_site() =>
362 #type_no => s.#name = Some(#import::#trait_name::#deserialize_name(bytes)?),
363 });
364 } else {
365 inner.append_all(quote_spanned! { Span::call_site() =>
366 #type_no => s.#name = #import::#trait_name::#deserialize_name(bytes)?,
367 });
368 }
369 }
370
371 let aggregator = if let Some(ref tlv_aggregator) = tlv_aggregator {
372 quote_spanned! { Span::call_site() =>
373 _ if *type_no % 2 == 0 => return Err(#import::TlvError::UnknownEvenType(*type_no).into()),
374 _ => { s.#tlv_aggregator.insert(type_no, bytes); },
375 }
376 } else {
377 quote_spanned! { Span::call_site() =>
378 _ if *type_no % 2 == 0 => return Err(#import::TlvError::UnknownEvenType(*type_no).into()),
379 _ => {}
380 }
381 };
382
383 stream = quote_spanned! { Span::call_site() =>
384 let mut s = #ident_name { #stream };
385 let tlvs = internet2::tlv::Stream::#decode_name(&mut d)?;
386 };
387
388 stream.append_all(quote_spanned! { Span::call_site() =>
389 for (type_no, bytes) in tlvs {
390 match *type_no as usize {
391 #inner
392
393 #aggregator
394 }
395 }
396 Ok(s)
397 });
398 } else {
399 stream = quote_spanned! { Span::call_site() =>
400 Ok(#ident_name { #stream })
401 };
402 }
403 }
404
405 Ok(stream)
406}