deku_derive/
lib.rs

1/*!
2Procedural macros that implement `DekuRead` and `DekuWrite` traits
3 */
4#![warn(missing_docs)]
5
6extern crate alloc;
7
8use alloc::borrow::Cow;
9use std::convert::TryFrom;
10use std::fmt::Display;
11
12use darling::{ast, FromDeriveInput, FromField, FromMeta, FromVariant, ToTokens};
13use proc_macro2::TokenStream;
14use quote::quote;
15use syn::punctuated::Punctuated;
16use syn::spanned::Spanned;
17
18use crate::macros::deku_read::emit_deku_read;
19use crate::macros::deku_write::emit_deku_write;
20
21mod macros;
22
23#[derive(Debug)]
24enum Id {
25    TokenStream(TokenStream),
26    LitByteStr(syn::LitByteStr),
27    Int(syn::LitInt),
28    Bool(syn::LitBool),
29}
30
31impl Display for Id {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.write_str(&self.to_token_stream().to_string())
34    }
35}
36
37impl ToTokens for Id {
38    fn to_tokens(&self, tokens: &mut TokenStream) {
39        match self {
40            Id::TokenStream(v) => v.to_tokens(tokens),
41            Id::LitByteStr(v) => v.to_tokens(tokens),
42            Id::Int(v) => v.to_tokens(tokens),
43            Id::Bool(v) => v.to_tokens(tokens),
44        }
45    }
46}
47
48impl FromMeta for Id {
49    fn from_value(value: &syn::Lit) -> darling::Result<Self> {
50        (match *value {
51            syn::Lit::Str(ref s) => Ok(Id::TokenStream(
52                apply_replacements(s)
53                    .map_err(darling::Error::custom)?
54                    .parse::<TokenStream>()
55                    .expect("could not parse token stream"),
56            )),
57            syn::Lit::Int(ref s) => Ok(Id::Int(s.clone())),
58            syn::Lit::Bool(ref s) => Ok(Id::Bool(s.clone())),
59            syn::Lit::ByteStr(ref s) => Ok(Id::LitByteStr(s.clone())),
60            _ => Err(darling::Error::unexpected_lit_type(value)),
61        })
62        .map_err(|e| e.with_span(value))
63    }
64
65    fn from_string(value: &str) -> darling::Result<Self> {
66        Ok(Id::TokenStream(
67            value.parse().expect("Failed to parse tokens"),
68        ))
69    }
70}
71
72#[derive(Debug)]
73struct Num(syn::LitInt);
74
75impl Num {
76    fn new(n: syn::LitInt) -> Self {
77        Self(n)
78    }
79}
80
81impl Display for Num {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.write_str(&self.0.to_token_stream().to_string())
84    }
85}
86
87impl ToTokens for Num {
88    fn to_tokens(&self, tokens: &mut TokenStream) {
89        self.0.to_tokens(tokens)
90    }
91}
92
93impl FromMeta for Num {
94    fn from_value(value: &syn::Lit) -> darling::Result<Self> {
95        (match *value {
96            syn::Lit::Str(ref s) => Ok(Num::new(syn::LitInt::new(
97                s.value()
98                    .as_str()
99                    .parse::<usize>()
100                    .map_err(|_| darling::Error::unknown_value(&s.value()))?
101                    .to_string()
102                    .as_str(),
103                s.span(),
104            ))),
105            syn::Lit::Int(ref s) => Ok(Num::new(s.clone())),
106            _ => Err(darling::Error::unexpected_lit_type(value)),
107        })
108        .map_err(|e| e.with_span(value))
109    }
110}
111
112fn cerror(span: proc_macro2::Span, msg: &str) -> TokenStream {
113    syn::Error::new(span, msg).to_compile_error()
114}
115
116/// A post-processed version of `DekuReceiver`
117#[derive(Debug)]
118struct DekuData {
119    ident: syn::Ident,
120    generics: syn::Generics,
121    data: ast::Data<VariantData, FieldData>,
122
123    /// Endianness for all fields
124    endian: Option<syn::LitStr>,
125
126    /// top-level context, argument list
127    ctx: Option<syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>>,
128
129    /// default context passed to the field
130    ctx_default: Option<Punctuated<syn::Expr, syn::token::Comma>>,
131
132    /// A magic value that must appear at the start of this struct/enum's data
133    magic: Option<syn::LitByteStr>,
134
135    /// enum only: `id` value
136    id: Option<Id>,
137
138    /// enum only: type of the enum `id`
139    id_type: Option<TokenStream>,
140
141    /// enum only: endianness of the enum `id`
142    id_endian: Option<syn::LitStr>,
143
144    /// enum only: bit size of the enum `id`
145    #[cfg(feature = "bits")]
146    bits: Option<Num>,
147
148    /// enum only: byte size of the enum `id`
149    bytes: Option<Num>,
150
151    /// struct only: seek from current position
152    seek_rewind: bool,
153
154    /// struct only: seek from current position
155    seek_from_current: Option<TokenStream>,
156
157    /// struct only: seek from end position
158    seek_from_end: Option<TokenStream>,
159
160    /// struct only: seek from start position
161    seek_from_start: Option<TokenStream>,
162}
163
164impl DekuData {
165    fn from_input(input: TokenStream) -> Result<Self, TokenStream> {
166        let input = match syn::parse2(input) {
167            Ok(input) => input,
168            Err(err) => return Err(err.to_compile_error()),
169        };
170
171        let receiver = match DekuReceiver::from_derive_input(&input) {
172            Ok(receiver) => receiver,
173            Err(err) => return Err(err.write_errors()),
174        };
175
176        DekuData::from_receiver(receiver)
177    }
178
179    /// Map a `DekuReceiver` to `DekuData`
180    fn from_receiver(receiver: DekuReceiver) -> Result<Self, TokenStream> {
181        let data = match receiver.data {
182            ast::Data::Struct(fields) => ast::Data::Struct(ast::Fields::new(
183                fields.style,
184                fields
185                    .fields
186                    .into_iter()
187                    .map(FieldData::from_receiver)
188                    .collect::<Result<Vec<_>, _>>()?,
189            )),
190            ast::Data::Enum(variants) => ast::Data::Enum(
191                variants
192                    .into_iter()
193                    .map(VariantData::from_receiver)
194                    .collect::<Result<Vec<_>, _>>()?,
195            ),
196        };
197
198        let data = Self {
199            ident: receiver.ident,
200            generics: receiver.generics,
201            data,
202            endian: receiver.endian,
203            ctx: receiver.ctx,
204            ctx_default: receiver.ctx_default,
205            magic: receiver.magic,
206            id: receiver.id,
207            id_type: receiver.id_type?,
208            id_endian: receiver.id_endian,
209            #[cfg(feature = "bits")]
210            bits: receiver.bits,
211            bytes: receiver.bytes,
212            seek_rewind: receiver.seek_rewind,
213            seek_from_current: receiver.seek_from_current?,
214            seek_from_end: receiver.seek_from_end?,
215            seek_from_start: receiver.seek_from_start?,
216        };
217
218        DekuData::validate(&data)?;
219
220        Ok(data)
221    }
222
223    fn validate(data: &DekuData) -> Result<(), TokenStream> {
224        // Validate `ctx_default`
225        if data.ctx_default.is_some() && data.ctx.is_none() {
226            // FIXME: Use `Span::join` once out of nightly
227            return Err(cerror(
228                data.ctx_default.span(),
229                "`ctx_default` must be used with `ctx`",
230            ));
231        }
232
233        match data.data {
234            ast::Data::Struct(_) => {
235                // Validate id_* attributes are being used on an enum
236                let ret = if data.id_type.is_some() {
237                    Err(cerror(
238                        data.id_type.span(),
239                        "`id_type` only supported on enum",
240                    ))
241                } else if data.id.is_some() {
242                    Err(cerror(data.id.span(), "`id` only supported on enum"))
243                } else if data.id_endian.is_some() {
244                    Err(cerror(data.id.span(), "`id_endian` only supported on enum"))
245                } else if data.bytes.is_some() {
246                    Err(cerror(data.bytes.span(), "`bytes` only supported on enum"))
247                } else {
248                    Ok(())
249                };
250
251                #[cfg(feature = "bits")]
252                if ret.is_ok() && data.bits.is_some() {
253                    return Err(cerror(data.bits.span(), "`bits` only supported on enum"));
254                }
255
256                ret
257            }
258            ast::Data::Enum(_) => {
259                // Validate `id_type` or `id` is specified
260                if data.id_type.is_none() && data.id.is_none() {
261                    return Err(cerror(
262                        data.ident.span(),
263                        "`id_type` or `id` must be specified on enum",
264                    ));
265                }
266
267                // Validate either `id_type` or `id` is specified
268                if data.id_type.is_some() && data.id.is_some() {
269                    return Err(cerror(
270                        data.ident.span(),
271                        "conflicting: both `id_type` and `id` specified on enum",
272                    ));
273                }
274
275                // Validate `id_*` used correctly
276                #[cfg(feature = "bits")]
277                if data.id.is_some() && data.bits.is_some() {
278                    return Err(cerror(
279                        data.ident.span(),
280                        "error: cannot use `bits` with `id`",
281                    ));
282                }
283                if data.id.is_some() && data.bytes.is_some() {
284                    return Err(cerror(
285                        data.ident.span(),
286                        "error: cannot use `bytes` with `id`",
287                    ));
288                }
289
290                // Validate either `bits` or `bytes` is specified
291                #[cfg(feature = "bits")]
292                if data.bits.is_some() && data.bytes.is_some() {
293                    return Err(cerror(
294                        data.bits.span(),
295                        "conflicting: both `bits` and `bytes` specified on enum",
296                    ));
297                }
298
299                Ok(())
300            }
301        }
302    }
303
304    /// Emit a reader. On error, a compiler error is emitted
305    fn emit_reader(&self) -> TokenStream {
306        self.emit_reader_checked()
307            .unwrap_or_else(|e| e.to_compile_error())
308    }
309
310    /// Emit a writer. On error, a compiler error is emitted
311    fn emit_writer(&self) -> TokenStream {
312        self.emit_writer_checked()
313            .unwrap_or_else(|e| e.to_compile_error())
314    }
315
316    /// Same as `emit_reader`, but won't auto convert error to compile error
317    fn emit_reader_checked(&self) -> Result<TokenStream, syn::Error> {
318        emit_deku_read(self)
319    }
320
321    /// Same as `emit_writer`, but won't auto convert error to compile error
322    fn emit_writer_checked(&self) -> Result<TokenStream, syn::Error> {
323        emit_deku_write(self)
324    }
325}
326
327/// Common variables from `DekuData` for `emit_enum` read/write functions
328#[derive(Debug)]
329struct DekuDataEnum<'a> {
330    imp: syn::ImplGenerics<'a>,
331    wher: Option<&'a syn::WhereClause>,
332    variants: Vec<&'a VariantData>,
333    ident: TokenStream,
334    id: Option<&'a Id>,
335    id_type: Option<&'a TokenStream>,
336    id_args: TokenStream,
337}
338
339impl<'a> TryFrom<&'a DekuData> for DekuDataEnum<'a> {
340    type Error = syn::Error;
341
342    /// Create common initializer variables for `emit_enum` read/write functions
343    fn try_from(deku_data: &'a DekuData) -> Result<Self, Self::Error> {
344        let (imp, ty, wher) = deku_data.generics.split_for_impl();
345
346        // checked in `emit_deku_{read/write}`
347        let variants = deku_data.data.as_ref().take_enum().unwrap();
348
349        let ident = &deku_data.ident;
350        let ident = quote! { #ident #ty };
351
352        let id = deku_data.id.as_ref();
353        let id_type = deku_data.id_type.as_ref();
354
355        let id_args = crate::macros::gen_id_args(
356            deku_data.endian.as_ref(),
357            deku_data.id_endian.as_ref(),
358            #[cfg(feature = "bits")]
359            deku_data.bits.as_ref(),
360            #[cfg(not(feature = "bits"))]
361            None,
362            deku_data.bytes.as_ref(),
363        )?;
364
365        Ok(Self {
366            imp,
367            wher,
368            variants,
369            ident,
370            id,
371            id_type,
372            id_args,
373        })
374    }
375}
376
377/// Common variables from `DekuData` for `emit_struct` read/write functions
378#[derive(Debug)]
379struct DekuDataStruct<'a> {
380    imp: syn::ImplGenerics<'a>,
381    wher: Option<&'a syn::WhereClause>,
382    ident: TokenStream,
383    fields: darling::ast::Fields<&'a FieldData>,
384}
385
386impl<'a> TryFrom<&'a DekuData> for DekuDataStruct<'a> {
387    type Error = syn::Error;
388
389    /// Create common initializer variables for `emit_struct` read/write functions
390    fn try_from(deku_data: &'a DekuData) -> Result<Self, Self::Error> {
391        let (imp, ty, wher) = deku_data.generics.split_for_impl();
392
393        let ident = &deku_data.ident;
394        let ident = quote! { #ident #ty };
395
396        // Checked in `emit_deku_{read/write}`.
397        let fields = deku_data.data.as_ref().take_struct().unwrap();
398
399        Ok(Self {
400            imp,
401            wher,
402            ident,
403            fields,
404        })
405    }
406}
407
408/// A post-processed version of `FieldReceiver`
409#[derive(Debug)]
410struct FieldData {
411    ident: Option<syn::Ident>,
412    ty: syn::Type,
413
414    /// endianness for the field
415    endian: Option<syn::LitStr>,
416
417    /// field bit size
418    #[cfg(feature = "bits")]
419    bits: Option<Num>,
420
421    /// field byte size
422    bytes: Option<Num>,
423
424    /// tokens providing the length of the container
425    count: Option<TokenStream>,
426
427    /// tokens providing the number of bits for the length of the container
428    #[cfg(feature = "bits")]
429    bits_read: Option<TokenStream>,
430
431    /// tokens providing the number of bytes for the length of the container
432    bytes_read: Option<TokenStream>,
433
434    /// a predicate to decide when to stop reading elements into the container
435    until: Option<TokenStream>,
436
437    /// read until `reader.end()`
438    read_all: bool,
439
440    /// apply a function to the field after it's read
441    map: Option<TokenStream>,
442
443    /// context passed to the field
444    ctx: Option<Punctuated<syn::Expr, syn::token::Comma>>,
445
446    /// map field when updating struct
447    update: Option<TokenStream>,
448
449    /// custom field reader code
450    reader: Option<TokenStream>,
451
452    /// custom field writer code
453    writer: Option<TokenStream>,
454
455    /// skip field reading/writing
456    skip: bool,
457
458    /// pad a number of bits before
459    #[cfg(feature = "bits")]
460    pad_bits_before: Option<TokenStream>,
461
462    /// pad a number of bytes before
463    pad_bytes_before: Option<TokenStream>,
464
465    /// pad a number of bits after
466    #[cfg(feature = "bits")]
467    pad_bits_after: Option<TokenStream>,
468
469    /// pad a number of bytes after
470    pad_bytes_after: Option<TokenStream>,
471
472    /// read field as temporary value, isn't stored
473    temp: bool,
474
475    /// write given value of temp field
476    temp_value: Option<TokenStream>,
477
478    /// default value code when used with skip or cond
479    default: Option<TokenStream>,
480
481    /// condition to parse field
482    cond: Option<TokenStream>,
483
484    // assertion on field
485    assert: Option<TokenStream>,
486
487    // assert value of field
488    assert_eq: Option<TokenStream>,
489
490    /// seek from current position
491    seek_rewind: bool,
492
493    /// seek from current position
494    seek_from_current: Option<TokenStream>,
495
496    /// seek from end position
497    seek_from_end: Option<TokenStream>,
498
499    /// seek from start position
500    seek_from_start: Option<TokenStream>,
501}
502
503impl FieldData {
504    fn from_receiver(receiver: DekuFieldReceiver) -> Result<Self, TokenStream> {
505        let ctx = receiver
506            .ctx?
507            .map(|s| s.parse_with(Punctuated::parse_terminated))
508            .transpose()
509            .map_err(|e| e.to_compile_error())?;
510
511        let data = Self {
512            ident: receiver.ident,
513            ty: receiver.ty,
514            endian: receiver.endian,
515            #[cfg(feature = "bits")]
516            bits: receiver.bits,
517            bytes: receiver.bytes,
518            count: receiver.count?,
519            #[cfg(feature = "bits")]
520            bits_read: receiver.bits_read?,
521            bytes_read: receiver.bytes_read?,
522            until: receiver.until?,
523            read_all: receiver.read_all,
524            map: receiver.map?,
525            ctx,
526            update: receiver.update?,
527            reader: receiver.reader?,
528            writer: receiver.writer?,
529            skip: receiver.skip,
530            #[cfg(feature = "bits")]
531            pad_bits_before: receiver.pad_bits_before?,
532            pad_bytes_before: receiver.pad_bytes_before?,
533            #[cfg(feature = "bits")]
534            pad_bits_after: receiver.pad_bits_after?,
535            pad_bytes_after: receiver.pad_bytes_after?,
536            temp: receiver.temp,
537            temp_value: receiver.temp_value?,
538            default: receiver.default?,
539            cond: receiver.cond?,
540            assert: receiver.assert?,
541            assert_eq: receiver.assert_eq?,
542            seek_rewind: receiver.seek_rewind,
543            seek_from_current: receiver.seek_from_current?,
544            seek_from_end: receiver.seek_from_end?,
545            seek_from_start: receiver.seek_from_start?,
546        };
547
548        FieldData::validate(&data)?;
549
550        let default = data.default.or_else(|| Some(quote! { Default::default() }));
551
552        Ok(Self { default, ..data })
553    }
554
555    fn validate(data: &FieldData) -> Result<(), TokenStream> {
556        // Validate either `read_bytes` or `read_bits` is specified
557        #[cfg(feature = "bits")]
558        if data.bits_read.is_some() && data.bytes_read.is_some() {
559            return Err(cerror(
560                data.bits_read.span(),
561                "conflicting: both `bits_read` and `bytes_read` specified on field",
562            ));
563        }
564
565        // Validate either `count` or `bits_read`/`bytes_read` is specified
566        #[cfg(feature = "bits")]
567        if data.count.is_some() && (data.bits_read.is_some() || data.bytes_read.is_some()) {
568            if data.bits_read.is_some() {
569                return Err(cerror(
570                    data.count.span(),
571                    "conflicting: both `count` and `bits_read` specified on field",
572                ));
573            } else {
574                return Err(cerror(
575                    data.count.span(),
576                    "conflicting: both `count` and `bytes_read` specified on field",
577                ));
578            }
579        }
580
581        #[cfg(not(feature = "bits"))]
582        if data.count.is_some() && data.bytes_read.is_some() {
583            return Err(cerror(
584                data.count.span(),
585                "conflicting: both `count` and `bytes_read` specified on field",
586            ));
587        }
588
589        // Validate either `bits` or `bytes` is specified
590        #[cfg(feature = "bits")]
591        if data.bits.is_some() && data.bytes.is_some() {
592            // FIXME: Use `Span::join` once out of nightly
593            return Err(cerror(
594                data.bits.span(),
595                "conflicting: both `bits` and `bytes` specified on field",
596            ));
597        }
598
599        // Validate usage of `default` attribute
600        if data.default.is_some() && (!data.skip && data.cond.is_none()) {
601            // FIXME: Use `Span::join` once out of nightly
602            return Err(cerror(
603                data.default.span(),
604                "`default` attribute cannot be used here",
605            ));
606        }
607
608        // Validate usage of read_all
609        #[cfg(feature = "bits")]
610        if data.read_all
611            && (data.until.is_some()
612                || data.count.is_some()
613                || (data.bits_read.is_some() || data.bytes_read.is_some()))
614        {
615            return Err(cerror(
616                data.read_all.span(),
617                "conflicting: `read_all` cannot be used with `until`, `count`, `bits_read`, or `bytes_read`",
618            ));
619        }
620
621        // Validate usage of seek_*
622        if (data.seek_from_current.is_some() as u8
623            + data.seek_from_end.is_some() as u8
624            + data.seek_from_start.is_some() as u8
625            + data.seek_rewind as u8)
626            > 1
627        {
628            return Err(cerror(
629                data.ty.span(),
630                "conflicting: only one `seek` attribute can be used at one time",
631            ));
632        }
633
634        Ok(())
635    }
636
637    /// Get ident of the field
638    /// `index` is provided in the case of un-named structs
639    /// `prefix` is true in the case of variable declarations, false if original field is desired
640    fn get_ident(&self, index: usize, prefix: bool) -> TokenStream {
641        let field_ident = gen_field_ident(self.ident.as_ref(), index, prefix);
642        quote! { #field_ident }
643    }
644}
645
646/// A post-processed version of `VariantReceiver`
647#[derive(Debug)]
648struct VariantData {
649    ident: syn::Ident,
650    fields: ast::Fields<FieldData>,
651    discriminant: Option<syn::Expr>,
652
653    /// custom variant reader code
654    reader: Option<TokenStream>,
655
656    /// custom variant reader code
657    writer: Option<TokenStream>,
658
659    /// variant `id` value
660    id: Option<Id>,
661
662    /// variant `id_pat` value
663    id_pat: Option<TokenStream>,
664
665    /// variant `default` option
666    default: Option<bool>,
667}
668
669impl VariantData {
670    fn from_receiver(receiver: DekuVariantReceiver) -> Result<Self, TokenStream> {
671        let fields = ast::Fields::new(
672            receiver.fields.style,
673            receiver
674                .fields
675                .fields
676                .into_iter()
677                .map(FieldData::from_receiver)
678                .collect::<Result<Vec<_>, _>>()?,
679        );
680
681        let ret = Self {
682            ident: receiver.ident,
683            fields,
684            discriminant: receiver.discriminant,
685            reader: receiver.reader?,
686            writer: receiver.writer?,
687            id: receiver.id,
688            id_pat: receiver.id_pat?,
689            default: receiver.default,
690        };
691
692        VariantData::validate(&ret)?;
693
694        Ok(ret)
695    }
696
697    fn validate(data: &VariantData) -> Result<(), TokenStream> {
698        if data.id.is_some() && data.id_pat.is_some() {
699            // FIXME: Use `Span::join` once out of nightly
700            return Err(cerror(
701                data.id.span(),
702                "conflicting: both `id` and `id_pat` specified on variant",
703            ));
704        }
705
706        if let Some(id) = &data.id {
707            if id.to_string() == "_" {
708                return Err(cerror(
709                    data.ident.span(),
710                    "error: `id_pat` should be used for `_`",
711                ));
712            }
713        }
714
715        Ok(())
716    }
717}
718
719/// Receiver for the top-level struct or enum
720#[derive(Debug, FromDeriveInput)]
721#[darling(attributes(deku), supports(struct_any, enum_any))]
722struct DekuReceiver {
723    ident: syn::Ident,
724    generics: syn::Generics,
725    data: ast::Data<DekuVariantReceiver, DekuFieldReceiver>,
726
727    /// Endianness for all fields
728    #[darling(default)]
729    endian: Option<syn::LitStr>,
730
731    /// top-level context, argument list
732    #[darling(default)]
733    ctx: Option<syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>>,
734
735    /// default context passed to the field
736    #[darling(default)]
737    ctx_default: Option<syn::punctuated::Punctuated<syn::Expr, syn::token::Comma>>,
738
739    /// A magic value that must appear at the start of this struct/enum's data
740    #[darling(default)]
741    magic: Option<syn::LitByteStr>,
742
743    /// enum only: `id` value
744    #[darling(default)]
745    id: Option<Id>,
746
747    /// enum only: type of the enum `id`
748    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
749    id_type: Result<Option<TokenStream>, ReplacementError>,
750
751    /// enum only: endianness of the enum `id`
752    #[darling(default)]
753    id_endian: Option<syn::LitStr>,
754
755    /// enum only: bit size of the enum `id`
756    #[cfg(feature = "bits")]
757    #[darling(default)]
758    bits: Option<Num>,
759
760    /// enum only: byte size of the enum `id`
761    #[darling(default)]
762    bytes: Option<Num>,
763
764    /// struct only: seek from current position
765    #[darling(default)]
766    seek_rewind: bool,
767
768    /// struct only: seek from current position
769    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
770    seek_from_current: Result<Option<TokenStream>, ReplacementError>,
771
772    /// struct only: seek from end position
773    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
774    seek_from_end: Result<Option<TokenStream>, ReplacementError>,
775
776    /// struct only: seek from start position
777    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
778    seek_from_start: Result<Option<TokenStream>, ReplacementError>,
779}
780
781type ReplacementError = TokenStream;
782
783fn apply_replacements(input: &syn::LitStr) -> Result<Cow<'_, syn::LitStr>, ReplacementError> {
784    let input_value = input.value();
785
786    if !input_value.contains("deku") {
787        return Ok(Cow::Borrowed(input));
788    }
789
790    if input_value.contains("__deku_") {
791        return Err(darling::Error::unsupported_format(
792            "attribute cannot contain `__deku_` these are internal variables. Please use the `deku::` instead."
793        )).map_err(|e| e.with_span(&input).write_errors());
794    }
795
796    let input_str = input_value
797        .replace("deku::reader", "__deku_reader")
798        .replace("deku::writer", "__deku_writer")
799        .replace("deku::bit_offset", "__deku_bit_offset")
800        .replace("deku::byte_offset", "__deku_byte_offset");
801
802    Ok(Cow::Owned(syn::LitStr::new(&input_str, input.span())))
803}
804
805/// Calls apply replacements on Option<LitStr>
806fn map_option_litstr(input: Option<syn::LitStr>) -> Result<Option<syn::LitStr>, ReplacementError> {
807    Ok(match input {
808        Some(v) => Some(apply_replacements(&v)?.into_owned()),
809        None => None,
810    })
811}
812
813/// Parse a TokenStream from an Option<LitStr>
814/// Also replaces any namespaced variables to internal variables found in `input`
815fn map_litstr_as_tokenstream(
816    input: Option<syn::LitStr>,
817) -> Result<Option<TokenStream>, ReplacementError> {
818    Ok(match input {
819        Some(v) => {
820            let v = apply_replacements(&v)?;
821            Some(
822                v.parse::<TokenStream>()
823                    .expect("could not parse token stream"),
824            )
825        }
826        None => None,
827    })
828}
829
830/// Generate field name which supports both un-named/named structs/enums
831/// `ident` is Some if the container has named fields
832/// `index` is the numerical index of the current field used in un-named containers
833/// `prefix` is true in the case of variable declarations and match arms,
834/// false when the raw field is required, for example a field access
835fn gen_field_ident<T: ToString>(ident: Option<T>, index: usize, prefix: bool) -> TokenStream {
836    let field_name = match ident {
837        Some(field_name) => field_name.to_string(),
838        None => {
839            let index = syn::Index::from(index);
840            let prefix = if prefix { "field_" } else { "" };
841            format!("{}{}", prefix, quote! { #index })
842        }
843    };
844
845    field_name.parse().unwrap()
846}
847
848/// Provided default when a attribute is not available
849#[allow(clippy::unnecessary_wraps)]
850fn default_res_opt<T, E>() -> Result<Option<T>, E> {
851    Ok(None)
852}
853
854/// Receiver for the field-level attributes inside a struct/enum variant
855#[derive(Debug, FromField)]
856#[darling(attributes(deku))]
857struct DekuFieldReceiver {
858    ident: Option<syn::Ident>,
859    ty: syn::Type,
860
861    /// Endianness for the field
862    #[darling(default)]
863    endian: Option<syn::LitStr>,
864
865    /// field bit size
866    #[cfg(feature = "bits")]
867    #[darling(default)]
868    bits: Option<Num>,
869
870    /// field byte size
871    #[darling(default)]
872    bytes: Option<Num>,
873
874    /// tokens providing the length of the container
875    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
876    count: Result<Option<TokenStream>, ReplacementError>,
877
878    /// tokens providing the number of bits for the length of the container
879    #[cfg(feature = "bits")]
880    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
881    bits_read: Result<Option<TokenStream>, ReplacementError>,
882
883    /// tokens providing the number of bytes for the length of the container
884    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
885    bytes_read: Result<Option<TokenStream>, ReplacementError>,
886
887    /// a predicate to decide when to stop reading elements into the container
888    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
889    until: Result<Option<TokenStream>, ReplacementError>,
890
891    /// read until `reader.end()`
892    #[darling(default)]
893    read_all: bool,
894
895    /// apply a function to the field after it's read
896    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
897    map: Result<Option<TokenStream>, ReplacementError>,
898
899    /// context passed to the field.
900    /// A comma separated argument list.
901    // TODO: The type of it should be `Punctuated<Expr, Comma>`
902    //       https://github.com/TedDriggs/darling/pull/98
903    #[darling(default = "default_res_opt", map = "map_option_litstr")]
904    ctx: Result<Option<syn::LitStr>, ReplacementError>,
905
906    /// map field when updating struct
907    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
908    update: Result<Option<TokenStream>, ReplacementError>,
909
910    /// custom field reader code
911    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
912    reader: Result<Option<TokenStream>, ReplacementError>,
913
914    /// custom field writer code
915    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
916    writer: Result<Option<TokenStream>, ReplacementError>,
917
918    /// skip field reading/writing
919    #[darling(default)]
920    skip: bool,
921
922    /// pad a number of bits before
923    #[cfg(feature = "bits")]
924    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
925    pad_bits_before: Result<Option<TokenStream>, ReplacementError>,
926
927    /// pad a number of bytes before
928    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
929    pad_bytes_before: Result<Option<TokenStream>, ReplacementError>,
930
931    /// pad a number of bits after
932    #[cfg(feature = "bits")]
933    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
934    pad_bits_after: Result<Option<TokenStream>, ReplacementError>,
935
936    /// pad a number of bytes after
937    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
938    pad_bytes_after: Result<Option<TokenStream>, ReplacementError>,
939
940    /// read field as temporary value, isn't stored
941    #[darling(default)]
942    temp: bool,
943
944    /// write given value of temp field
945    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
946    temp_value: Result<Option<TokenStream>, ReplacementError>,
947
948    /// default value code when used with skip
949    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
950    default: Result<Option<TokenStream>, ReplacementError>,
951
952    /// condition to parse field
953    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
954    cond: Result<Option<TokenStream>, ReplacementError>,
955
956    // assertion on field
957    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
958    assert: Result<Option<TokenStream>, ReplacementError>,
959
960    // assert value of field
961    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
962    assert_eq: Result<Option<TokenStream>, ReplacementError>,
963
964    /// seek from current position
965    #[darling(default)]
966    seek_rewind: bool,
967
968    /// seek from current position
969    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
970    seek_from_current: Result<Option<TokenStream>, ReplacementError>,
971
972    /// seek from end position
973    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
974    seek_from_end: Result<Option<TokenStream>, ReplacementError>,
975
976    /// seek from start position
977    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
978    seek_from_start: Result<Option<TokenStream>, ReplacementError>,
979}
980
981/// Receiver for the variant-level attributes inside a enum
982#[derive(Debug, FromVariant)]
983#[darling(attributes(deku))]
984struct DekuVariantReceiver {
985    ident: syn::Ident,
986    fields: ast::Fields<DekuFieldReceiver>,
987    discriminant: Option<syn::Expr>,
988
989    /// custom variant reader code
990    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
991    reader: Result<Option<TokenStream>, ReplacementError>,
992
993    /// custom variant reader code
994    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
995    writer: Result<Option<TokenStream>, ReplacementError>,
996
997    /// variant `id` value
998    #[darling(default)]
999    id: Option<Id>,
1000
1001    /// variant `id_pat` value
1002    #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
1003    id_pat: Result<Option<TokenStream>, ReplacementError>,
1004
1005    /// variant `id` value
1006    #[darling(default)]
1007    default: Option<bool>,
1008}
1009
1010/// Entry function for `DekuRead` proc-macro
1011#[proc_macro_derive(DekuRead, attributes(deku))]
1012pub fn proc_deku_read(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1013    match DekuData::from_input(input.into()) {
1014        Ok(data) => data.emit_reader().into(),
1015        Err(err) => err.into(),
1016    }
1017}
1018
1019/// Entry function for `DekuWrite` proc-macro
1020#[proc_macro_derive(DekuWrite, attributes(deku))]
1021pub fn proc_deku_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
1022    match DekuData::from_input(input.into()) {
1023        Ok(data) => data.emit_writer().into(),
1024        Err(err) => err.into(),
1025    }
1026}
1027
1028fn is_not_deku(attr: &syn::Attribute) -> bool {
1029    attr.path()
1030        .get_ident()
1031        .map(|ident| ident != "deku" && ident != "deku_derive")
1032        .unwrap_or(true)
1033}
1034
1035fn is_temp(field: &syn::Field) -> bool {
1036    DekuFieldReceiver::from_field(field)
1037        .map(|attrs| attrs.temp)
1038        .unwrap_or(false)
1039}
1040
1041fn remove_deku_temp_fields(fields: &mut syn::punctuated::Punctuated<syn::Field, syn::Token![,]>) {
1042    *fields = fields
1043        .clone()
1044        .into_pairs()
1045        .filter(|x| !is_temp(x.value()))
1046        .collect()
1047}
1048
1049fn remove_deku_field_attrs(fields: &mut syn::punctuated::Punctuated<syn::Field, syn::Token![,]>) {
1050    *fields = fields
1051        .clone()
1052        .into_pairs()
1053        .map(|mut field| {
1054            field.value_mut().attrs.retain(is_not_deku);
1055            field
1056        })
1057        .collect()
1058}
1059
1060fn remove_deku_attrs(fields: &mut syn::Fields) {
1061    match fields {
1062        syn::Fields::Named(ref mut fields) => remove_deku_field_attrs(&mut fields.named),
1063        syn::Fields::Unnamed(ref mut fields) => remove_deku_field_attrs(&mut fields.unnamed),
1064        syn::Fields::Unit => {}
1065    }
1066}
1067
1068fn remove_temp_fields(fields: &mut syn::Fields) {
1069    match fields {
1070        syn::Fields::Named(ref mut fields) => remove_deku_temp_fields(&mut fields.named),
1071        syn::Fields::Unnamed(ref mut fields) => remove_deku_temp_fields(&mut fields.unnamed),
1072        syn::Fields::Unit => {}
1073    }
1074}
1075
1076#[derive(Debug, FromMeta)]
1077struct DekuDerive {
1078    #[darling(default, rename = "DekuRead")]
1079    read: bool,
1080    #[darling(default, rename = "DekuWrite")]
1081    write: bool,
1082}
1083
1084/// Entry function for `deku_derive` proc-macro
1085/// This attribute macro is used to derive `DekuRead` and `DekuWrite`
1086/// while removing temporary variables.
1087#[proc_macro_attribute]
1088pub fn deku_derive(
1089    attr: proc_macro::TokenStream,
1090    item: proc_macro::TokenStream,
1091) -> proc_macro::TokenStream {
1092    // Parse `deku_derive` attribute
1093    let nested_meta = darling::ast::NestedMeta::parse_meta_list(attr.into()).unwrap();
1094    let args = match DekuDerive::from_list(&nested_meta) {
1095        Ok(v) => v,
1096        Err(e) => {
1097            return proc_macro::TokenStream::from(e.write_errors());
1098        }
1099    };
1100
1101    // Parse item
1102    let data = match DekuData::from_input(item.clone().into()) {
1103        Ok(data) => data,
1104        Err(err) => return err.into(),
1105    };
1106
1107    // Generate `DekuRead` impl
1108    let read_impl = if args.read {
1109        data.emit_reader()
1110    } else {
1111        TokenStream::new()
1112    };
1113
1114    // Remove the temp fields
1115    let mut input = syn::parse_macro_input!(item as syn::DeriveInput);
1116
1117    match input.data {
1118        syn::Data::Struct(ref mut input_struct) => remove_temp_fields(&mut input_struct.fields),
1119        syn::Data::Enum(ref mut input_enum) => {
1120            for variant in input_enum.variants.iter_mut() {
1121                remove_temp_fields(&mut variant.fields)
1122            }
1123        }
1124        _ => unimplemented!(),
1125    }
1126
1127    // Generate `DekuWrite` impl
1128    let write_impl = if args.write {
1129        data.emit_writer()
1130    } else {
1131        TokenStream::new()
1132    };
1133
1134    // Remove attributes
1135    match input.data {
1136        syn::Data::Struct(ref mut input_struct) => {
1137            input.attrs.retain(is_not_deku);
1138            remove_deku_attrs(&mut input_struct.fields)
1139        }
1140        syn::Data::Enum(ref mut input_enum) => {
1141            for variant in input_enum.variants.iter_mut() {
1142                variant.attrs.retain(is_not_deku);
1143                remove_deku_attrs(&mut variant.fields)
1144            }
1145        }
1146        _ => unimplemented!(),
1147    }
1148
1149    input.attrs.retain(is_not_deku);
1150
1151    quote!(
1152        #read_impl
1153
1154        #write_impl
1155
1156        #input
1157    )
1158    .into()
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use rstest::rstest;
1164    use syn::parse_str;
1165
1166    use super::*;
1167
1168    #[rstest(input,
1169        // Valid struct
1170        case::struct_empty(r#"struct Test {}"#),
1171        case::struct_unnamed(r#"struct Test(u8, u8);"#),
1172        case::struct_unnamed_attrs(r#"struct Test(#[deku(bits=4)] u8, u8);"#),
1173        case::struct_all_attrs(r#"
1174        struct Test {
1175            #[deku(bits = 4)]
1176            field_a: u8,
1177            #[deku(bytes = 4)]
1178            field_b: u64,
1179            #[deku(endian = "little")]
1180            field_c: u32,
1181            #[deku(endian = "big")]
1182            field_d: u32,
1183            #[deku(skip, default = "5")]
1184            field_e: u32,
1185        }"#),
1186        case::struct_internal_var(r#"
1187        struct Test {
1188            #[deku(bits_read = "deku::rest.len()")]
1189            field: Vec<u8>,
1190        }"#),
1191
1192        // Valid Enum
1193        case::enum_empty(r#"#[deku(id_type = "u8")] enum Test {}"#),
1194        case::enum_all(r#"
1195        #[deku(id_type = "u8")]
1196        enum Test {
1197            #[deku(id = "1")]
1198            A,
1199            #[deku(id = "2")]
1200            B(#[deku(bits = 4)] u8),
1201            #[deku(id = "3")]
1202            C { field_n: u8 },
1203        }"#),
1204
1205        // TODO: these tests should error/warn eventually?
1206        // error: trying to store 9 bits in 8 bit type
1207        case::invalid_storage(r#"struct Test(#[deku(bits=9)] u8);"#),
1208        // warn: trying to set endian on a type which wouldn't make a difference
1209        case::invalid_endian(r#"struct Test(#[endian=big] u8);"#),
1210    )]
1211    fn test_macro(input: &str) {
1212        let parsed = parse_str(input).unwrap();
1213
1214        let data = DekuData::from_input(parsed).unwrap();
1215        let res_reader = data.emit_reader_checked();
1216        let res_writer = data.emit_writer_checked();
1217
1218        res_reader.unwrap();
1219        res_writer.unwrap();
1220    }
1221}