pdl_compiler/backends/rust/
mod.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Rust compiler backend.
16
17use crate::{analyzer, ast};
18use quote::{format_ident, quote};
19use std::collections::{BTreeMap, BTreeSet, HashMap};
20use std::path::Path;
21use syn::LitInt;
22
23mod decoder;
24mod encoder;
25mod preamble;
26pub mod test;
27mod types;
28
29use decoder::FieldParser;
30pub use heck::ToUpperCamelCase;
31
32pub trait ToIdent {
33    /// Generate a sanitized rust identifier.
34    /// Rust specific keywords are renamed for validity.
35    fn to_ident(self) -> proc_macro2::Ident;
36}
37
38impl ToIdent for &'_ str {
39    fn to_ident(self) -> proc_macro2::Ident {
40        match self {
41            "as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
42            | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod"
43            | "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self" | "static" | "struct"
44            | "super" | "trait" | "true" | "type" | "unsafe" | "use" | "where" | "while"
45            | "async" | "await" | "dyn" | "abstract" | "become" | "box" | "do" | "final"
46            | "macro" | "override" | "priv" | "typeof" | "unsized" | "virtual" | "yield"
47            | "try" => format_ident!("r#{}", self),
48            _ => format_ident!("{}", self),
49        }
50    }
51}
52
53/// Generate a bit-mask which masks out `n` least significant bits.
54///
55/// Literal integers in Rust default to the `i32` type. For this
56/// reason, if `n` is larger than 31, a suffix is added to the
57/// `LitInt` returned. This should either be `u64` or `usize`
58/// depending on where the result is used.
59pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
60    let suffix = if n > 31 { format!("_{suffix}") } else { String::new() };
61    // Format the hex digits as 0x1111_2222_3333_usize.
62    let hex_digits = format!("{:x}", (1u64 << n) - 1)
63        .as_bytes()
64        .rchunks(4)
65        .rev()
66        .map(|chunk| std::str::from_utf8(chunk).unwrap())
67        .collect::<Vec<&str>>()
68        .join("_");
69    syn::parse_str::<syn::LitInt>(&format!("0x{hex_digits}{suffix}")).unwrap()
70}
71
72/// Return the list of fields that will appear in the generated
73/// rust structs (<Packet> and <Packet>Builder).
74///
75///  - must be a named field
76///  - must not be a flag
77///  - must not appear in the packet constraints.
78///
79/// The fields are presented in declaration order, with ancestor
80/// fields declared first.
81/// The payload field _ if declared _ is handled separately.
82fn packet_data_fields<'a>(
83    scope: &'a analyzer::Scope<'a>,
84    decl: &'a ast::Decl,
85) -> Vec<&'a ast::Field> {
86    let all_constraints = HashMap::<String, _>::from_iter(
87        scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
88    );
89
90    scope
91        .iter_fields(decl)
92        .filter(|f| f.id().is_some())
93        .filter(|f| !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
94        .filter(|f| !all_constraints.contains_key(f.id().unwrap()))
95        .collect::<Vec<_>>()
96}
97
98/// Return the list of fields that have a constant value.
99/// The fields are presented in declaration order, with ancestor
100/// fields declared first.
101fn packet_constant_fields<'a>(
102    scope: &'a analyzer::Scope<'a>,
103    decl: &'a ast::Decl,
104) -> Vec<&'a ast::Field> {
105    let all_constraints = HashMap::<String, _>::from_iter(
106        scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
107    );
108
109    scope
110        .iter_fields(decl)
111        .filter(|f| f.id().is_some())
112        .filter(|f| all_constraints.contains_key(f.id().unwrap()))
113        .collect::<Vec<_>>()
114}
115
116#[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Debug, Hash)]
117enum ConstraintValue {
118    Scalar(usize),
119    Tag(String, String),
120}
121
122impl quote::ToTokens for ConstraintValue {
123    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
124        tokens.extend(match self {
125            ConstraintValue::Scalar(s) => {
126                let s = proc_macro2::Literal::usize_unsuffixed(*s);
127                quote!(#s)
128            }
129            ConstraintValue::Tag(e, t) => {
130                let tag_id = format_ident!("{}", t.to_upper_camel_case());
131                let type_id = format_ident!("{}", e);
132                quote!(#type_id::#tag_id)
133            }
134        })
135    }
136}
137
138fn constraint_value_ast(
139    fields: &[&'_ ast::Field],
140    constraint: &ast::Constraint,
141) -> ConstraintValue {
142    match constraint {
143        ast::Constraint { value: Some(value), .. } => ConstraintValue::Scalar(*value),
144        ast::Constraint { tag_id: Some(tag_id), .. } => {
145            let type_id = fields
146                .iter()
147                .filter_map(|f| match &f.desc {
148                    ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
149                        Some(type_id)
150                    }
151                    _ => None,
152                })
153                .next()
154                .unwrap();
155            ConstraintValue::Tag(type_id.clone(), tag_id.clone())
156        }
157        _ => unreachable!("Invalid constraint: {constraint:?}"),
158    }
159}
160
161fn constraint_value(
162    fields: &[&'_ ast::Field],
163    constraint: &ast::Constraint,
164) -> proc_macro2::TokenStream {
165    match constraint {
166        ast::Constraint { value: Some(value), .. } => {
167            let value = proc_macro2::Literal::usize_unsuffixed(*value);
168            quote!(#value)
169        }
170        // TODO(mgeisler): include type_id in `ast::Constraint` and
171        // drop the packet_scope argument.
172        ast::Constraint { tag_id: Some(tag_id), .. } => {
173            let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
174            let type_id = fields
175                .iter()
176                .filter_map(|f| match &f.desc {
177                    ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
178                        Some(type_id.to_ident())
179                    }
180                    _ => None,
181                })
182                .next()
183                .unwrap();
184            quote!(#type_id::#tag_id)
185        }
186        _ => unreachable!("Invalid constraint: {constraint:?}"),
187    }
188}
189
190fn constraint_value_str(fields: &[&'_ ast::Field], constraint: &ast::Constraint) -> String {
191    match constraint {
192        ast::Constraint { value: Some(value), .. } => {
193            format!("{}", value)
194        }
195        ast::Constraint { tag_id: Some(tag_id), .. } => {
196            let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
197            let type_id = fields
198                .iter()
199                .filter_map(|f| match &f.desc {
200                    ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
201                        Some(type_id.to_ident())
202                    }
203                    _ => None,
204                })
205                .next()
206                .unwrap();
207            format!("{}::{}", type_id, tag_id)
208        }
209        _ => unreachable!("Invalid constraint: {constraint:?}"),
210    }
211}
212
213fn implements_copy(scope: &analyzer::Scope<'_>, field: &ast::Field) -> bool {
214    match &field.desc {
215        ast::FieldDesc::Scalar { .. } => true,
216        ast::FieldDesc::Typedef { type_id, .. } => match &scope.typedef[type_id].desc {
217            ast::DeclDesc::Enum { .. } | ast::DeclDesc::CustomField { .. } => true,
218            ast::DeclDesc::Struct { .. } => false,
219            desc => unreachable!("unexpected declaration: {desc:?}"),
220        },
221        ast::FieldDesc::Array { .. } => false,
222        _ => todo!(),
223    }
224}
225
226/// Generate the implementation of the specialize method.
227///
228/// The function is generated after selecting the information from the parent
229/// packet that can be used to identify with
230/// _certainty_ the child packet.
231///
232/// The discriminant information is:
233///     - field values
234///     - payload size, to disambiguate between children of
235///       identical constant size
236///
237/// The generator will raise warnings if ambiguities remain after all
238/// information is taken into account, i.e. two child packets map to the same
239/// constraints. In this case ambiguities are resolved by trying each child
240/// in order of declaration.
241fn generate_specialize_impl(
242    scope: &analyzer::Scope<'_>,
243    schema: &analyzer::Schema,
244    decl: &ast::Decl,
245    id: &str,
246    data_fields: &[&ast::Field],
247) -> Result<proc_macro2::TokenStream, String> {
248    #[derive(PartialEq, Eq)]
249    struct SpecializeCase {
250        id: String,
251        constraints: HashMap<String, ConstraintValue>,
252        size: analyzer::Size,
253    }
254
255    fn gather_specialize_cases(
256        scope: &analyzer::Scope<'_>,
257        schema: &analyzer::Schema,
258        id: &str,
259        decl: &ast::Decl,
260        data_fields: &[&ast::Field],
261        constraints: &HashMap<String, ConstraintValue>,
262        specialize_cases: &mut Vec<SpecializeCase>,
263    ) {
264        // Add local constraints to the context.
265        let mut constraints = constraints.clone();
266        for c in decl.constraints() {
267            if data_fields.iter().any(|f| f.id() == Some(&c.id)) {
268                constraints.insert(c.id.to_owned(), constraint_value_ast(data_fields, c));
269            }
270        }
271
272        // Generate specialize cases for the child declarations.
273        for decl in scope.iter_children(decl) {
274            gather_specialize_cases(
275                scope,
276                schema,
277                id,
278                decl,
279                data_fields,
280                &constraints,
281                specialize_cases,
282            );
283        }
284
285        // Add a case for the current declaration.
286        specialize_cases.push(SpecializeCase {
287            id: id.to_owned(),
288            constraints,
289            size: schema.decl_size(decl.key) + schema.payload_size(decl.key),
290        });
291    }
292
293    // Create match cases for each child declaration: the union of
294    // tuple of constaint values and packet sizes that will specialize to this
295    // declaration.
296    let mut specialize_cases = Vec::new();
297    for child_decl in scope.iter_children(decl) {
298        gather_specialize_cases(
299            scope,
300            schema,
301            child_decl.id().unwrap(),
302            child_decl,
303            data_fields,
304            &HashMap::new(),
305            &mut specialize_cases,
306        )
307    }
308
309    // List the identifiers of fields constituting the
310    // discriminant tuple.
311    let ids = specialize_cases
312        .iter()
313        .flat_map(|case| case.constraints.keys())
314        .cloned()
315        .collect::<BTreeSet<String>>()
316        .into_iter()
317        .collect::<Vec<String>>();
318
319    fn make_specialize_case(ids: &[String], case: &SpecializeCase) -> Vec<Option<ConstraintValue>> {
320        ids.iter().map(|id| case.constraints.get(id).cloned()).collect::<Vec<_>>()
321    }
322
323    fn check_specialize_cases(
324        ids: &[String],
325        with_size: bool,
326        specialize_cases: &[SpecializeCase],
327    ) -> Result<(), String> {
328        // Check unicity of constraints.
329        let mut grouped_cases = HashMap::new();
330        for case in specialize_cases {
331            let constraints = make_specialize_case(ids, case);
332            match grouped_cases.insert(
333                (constraints, if with_size { case.size } else { analyzer::Size::Unknown }),
334                case.id.clone(),
335            ) {
336                Some(id) if id != case.id => {
337                    return Err(format!("{} and {} cannot be disambiguated", id, case.id))
338                }
339                _ => (),
340            }
341        }
342
343        Ok(())
344    }
345
346    // Check if constraints are un-amiguous, and whether the packet size
347    // is required to disambiguate.
348    // TODO(henrichataing) ambiguities should be resolved by trying each
349    // case until one is successfully parsed.
350    check_specialize_cases(&ids, true, &specialize_cases)?;
351    let with_size = check_specialize_cases(&ids, false, &specialize_cases).is_err();
352
353    // Finally group match cases by matching child declaration.
354    let mut grouped_cases = BTreeMap::new();
355    for case in specialize_cases {
356        let constraints = make_specialize_case(&ids, &case);
357        let size = if with_size { case.size } else { analyzer::Size::Unknown };
358        if constraints.iter().any(Option::is_some) || size != analyzer::Size::Unknown {
359            grouped_cases
360                .entry(case.id.clone())
361                .or_insert(BTreeSet::new())
362                .insert((constraints, size));
363        }
364    }
365
366    // Build the case values and case branches.
367    // The case are ordered by child declaration order.
368    let mut case_values = vec![];
369    let mut case_ids = vec![];
370    let child_name = format_ident!("{id}Child");
371
372    for (id, cases) in grouped_cases {
373        case_ids.push(format_ident!("{id}"));
374        case_values.push(
375            cases
376                .iter()
377                .map(|(constraints, size)| {
378                    let mut case = constraints
379                        .iter()
380                        .map(|v| match v {
381                            Some(v) => quote!(#v),
382                            None => quote!(_),
383                        })
384                        .collect::<Vec<_>>();
385                    if with_size {
386                        case.push(match size {
387                            analyzer::Size::Static(s) => {
388                                let s = proc_macro2::Literal::usize_unsuffixed(s / 8);
389                                quote!(#s)
390                            }
391                            _ => quote!(_),
392                        });
393                    }
394                    case
395                })
396                .collect::<Vec<_>>(),
397        );
398    }
399
400    let mut field_values = ids
401        .iter()
402        .map(|id| {
403            let id = id.to_ident();
404            quote!(self.#id)
405        })
406        .collect::<Vec<_>>();
407    if with_size {
408        field_values.push(quote!(self.payload.len()));
409    }
410
411    // TODO(henrichataing) the default case is necessary only if the match
412    // is non-exhaustive.
413    Ok(quote! {
414        pub fn specialize(&self) -> Result<#child_name, DecodeError> {
415            Ok(
416                match ( #( #field_values ),* ) {
417                    #( #( ( #( #case_values ),* ) )|* =>
418                        #child_name::#case_ids(self.try_into()?), )*
419                    _ => #child_name::None,
420                }
421            )
422        }
423    })
424}
425
426/// Generate code for a root packet declaration.
427///
428/// # Arguments
429/// * `endianness` - File endianness
430/// * `id` - Packet identifier.
431fn generate_root_packet_decl(
432    scope: &analyzer::Scope<'_>,
433    schema: &analyzer::Schema,
434    endianness: ast::EndiannessValue,
435    id: &str,
436) -> proc_macro2::TokenStream {
437    let decl = scope.typedef[id];
438    let name = id.to_ident();
439    let child_name = format_ident!("{id}Child");
440
441    // Return the list of fields that will appear in the generated
442    // rust structs (<Packet> and <Packet>Builder).
443    // The payload field _ if declared _ is handled separately.
444    let data_fields = packet_data_fields(scope, decl);
445    let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
446    let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
447    let data_field_borrows = data_fields
448        .iter()
449        .map(|f| {
450            if implements_copy(scope, f) {
451                quote! {}
452            } else {
453                quote! { & }
454            }
455        })
456        .collect::<Vec<_>>();
457    let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
458    let payload_accessor =
459        decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
460
461    let parser_span = format_ident!("buf");
462    let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
463    for field in decl.fields() {
464        field_parser.add(field);
465    }
466
467    // For the implementation of decode_partial, sort the data field identifiers
468    // between parsed fields (extracted from the payload), and copied fields
469    // (copied from the parent).
470    let mut parsed_field_ids = vec![];
471    if decl.payload().is_some() {
472        parsed_field_ids.push(format_ident!("payload"));
473    }
474    for f in &data_fields {
475        let id = f.id().unwrap().to_ident();
476        parsed_field_ids.push(id);
477    }
478
479    let (encode_fields, encoded_len) =
480        encoder::encode(scope, schema, endianness, "buf".to_ident(), decl);
481
482    let encode = quote! {
483         fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
484            #encode_fields
485            Ok(())
486        }
487    };
488
489    // Compute the encoded length of the packet.
490    let encoded_len = quote! {
491        fn encoded_len(&self) -> usize {
492            #encoded_len
493        }
494    };
495
496    // The implementation of decode for root packets contains the full
497    // parser implementation.
498    let decode = quote! {
499       fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
500           #field_parser
501           Ok((Self { #( #parsed_field_ids, )* }, buf))
502       }
503    };
504
505    // Provide the implementation of the enum listing child declarations of the
506    // current declaration. This enum is only provided for declarations that
507    // have child packets.
508    let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
509    let child_struct = (!children_decl.is_empty()).then(|| {
510        let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
511        quote! {
512            #[derive(Debug, Clone, PartialEq, Eq)]
513            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
514            pub enum #child_name {
515                #( #children_ids(#children_ids), )*
516                None,
517            }
518        }
519    });
520
521    // Provide the implementation of the specialization function.
522    // The specialization function is only provided for declarations that have
523    // child packets.
524    let specialize = (!children_decl.is_empty())
525        .then(|| generate_specialize_impl(scope, schema, decl, id, &data_fields).unwrap());
526
527    quote! {
528        #[derive(Debug, Clone, PartialEq, Eq)]
529        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
530        pub struct #name {
531            #( pub #data_field_ids: #data_field_types, )*
532            #payload_field
533        }
534
535        #child_struct
536
537        impl #name {
538            #specialize
539            #payload_accessor
540
541            #(
542            pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
543                #data_field_borrows self.#data_field_ids
544            }
545            )*
546        }
547
548        impl Packet for #name {
549            #encoded_len
550            #encode
551            #decode
552        }
553    }
554}
555
556/// Generate code for a derived packet declaration
557///
558/// # Arguments
559/// * `endianness` - File endianness
560/// * `id` - Packet identifier.
561fn generate_derived_packet_decl(
562    scope: &analyzer::Scope<'_>,
563    schema: &analyzer::Schema,
564    endianness: ast::EndiannessValue,
565    id: &str,
566) -> proc_macro2::TokenStream {
567    let decl = scope.typedef[id];
568    let name = id.to_ident();
569    let parent_decl = scope.get_parent(decl).unwrap();
570    let parent_name = parent_decl.id().unwrap().to_ident();
571    let child_name = format_ident!("{id}Child");
572
573    // Extract all constraint values from the parent declarations.
574    let all_constraints = HashMap::<String, _>::from_iter(
575        scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
576    );
577
578    let all_fields = scope.iter_fields(decl).collect::<Vec<_>>();
579
580    // Return the list of fields that will appear in the generated
581    // rust structs (<Packet> and <Packet>Builder).
582    // The payload field _ if declared _ is handled separately.
583    let data_fields = packet_data_fields(scope, decl);
584    let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
585    let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
586    let data_field_borrows = data_fields
587        .iter()
588        .map(|f| {
589            if implements_copy(scope, f) {
590                quote! {}
591            } else {
592                quote! { & }
593            }
594        })
595        .collect::<Vec<_>>();
596    let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
597    let payload_accessor =
598        decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
599
600    let parent_data_fields = packet_data_fields(scope, parent_decl);
601
602    // Return the list of fields that have a constant value.
603    let constant_fields = packet_constant_fields(scope, decl);
604    let constant_field_ids =
605        constant_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
606    let constant_field_types =
607        constant_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
608    let constant_field_values = constant_fields.iter().map(|f| {
609        let c = all_constraints.get(f.id().unwrap()).unwrap();
610        constraint_value(&all_fields, c)
611    });
612
613    // Generate field parsing and serialization.
614    let parser_span = format_ident!("buf");
615    let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
616    for field in decl.fields() {
617        field_parser.add(field);
618    }
619
620    // For the implementation of decode_partial, sort the data field identifiers
621    // between parsed fields (extracted from the payload), and copied fields
622    // (copied from the parent).
623    let mut parsed_field_ids = vec![];
624    let mut copied_field_ids = vec![];
625    let mut cloned_field_ids = vec![];
626    if decl.payload().is_some() {
627        parsed_field_ids.push(format_ident!("payload"));
628    }
629    for f in &data_fields {
630        let id = f.id().unwrap().to_ident();
631        if decl.fields().any(|ff| f.id() == ff.id()) {
632            parsed_field_ids.push(id);
633        } else if implements_copy(scope, f) {
634            copied_field_ids.push(id);
635        } else {
636            cloned_field_ids.push(id);
637        }
638    }
639
640    let (partial_field_serializer, field_serializer, encoded_len) =
641        encoder::encode_partial(scope, schema, endianness, "buf".to_ident(), decl);
642
643    let encode_partial = quote! {
644        pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
645            #partial_field_serializer
646            Ok(())
647        }
648    };
649
650    let encode = quote! {
651         fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
652            #field_serializer
653            Ok(())
654        }
655    };
656
657    // Compute the encoded length of the packet.
658    let encoded_len = quote! {
659        fn encoded_len(&self) -> usize {
660            #encoded_len
661        }
662    };
663
664    // Constraint checks are only run for constraints added to this declaration
665    // and not parent constraints which are expected to have been validated
666    // earlier.
667    let constraint_checks = decl.constraints().map(|c| {
668        let field_id = c.id.to_ident();
669        let field_name = &c.id;
670        let packet_name = id;
671        let value = constraint_value(&parent_data_fields, c);
672        let value_str = constraint_value_str(&parent_data_fields, c);
673        quote! {
674            if parent.#field_id() != #value {
675                return Err(DecodeError::InvalidFieldValue {
676                    packet: #packet_name,
677                    field: #field_name,
678                    expected: #value_str,
679                    actual: format!("{:?}", parent.#field_id()),
680                })
681            }
682        }
683    });
684
685    let decode_partial = if parent_decl.payload().is_some() {
686        // Generate an implementation of decode_partial that will decode
687        // data fields present in the parent payload.
688        // TODO(henrichataing) add constraint validation to decode_partial,
689        // return DecodeError::InvalidConstraint.
690        quote! {
691            fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
692                let mut buf: &[u8] = &parent.payload;
693                #( #constraint_checks )*
694                #field_parser
695                if buf.is_empty() {
696                    Ok(Self {
697                        #( #parsed_field_ids, )*
698                        #( #copied_field_ids: parent.#copied_field_ids, )*
699                        #( #cloned_field_ids: parent.#cloned_field_ids.clone(), )*
700                    })
701                } else {
702                    Err(DecodeError::TrailingBytes)
703                }
704            }
705        }
706    } else {
707        // Generate an implementation of decode_partial that will only copy
708        // data fields present in the parent.
709        // TODO(henrichataing) add constraint validation to decode_partial,
710        // return DecodeError::InvalidConstraint.
711        quote! {
712            fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
713                #( #constraint_checks )*
714                Ok(Self {
715                    #( #copied_field_ids: parent.#copied_field_ids, )*
716                })
717            }
718        }
719    };
720
721    let decode =
722        // The implementation of decode for derived packets relies on
723        // the parent packet parser.
724        quote! {
725            fn decode(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
726                let (parent, trailing_bytes) = #parent_name::decode(buf)?;
727                let packet = Self::decode_partial(&parent)?;
728                Ok((packet, trailing_bytes))
729            }
730        };
731
732    // Provide the implementation of conversion helpers from
733    // the current packet to its parent packets. The implementation
734    // is explicit for the immediate parent, and derived using other
735    // Into<> implementations for the ancestors.
736    let into_parent = {
737        let parent_data_field_ids = parent_data_fields.iter().map(|f| f.id().unwrap().to_ident());
738        let parent_data_field_values = parent_data_fields.iter().map(|f| {
739            let id = f.id().unwrap().to_ident();
740            match all_constraints.get(f.id().unwrap()) {
741                Some(c) => constraint_value(&parent_data_fields, c),
742                None => quote! { packet.#id },
743            }
744        });
745        if parent_decl.payload().is_some() {
746            quote! {
747                impl TryFrom<&#name> for #parent_name {
748                    type Error = EncodeError;
749                    fn try_from(packet: &#name) -> Result<#parent_name, Self::Error> {
750                        let mut payload = Vec::new();
751                        packet.encode_partial(&mut payload)?;
752                        Ok(#parent_name {
753                            #( #parent_data_field_ids: #parent_data_field_values, )*
754                            payload,
755                        })
756                    }
757                }
758
759                impl TryFrom<#name> for #parent_name {
760                    type Error = EncodeError;
761                    fn try_from(packet: #name) -> Result<#parent_name, Self::Error> {
762                        (&packet).try_into()
763                    }
764                }
765            }
766        } else {
767            quote! {
768                impl From<&#name> for #parent_name {
769                    fn from(packet: &#name) -> #parent_name {
770                        #parent_name {
771                            #( #parent_data_field_ids: #parent_data_field_values, )*
772                        }
773                    }
774                }
775
776                impl From<#name> for #parent_name {
777                    fn from(packet: #name) -> #parent_name {
778                        (&packet).into()
779                    }
780                }
781            }
782        }
783    };
784
785    let into_ancestors = scope.iter_parents(parent_decl).map(|ancestor_decl| {
786        let ancestor_name = ancestor_decl.id().unwrap().to_ident();
787        quote! {
788            impl TryFrom<&#name> for #ancestor_name {
789                type Error = EncodeError;
790                fn try_from(packet: &#name) -> Result<#ancestor_name, Self::Error> {
791                    (&#parent_name::try_from(packet)?).try_into()
792                }
793            }
794
795            impl TryFrom<#name> for #ancestor_name {
796                type Error = EncodeError;
797                fn try_from(packet: #name) -> Result<#ancestor_name, Self::Error> {
798                    (&packet).try_into()
799                }
800            }
801        }
802    });
803
804    // Provide the implementation of conversion helper from
805    // the parent packet. This function is actually the parse
806    // implementation. This helper is provided only if the packet has a
807    // parent declaration.
808    let from_parent = quote! {
809        impl TryFrom<&#parent_name> for #name {
810            type Error = DecodeError;
811            fn try_from(parent: &#parent_name) -> Result<#name, Self::Error> {
812                #name::decode_partial(&parent)
813            }
814        }
815
816        impl TryFrom<#parent_name> for #name {
817            type Error = DecodeError;
818            fn try_from(parent: #parent_name) -> Result<#name, Self::Error> {
819                (&parent).try_into()
820            }
821        }
822    };
823
824    // Provide the implementation of conversion helpers from
825    // the ancestor packets.
826    let from_ancestors = scope.iter_parents(parent_decl).map(|ancestor_decl| {
827        let ancestor_name = ancestor_decl.id().unwrap().to_ident();
828        quote! {
829            impl TryFrom<&#ancestor_name> for #name {
830                type Error = DecodeError;
831                fn try_from(packet: &#ancestor_name) -> Result<#name, Self::Error> {
832                    (&#parent_name::try_from(packet)?).try_into()
833                }
834            }
835
836            impl TryFrom<#ancestor_name> for #name {
837                type Error = DecodeError;
838                fn try_from(packet: #ancestor_name) -> Result<#name, Self::Error> {
839                    (&packet).try_into()
840                }
841            }
842        }
843    });
844
845    // Provide the implementation of the enum listing child declarations of the
846    // current declaration. This enum is only provided for declarations that
847    // have child packets.
848    let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
849    let child_struct = (!children_decl.is_empty()).then(|| {
850        let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
851        quote! {
852            #[derive(Debug, Clone, PartialEq, Eq)]
853            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
854            pub enum #child_name {
855                #( #children_ids(#children_ids), )*
856                None,
857            }
858        }
859    });
860
861    // Provide the implementation of the specialization function.
862    // The specialization function is only provided for declarations that have
863    // child packets.
864    let specialize = (!children_decl.is_empty())
865        .then(|| generate_specialize_impl(scope, schema, decl, id, &data_fields).unwrap());
866
867    quote! {
868        #[derive(Debug, Clone, PartialEq, Eq)]
869        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
870        pub struct #name {
871            #( pub #data_field_ids: #data_field_types, )*
872            #payload_field
873        }
874
875        #into_parent
876        #from_parent
877        #( #into_ancestors )*
878        #( #from_ancestors )*
879
880        #child_struct
881
882        impl #name {
883            #specialize
884            #decode_partial
885            #encode_partial
886            #payload_accessor
887
888            #(
889            pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
890                #data_field_borrows self.#data_field_ids
891            }
892            )*
893
894            #(
895            pub fn #constant_field_ids(&self) -> #constant_field_types {
896                #constant_field_values
897            }
898            )*
899        }
900
901        impl Packet for #name {
902            #encoded_len
903            #encode
904            #decode
905        }
906    }
907}
908
909/// Generate an enum declaration.
910///
911/// # Arguments
912/// * `id` - Enum identifier.
913/// * `tags` - List of enum tags.
914/// * `width` - Width of the backing type of the enum, in bits.
915fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2::TokenStream {
916    // Determine if the enum is open, i.e. a default tag is defined.
917    fn enum_default_tag(tags: &[ast::Tag]) -> Option<ast::TagOther> {
918        tags.iter()
919            .filter_map(|tag| match tag {
920                ast::Tag::Other(tag) => Some(tag.clone()),
921                _ => None,
922            })
923            .next()
924    }
925
926    // Determine if the enum is complete, i.e. all values in the backing
927    // integer range have a matching tag in the original declaration.
928    fn enum_is_complete(tags: &[ast::Tag], max: usize) -> bool {
929        let mut ranges = tags
930            .iter()
931            .filter_map(|tag| match tag {
932                ast::Tag::Value(tag) => Some((tag.value, tag.value)),
933                ast::Tag::Range(tag) => Some(tag.range.clone().into_inner()),
934                _ => None,
935            })
936            .collect::<Vec<_>>();
937        ranges.sort_unstable();
938        ranges.first().unwrap().0 == 0
939            && ranges.last().unwrap().1 == max
940            && ranges.windows(2).all(|window| {
941                if let [left, right] = window {
942                    left.1 == right.0 - 1
943                } else {
944                    false
945                }
946            })
947    }
948
949    // Determine if the enum is primitive, i.e. does not contain any tag range.
950    fn enum_is_primitive(tags: &[ast::Tag]) -> bool {
951        tags.iter().all(|tag| matches!(tag, ast::Tag::Value(_)))
952    }
953
954    // Return the maximum value for the scalar type.
955    fn scalar_max(width: usize) -> usize {
956        if width >= usize::BITS as usize {
957            usize::MAX
958        } else {
959            (1 << width) - 1
960        }
961    }
962
963    // Format an enum tag identifier to rust upper caml case.
964    fn format_tag_ident(id: &str) -> proc_macro2::TokenStream {
965        let id = format_ident!("{}", id.to_upper_camel_case());
966        quote! { #id }
967    }
968
969    // Format a constant value as hexadecimal constant.
970    fn format_value(value: usize) -> LitInt {
971        syn::parse_str::<syn::LitInt>(&format!("{:#x}", value)).unwrap()
972    }
973
974    // Backing type for the enum.
975    let backing_type = types::Integer::new(width);
976    let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
977    let range_max = scalar_max(width);
978    let default_tag = enum_default_tag(tags);
979    let is_open = default_tag.is_some();
980    let is_complete = enum_is_complete(tags, scalar_max(width));
981    let is_primitive = enum_is_primitive(tags);
982    let name = id.to_ident();
983
984    // Generate the variant cases for the enum declaration.
985    // Tags declared in ranges are flattened in the same declaration.
986    let use_variant_values = is_primitive && (is_complete || !is_open);
987    let repr_u64 = use_variant_values.then(|| quote! { #[repr(u64)] });
988    let mut variants = vec![];
989    for tag in tags.iter() {
990        match tag {
991            ast::Tag::Value(tag) if use_variant_values => {
992                let id = format_tag_ident(&tag.id);
993                let value = format_value(tag.value);
994                variants.push(quote! { #id = #value })
995            }
996            ast::Tag::Value(tag) => variants.push(format_tag_ident(&tag.id)),
997            ast::Tag::Range(tag) => {
998                variants.extend(tag.tags.iter().map(|tag| format_tag_ident(&tag.id)));
999                let id = format_tag_ident(&tag.id);
1000                variants.push(quote! { #id(Private<#backing_type>) })
1001            }
1002            ast::Tag::Other(_) => (),
1003        }
1004    }
1005
1006    // Generate the cases for parsing the enum value from an integer.
1007    let mut from_cases = vec![];
1008    for tag in tags.iter() {
1009        match tag {
1010            ast::Tag::Value(tag) => {
1011                let id = format_tag_ident(&tag.id);
1012                let value = format_value(tag.value);
1013                from_cases.push(quote! { #value => Ok(#name::#id) })
1014            }
1015            ast::Tag::Range(tag) => {
1016                from_cases.extend(tag.tags.iter().map(|tag| {
1017                    let id = format_tag_ident(&tag.id);
1018                    let value = format_value(tag.value);
1019                    quote! { #value => Ok(#name::#id) }
1020                }));
1021                let id = format_tag_ident(&tag.id);
1022                let start = format_value(*tag.range.start());
1023                let end = format_value(*tag.range.end());
1024                from_cases.push(quote! { #start ..= #end => Ok(#name::#id(Private(value))) })
1025            }
1026            ast::Tag::Other(_) => (),
1027        }
1028    }
1029
1030    // Generate the cases for serializing the enum value to an integer.
1031    let mut into_cases = vec![];
1032    for tag in tags.iter() {
1033        match tag {
1034            ast::Tag::Value(tag) => {
1035                let id = format_tag_ident(&tag.id);
1036                let value = format_value(tag.value);
1037                into_cases.push(quote! { #name::#id => #value })
1038            }
1039            ast::Tag::Range(tag) => {
1040                into_cases.extend(tag.tags.iter().map(|tag| {
1041                    let id = format_tag_ident(&tag.id);
1042                    let value = format_value(tag.value);
1043                    quote! { #name::#id => #value }
1044                }));
1045                let id = format_tag_ident(&tag.id);
1046                into_cases.push(quote! { #name::#id(Private(value)) => *value })
1047            }
1048            ast::Tag::Other(_) => (),
1049        }
1050    }
1051
1052    // Generate a default case if the enum is open and incomplete.
1053    if !is_complete && is_open {
1054        let unknown_id = format_tag_ident(&default_tag.unwrap().id);
1055        let range_max = format_value(range_max);
1056        variants.push(quote! { #unknown_id(Private<#backing_type>) });
1057        from_cases.push(quote! { 0..=#range_max => Ok(#name::#unknown_id(Private(value))) });
1058        into_cases.push(quote! { #name::#unknown_id(Private(value)) => *value });
1059    }
1060
1061    // Generate an error case if the enum size is lower than the backing
1062    // type size, or if the enum is closed or incomplete.
1063    if backing_type.width != width || (!is_complete && !is_open) {
1064        from_cases.push(quote! { _ => Err(value) });
1065    }
1066
1067    // Derive other Into<uN> and Into<iN> implementations from the explicit
1068    // implementation, where the type is larger than the backing type.
1069    let derived_signed_into_types = [8, 16, 32, 64]
1070        .into_iter()
1071        .filter(|w| *w > width)
1072        .map(|w| syn::parse_str::<syn::Type>(&format!("i{}", w)).unwrap());
1073    let derived_unsigned_into_types = [8, 16, 32, 64]
1074        .into_iter()
1075        .filter(|w| *w >= width && *w != backing_type.width)
1076        .map(|w| syn::parse_str::<syn::Type>(&format!("u{}", w)).unwrap());
1077    let derived_into_types = derived_signed_into_types.chain(derived_unsigned_into_types);
1078
1079    quote! {
1080        #repr_u64
1081        #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
1082        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1083        #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
1084        pub enum #name {
1085            #(#variants,)*
1086        }
1087
1088        impl TryFrom<#backing_type> for #name {
1089            type Error = #backing_type;
1090            fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
1091                match value {
1092                    #(#from_cases,)*
1093                }
1094            }
1095        }
1096
1097        impl From<&#name> for #backing_type {
1098            fn from(value: &#name) -> Self {
1099                match value {
1100                    #(#into_cases,)*
1101                }
1102            }
1103        }
1104
1105        impl From<#name> for #backing_type {
1106            fn from(value: #name) -> Self {
1107                (&value).into()
1108            }
1109        }
1110
1111        #(impl From<#name> for #derived_into_types {
1112            fn from(value: #name) -> Self {
1113                #backing_type::from(value) as Self
1114            }
1115        })*
1116    }
1117}
1118
1119/// Generate the declaration for a custom field of static size.
1120///
1121/// * `id` - Enum identifier.
1122/// * `width` - Width of the backing type of the enum, in bits.
1123fn generate_custom_field_decl(
1124    endianness: ast::EndiannessValue,
1125    id: &str,
1126    width: usize,
1127) -> proc_macro2::TokenStream {
1128    let name = id;
1129    let id = id.to_ident();
1130    let backing_type = types::Integer::new(width);
1131    let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
1132    let max_value = mask_bits(width, &format!("u{}", backing_type.width));
1133    let size = proc_macro2::Literal::usize_unsuffixed(width / 8);
1134
1135    let read_value = types::get_uint(endianness, width, &format_ident!("buf"));
1136    let read_value = if [8, 16, 32, 64].contains(&width) {
1137        quote! { #read_value.into() }
1138    } else {
1139        // The value is masked when read, and the conversion must succeed.
1140        quote! { (#read_value).try_into().unwrap() }
1141    };
1142
1143    let write_value = types::put_uint(
1144        endianness,
1145        &quote! { #backing_type::from(self) },
1146        width,
1147        &format_ident!("buf"),
1148    );
1149
1150    let common = quote! {
1151        impl From<&#id> for #backing_type {
1152            fn from(value: &#id) -> #backing_type {
1153                value.0
1154            }
1155        }
1156
1157        impl From<#id> for #backing_type {
1158            fn from(value: #id) -> #backing_type {
1159                value.0
1160            }
1161        }
1162
1163        impl Packet for #id {
1164            fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
1165                if buf.len() < #size {
1166                    return Err(DecodeError::InvalidLengthError {
1167                        obj: #name,
1168                        wanted: #size,
1169                        got: buf.len(),
1170                    })
1171                }
1172
1173                Ok((#read_value, buf))
1174            }
1175
1176            fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
1177                #write_value;
1178                Ok(())
1179            }
1180
1181            fn encoded_len(&self) -> usize {
1182                #size
1183            }
1184        }
1185    };
1186
1187    if backing_type.width == width {
1188        quote! {
1189            #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
1190            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1191            #[cfg_attr(feature = "serde", serde(from = #backing_type_str, into = #backing_type_str))]
1192            pub struct #id(#backing_type);
1193
1194            #common
1195
1196            impl From<#backing_type> for #id {
1197                fn from(value: #backing_type) -> Self {
1198                    #id(value)
1199                }
1200            }
1201        }
1202    } else {
1203        quote! {
1204            #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
1205            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1206            #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
1207            pub struct #id(#backing_type);
1208
1209            #common
1210
1211            impl TryFrom<#backing_type> for #id {
1212                type Error = #backing_type;
1213                fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
1214                    if value > #max_value {
1215                        Err(value)
1216                    } else {
1217                        Ok(#id(value))
1218                    }
1219                }
1220            }
1221        }
1222    }
1223}
1224
1225fn generate_decl(
1226    scope: &analyzer::Scope<'_>,
1227    schema: &analyzer::Schema,
1228    file: &ast::File,
1229    decl: &ast::Decl,
1230) -> proc_macro2::TokenStream {
1231    match &decl.desc {
1232        ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
1233            match scope.get_parent(decl) {
1234                None => generate_root_packet_decl(scope, schema, file.endianness.value, id),
1235                Some(_) => generate_derived_packet_decl(scope, schema, file.endianness.value, id),
1236            }
1237        }
1238        ast::DeclDesc::Enum { id, tags, width } => generate_enum_decl(id, tags, *width),
1239        ast::DeclDesc::CustomField { id, width: Some(width), .. } => {
1240            generate_custom_field_decl(file.endianness.value, id, *width)
1241        }
1242        ast::DeclDesc::CustomField { .. } => {
1243            // No need to generate anything for a custom field,
1244            // we just assume it will be in scope.
1245            quote!()
1246        }
1247        _ => todo!("unsupported Decl::{:?}", decl),
1248    }
1249}
1250
1251/// Generate Rust code from an AST.
1252///
1253/// The code is not formatted, pipe it through `rustfmt` to get
1254/// readable source code.
1255pub fn generate_tokens(
1256    sources: &ast::SourceDatabase,
1257    file: &ast::File,
1258    custom_fields: &[String],
1259) -> proc_macro2::TokenStream {
1260    let source = sources.get(file.file).expect("could not read source");
1261    let preamble = preamble::generate(Path::new(source.name()));
1262    let scope = analyzer::Scope::new(file).expect("could not create scope");
1263    let schema = analyzer::Schema::new(file);
1264    let custom_fields = custom_fields.iter().map(|custom_field| {
1265        syn::parse_str::<syn::Path>(custom_field)
1266            .unwrap_or_else(|err| panic!("invalid path '{custom_field}': {err:?}"))
1267    });
1268    let decls = file.declarations.iter().map(|decl| generate_decl(&scope, &schema, file, decl));
1269    quote! {
1270        #preamble
1271        #(use #custom_fields;)*
1272
1273        #(#decls)*
1274    }
1275}
1276
1277/// Generate formatted Rust code from an AST.
1278///
1279/// The code is not formatted, pipe it through `rustfmt` to get
1280/// readable source code.
1281pub fn generate(
1282    sources: &ast::SourceDatabase,
1283    file: &ast::File,
1284    custom_fields: &[String],
1285) -> String {
1286    let syntax_tree =
1287        syn::parse2(generate_tokens(sources, file, custom_fields)).expect("Could not parse code");
1288    prettyplease::unparse(&syntax_tree)
1289}
1290
1291#[cfg(test)]
1292mod tests {
1293    use super::*;
1294    use crate::analyzer;
1295    use crate::ast;
1296    use crate::parser::parse_inline;
1297    use crate::test_utils::{assert_snapshot_eq, format_rust};
1298    use paste::paste;
1299
1300    /// Create a unit test for the given PDL `code`.
1301    ///
1302    /// The unit test will compare the generated Rust code for all
1303    /// declarations with previously saved snapshots. The snapshots
1304    /// are read from `"tests/generated/{name}_{endianness}_{id}.rs"`
1305    /// where `is` taken from the declaration.
1306    ///
1307    /// When adding new tests or modifying existing ones, use
1308    /// `UPDATE_SNAPSHOTS=1 cargo test` to automatically populate the
1309    /// snapshots with the expected output.
1310    ///
1311    /// The `code` cannot have an endianness declaration, instead you
1312    /// must supply either `little_endian` or `big_endian` as
1313    /// `endianness`.
1314    macro_rules! make_pdl_test {
1315        ($name:ident, $code:expr, $endianness:ident) => {
1316            paste! {
1317                #[test]
1318                fn [< test_ $name _ $endianness >]() {
1319                    let name = stringify!($name);
1320                    let endianness = stringify!($endianness);
1321                    let code = format!("{endianness}_packets\n{}", $code);
1322                    let mut db = ast::SourceDatabase::new();
1323                    let file = parse_inline(&mut db, "test", code).unwrap();
1324                    let file = analyzer::analyze(&file).unwrap();
1325                    let actual_code = generate(&db, &file, &[]);
1326                    assert_snapshot_eq(
1327                        &format!("tests/generated/rust/{name}_{endianness}.rs"),
1328                        &format_rust(&actual_code),
1329                    );
1330                }
1331            }
1332        };
1333    }
1334
1335    /// Create little- and bit-endian tests for the given PDL `code`.
1336    ///
1337    /// The `code` cannot have an endianness declaration: we will
1338    /// automatically generate unit tests for both
1339    /// "little_endian_packets" and "big_endian_packets".
1340    macro_rules! test_pdl {
1341        ($name:ident, $code:expr $(,)?) => {
1342            make_pdl_test!($name, $code, little_endian);
1343            make_pdl_test!($name, $code, big_endian);
1344        };
1345    }
1346
1347    test_pdl!(packet_decl_empty, "packet Foo {}");
1348
1349    test_pdl!(packet_decl_8bit_scalar, " packet Foo { x:  8 }");
1350    test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
1351    test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");
1352
1353    test_pdl!(
1354        enum_declaration,
1355        r#"
1356        enum IncompleteTruncatedClosed : 3 {
1357            A = 0,
1358            B = 1,
1359        }
1360
1361        enum IncompleteTruncatedOpen : 3 {
1362            A = 0,
1363            B = 1,
1364            UNKNOWN = ..
1365        }
1366
1367        enum IncompleteTruncatedClosedWithRange : 3 {
1368            A = 0,
1369            B = 1..6 {
1370                X = 1,
1371                Y = 2,
1372            }
1373        }
1374
1375        enum IncompleteTruncatedOpenWithRange : 3 {
1376            A = 0,
1377            B = 1..6 {
1378                X = 1,
1379                Y = 2,
1380            },
1381            UNKNOWN = ..
1382        }
1383
1384        enum CompleteTruncated : 3 {
1385            A = 0,
1386            B = 1,
1387            C = 2,
1388            D = 3,
1389            E = 4,
1390            F = 5,
1391            G = 6,
1392            H = 7,
1393        }
1394
1395        enum CompleteTruncatedWithRange : 3 {
1396            A = 0,
1397            B = 1..7 {
1398                X = 1,
1399                Y = 2,
1400            }
1401        }
1402
1403        enum CompleteWithRange : 8 {
1404            A = 0,
1405            B = 1,
1406            C = 2..255,
1407        }
1408        "#
1409    );
1410
1411    test_pdl!(
1412        custom_field_declaration,
1413        r#"
1414        // Still unsupported.
1415        // custom_field Dynamic "dynamic"
1416
1417        // Should generate a type with From<u32> implementation.
1418        custom_field ExactSize : 32 "exact_size"
1419
1420        // Should generate a type with TryFrom<u32> implementation.
1421        custom_field TruncatedSize : 24 "truncated_size"
1422        "#
1423    );
1424
1425    test_pdl!(
1426        packet_decl_simple_scalars,
1427        r#"
1428          packet Foo {
1429            x: 8,
1430            y: 16,
1431            z: 24,
1432          }
1433        "#
1434    );
1435
1436    test_pdl!(
1437        packet_decl_complex_scalars,
1438        r#"
1439          packet Foo {
1440            a: 3,
1441            b: 8,
1442            c: 5,
1443            d: 24,
1444            e: 12,
1445            f: 4,
1446          }
1447        "#,
1448    );
1449
1450    // Test that we correctly mask a byte-sized value in the middle of
1451    // a chunk.
1452    test_pdl!(
1453        packet_decl_mask_scalar_value,
1454        r#"
1455          packet Foo {
1456            a: 2,
1457            b: 24,
1458            c: 6,
1459          }
1460        "#,
1461    );
1462
1463    test_pdl!(
1464        struct_decl_complex_scalars,
1465        r#"
1466          struct Foo {
1467            a: 3,
1468            b: 8,
1469            c: 5,
1470            d: 24,
1471            e: 12,
1472            f: 4,
1473          }
1474        "#,
1475    );
1476
1477    test_pdl!(packet_decl_8bit_enum, " enum Foo :  8 { A = 1, B = 2 } packet Bar { x: Foo }");
1478    test_pdl!(packet_decl_24bit_enum, "enum Foo : 24 { A = 1, B = 2 } packet Bar { x: Foo }");
1479    test_pdl!(packet_decl_64bit_enum, "enum Foo : 64 { A = 1, B = 2 } packet Bar { x: Foo }");
1480
1481    test_pdl!(
1482        packet_decl_mixed_scalars_enums,
1483        "
1484          enum Enum7 : 7 {
1485            A = 1,
1486            B = 2,
1487          }
1488
1489          enum Enum9 : 9 {
1490            A = 1,
1491            B = 2,
1492          }
1493
1494          packet Foo {
1495            x: Enum7,
1496            y: 5,
1497            z: Enum9,
1498            w: 3,
1499          }
1500        "
1501    );
1502
1503    test_pdl!(packet_decl_8bit_scalar_array, " packet Foo { x:  8[3] }");
1504    test_pdl!(packet_decl_24bit_scalar_array, "packet Foo { x: 24[5] }");
1505    test_pdl!(packet_decl_64bit_scalar_array, "packet Foo { x: 64[7] }");
1506
1507    test_pdl!(
1508        packet_decl_8bit_enum_array,
1509        "enum Foo :  8 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[3] }"
1510    );
1511    test_pdl!(
1512        packet_decl_24bit_enum_array,
1513        "enum Foo : 24 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[5] }"
1514    );
1515    test_pdl!(
1516        packet_decl_64bit_enum_array,
1517        "enum Foo : 64 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[7] }"
1518    );
1519
1520    test_pdl!(
1521        packet_decl_array_dynamic_count,
1522        "
1523          packet Foo {
1524            _count_(x): 5,
1525            padding: 3,
1526            x: 24[]
1527          }
1528        "
1529    );
1530
1531    test_pdl!(
1532        packet_decl_array_dynamic_size,
1533        "
1534          packet Foo {
1535            _size_(x): 5,
1536            padding: 3,
1537            x: 24[]
1538          }
1539        "
1540    );
1541
1542    test_pdl!(
1543        packet_decl_array_unknown_element_width_dynamic_size,
1544        "
1545          struct Foo {
1546            _count_(a): 40,
1547            a: 16[],
1548          }
1549
1550          packet Bar {
1551            _size_(x): 40,
1552            x: Foo[],
1553          }
1554        "
1555    );
1556
1557    test_pdl!(
1558        packet_decl_array_unknown_element_width_dynamic_count,
1559        "
1560          struct Foo {
1561            _count_(a): 40,
1562            a: 16[],
1563          }
1564
1565          packet Bar {
1566            _count_(x): 40,
1567            x: Foo[],
1568          }
1569        "
1570    );
1571
1572    test_pdl!(
1573        packet_decl_array_with_padding,
1574        "
1575          struct Foo {
1576            _count_(a): 40,
1577            a: 16[],
1578          }
1579
1580          packet Bar {
1581            a: Foo[],
1582            _padding_ [128],
1583          }
1584        "
1585    );
1586
1587    test_pdl!(
1588        packet_decl_array_dynamic_element_size,
1589        "
1590          struct Foo {
1591            inner: 8[]
1592          }
1593          packet Bar {
1594            _elementsize_(x): 5,
1595            padding: 3,
1596            x: Foo[]
1597          }
1598        "
1599    );
1600
1601    test_pdl!(
1602        packet_decl_array_dynamic_element_size_dynamic_size,
1603        "
1604          struct Foo {
1605            inner: 8[]
1606          }
1607          packet Bar {
1608            _size_(x): 4,
1609            _elementsize_(x): 4,
1610            x: Foo[]
1611          }
1612        "
1613    );
1614
1615    test_pdl!(
1616        packet_decl_array_dynamic_element_size_dynamic_count,
1617        "
1618          struct Foo {
1619            inner: 8[]
1620          }
1621          packet Bar {
1622            _count_(x): 4,
1623            _elementsize_(x): 4,
1624            x: Foo[]
1625          }
1626        "
1627    );
1628
1629    test_pdl!(
1630        packet_decl_array_dynamic_element_size_static_count,
1631        "
1632          struct Foo {
1633            inner: 8[]
1634          }
1635          packet Bar {
1636            _elementsize_(x): 5,
1637            padding: 3,
1638            x: Foo[4]
1639          }
1640        "
1641    );
1642
1643    test_pdl!(
1644        packet_decl_array_dynamic_element_size_static_count_1,
1645        "
1646          struct Foo {
1647            inner: 8[]
1648          }
1649          packet Bar {
1650            _elementsize_(x): 5,
1651            padding: 3,
1652            x: Foo[1]
1653          }
1654        "
1655    );
1656
1657    test_pdl!(
1658        packet_decl_reserved_field,
1659        "
1660          packet Foo {
1661            _reserved_: 40,
1662          }
1663        "
1664    );
1665
1666    test_pdl!(
1667        packet_decl_custom_field,
1668        r#"
1669          custom_field Bar1 : 24 "exact"
1670          custom_field Bar2 : 32 "truncated"
1671
1672          packet Foo {
1673            a: Bar1,
1674            b: Bar2,
1675          }
1676        "#
1677    );
1678
1679    test_pdl!(
1680        packet_decl_fixed_scalar_field,
1681        "
1682          packet Foo {
1683            _fixed_ = 7 : 7,
1684            b: 57,
1685          }
1686        "
1687    );
1688
1689    test_pdl!(
1690        packet_decl_fixed_enum_field,
1691        "
1692          enum Enum7 : 7 {
1693            A = 1,
1694            B = 2,
1695          }
1696
1697          packet Foo {
1698              _fixed_ = A : Enum7,
1699              b: 57,
1700          }
1701        "
1702    );
1703
1704    test_pdl!(
1705        packet_decl_payload_field_variable_size,
1706        "
1707          packet Foo {
1708              a: 8,
1709              _size_(_payload_): 8,
1710              _payload_,
1711              b: 16,
1712          }
1713        "
1714    );
1715
1716    test_pdl!(
1717        packet_decl_payload_field_unknown_size,
1718        "
1719          packet Foo {
1720              a: 24,
1721              _payload_,
1722          }
1723        "
1724    );
1725
1726    test_pdl!(
1727        packet_decl_payload_field_unknown_size_terminal,
1728        "
1729          packet Foo {
1730              _payload_,
1731              a: 24,
1732          }
1733        "
1734    );
1735
1736    test_pdl!(
1737        packet_decl_child_packets,
1738        "
1739          enum Enum16 : 16 {
1740            A = 1,
1741            B = 2,
1742          }
1743
1744          packet Foo {
1745              a: 8,
1746              b: Enum16,
1747              _size_(_payload_): 8,
1748              _payload_
1749          }
1750
1751          packet Bar : Foo (a = 100) {
1752              x: 8,
1753          }
1754
1755          packet Baz : Foo (b = B) {
1756              y: 16,
1757          }
1758        "
1759    );
1760
1761    test_pdl!(
1762        packet_decl_grand_children,
1763        "
1764          enum Enum16 : 16 {
1765            A = 1,
1766            B = 2,
1767          }
1768
1769          packet Parent {
1770              foo: Enum16,
1771              bar: Enum16,
1772              baz: Enum16,
1773              _size_(_payload_): 8,
1774              _payload_
1775          }
1776
1777          packet Child : Parent (foo = A) {
1778              quux: Enum16,
1779              _payload_,
1780          }
1781
1782          packet GrandChild : Child (bar = A, quux = A) {
1783              _body_,
1784          }
1785
1786          packet GrandGrandChild : GrandChild (baz = A) {
1787              _body_,
1788          }
1789        "
1790    );
1791
1792    test_pdl!(
1793        packet_decl_parent_with_no_payload,
1794        "
1795          enum Enum8 : 8 {
1796            A = 0,
1797          }
1798
1799          packet Parent {
1800            v : Enum8,
1801          }
1802
1803          packet Child : Parent (v = A) {
1804          }
1805        "
1806    );
1807
1808    test_pdl!(
1809        packet_decl_parent_with_alias_child,
1810        "
1811          enum Enum8 : 8 {
1812            A = 0,
1813            B = 1,
1814            C = 2,
1815          }
1816
1817          packet Parent {
1818            v : Enum8,
1819            _payload_,
1820          }
1821
1822          packet AliasChild : Parent {
1823            _payload_
1824          }
1825
1826          packet NormalChild : Parent (v = A) {
1827          }
1828
1829          packet NormalGrandChild1 : AliasChild (v = B) {
1830          }
1831
1832          packet NormalGrandChild2 : AliasChild (v = C) {
1833              _payload_
1834          }
1835        "
1836    );
1837
1838    test_pdl!(
1839        reserved_identifier,
1840        "
1841          packet Test {
1842            type: 8,
1843          }
1844        "
1845    );
1846
1847    test_pdl!(
1848        payload_with_size_modifier,
1849        "
1850        packet Test {
1851            _size_(_payload_): 8,
1852            _payload_ : [+1],
1853        }
1854        "
1855    );
1856
1857    test_pdl!(
1858        struct_decl_child_structs,
1859        "
1860          enum Enum16 : 16 {
1861            A = 1,
1862            B = 2,
1863          }
1864
1865          struct Foo {
1866              a: 8,
1867              b: Enum16,
1868              _size_(_payload_): 8,
1869              _payload_
1870          }
1871
1872          struct Bar : Foo (a = 100) {
1873              x: 8,
1874          }
1875
1876          struct Baz : Foo (b = B) {
1877              y: 16,
1878          }
1879        "
1880    );
1881
1882    test_pdl!(
1883        struct_decl_grand_children,
1884        "
1885          enum Enum16 : 16 {
1886            A = 1,
1887            B = 2,
1888          }
1889
1890          struct Parent {
1891              foo: Enum16,
1892              bar: Enum16,
1893              baz: Enum16,
1894              _size_(_payload_): 8,
1895              _payload_
1896          }
1897
1898          struct Child : Parent (foo = A) {
1899              quux: Enum16,
1900              _payload_,
1901          }
1902
1903          struct GrandChild : Child (bar = A, quux = A) {
1904              _body_,
1905          }
1906
1907          struct GrandGrandChild : GrandChild (baz = A) {
1908              _body_,
1909          }
1910        "
1911    );
1912}