1use crate::{analyzer, ast};
18use quote::{format_ident, quote};
19use std::collections::BTreeSet;
20use std::collections::HashMap;
21use std::path::Path;
22use syn::LitInt;
23
24mod decoder;
25mod encoder;
26mod preamble;
27pub mod test;
28mod types;
29
30use decoder::FieldParser;
31pub use heck::ToUpperCamelCase;
32
33pub trait ToIdent {
34 fn to_ident(self) -> proc_macro2::Ident;
37}
38
39impl ToIdent for &'_ str {
40 fn to_ident(self) -> proc_macro2::Ident {
41 match self {
42 "as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
43 | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod"
44 | "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self" | "static" | "struct"
45 | "super" | "trait" | "true" | "type" | "unsafe" | "use" | "where" | "while"
46 | "async" | "await" | "dyn" | "abstract" | "become" | "box" | "do" | "final"
47 | "macro" | "override" | "priv" | "typeof" | "unsized" | "virtual" | "yield"
48 | "try" => format_ident!("r#{}", self),
49 _ => format_ident!("{}", self),
50 }
51 }
52}
53
54pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
61 let suffix = if n > 31 { format!("_{suffix}") } else { String::new() };
62 let hex_digits = format!("{:x}", (1u64 << n) - 1)
64 .as_bytes()
65 .rchunks(4)
66 .rev()
67 .map(|chunk| std::str::from_utf8(chunk).unwrap())
68 .collect::<Vec<&str>>()
69 .join("_");
70 syn::parse_str::<syn::LitInt>(&format!("0x{hex_digits}{suffix}")).unwrap()
71}
72
73fn packet_data_fields<'a>(
84 scope: &'a analyzer::Scope<'a>,
85 decl: &'a ast::Decl,
86) -> Vec<&'a ast::Field> {
87 let all_constraints = HashMap::<String, _>::from_iter(
88 scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
89 );
90
91 scope
92 .iter_fields(decl)
93 .filter(|f| f.id().is_some())
94 .filter(|f| !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
95 .filter(|f| !all_constraints.contains_key(f.id().unwrap()))
96 .collect::<Vec<_>>()
97}
98
99fn packet_constant_fields<'a>(
103 scope: &'a analyzer::Scope<'a>,
104 decl: &'a ast::Decl,
105) -> Vec<&'a ast::Field> {
106 let all_constraints = HashMap::<String, _>::from_iter(
107 scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
108 );
109
110 scope
111 .iter_fields(decl)
112 .filter(|f| f.id().is_some())
113 .filter(|f| all_constraints.contains_key(f.id().unwrap()))
114 .collect::<Vec<_>>()
115}
116
117fn constraint_value(
118 fields: &[&'_ ast::Field],
119 constraint: &ast::Constraint,
120) -> proc_macro2::TokenStream {
121 match constraint {
122 ast::Constraint { value: Some(value), .. } => {
123 let value = proc_macro2::Literal::usize_unsuffixed(*value);
124 quote!(#value)
125 }
126 ast::Constraint { tag_id: Some(tag_id), .. } => {
129 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
130 let type_id = fields
131 .iter()
132 .filter_map(|f| match &f.desc {
133 ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
134 Some(type_id.to_ident())
135 }
136 _ => None,
137 })
138 .next()
139 .unwrap();
140 quote!(#type_id::#tag_id)
141 }
142 _ => unreachable!("Invalid constraint: {constraint:?}"),
143 }
144}
145
146fn constraint_value_str(fields: &[&'_ ast::Field], constraint: &ast::Constraint) -> String {
147 match constraint {
148 ast::Constraint { value: Some(value), .. } => {
149 format!("{}", value)
150 }
151 ast::Constraint { tag_id: Some(tag_id), .. } => {
152 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
153 let type_id = fields
154 .iter()
155 .filter_map(|f| match &f.desc {
156 ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
157 Some(type_id.to_ident())
158 }
159 _ => None,
160 })
161 .next()
162 .unwrap();
163 format!("{}::{}", type_id, tag_id)
164 }
165 _ => unreachable!("Invalid constraint: {constraint:?}"),
166 }
167}
168
169fn implements_copy(scope: &analyzer::Scope<'_>, field: &ast::Field) -> bool {
170 match &field.desc {
171 ast::FieldDesc::Scalar { .. } => true,
172 ast::FieldDesc::Typedef { type_id, .. } => match &scope.typedef[type_id].desc {
173 ast::DeclDesc::Enum { .. } | ast::DeclDesc::CustomField { .. } => true,
174 ast::DeclDesc::Struct { .. } => false,
175 desc => unreachable!("unexpected declaration: {desc:?}"),
176 },
177 ast::FieldDesc::Array { .. } => false,
178 _ => todo!(),
179 }
180}
181
182fn generate_root_packet_decl(
188 scope: &analyzer::Scope<'_>,
189 schema: &analyzer::Schema,
190 endianness: ast::EndiannessValue,
191 id: &str,
192) -> proc_macro2::TokenStream {
193 let decl = scope.typedef[id];
194 let name = id.to_ident();
195 let child_name = format_ident!("{id}Child");
196
197 let data_fields = packet_data_fields(scope, decl);
201 let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
202 let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
203 let data_field_borrows = data_fields
204 .iter()
205 .map(|f| {
206 if implements_copy(scope, f) {
207 quote! {}
208 } else {
209 quote! { & }
210 }
211 })
212 .collect::<Vec<_>>();
213 let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
214 let payload_accessor =
215 decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
216
217 let parser_span = format_ident!("buf");
218 let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
219 for field in decl.fields() {
220 field_parser.add(field);
221 }
222
223 let mut parsed_field_ids = vec![];
227 if decl.payload().is_some() {
228 parsed_field_ids.push(format_ident!("payload"));
229 }
230 for f in &data_fields {
231 let id = f.id().unwrap().to_ident();
232 parsed_field_ids.push(id);
233 }
234
235 let (encode_fields, encoded_len) =
236 encoder::encode(scope, schema, endianness, "buf".to_ident(), decl);
237
238 let encode = quote! {
239 fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
240 #encode_fields
241 Ok(())
242 }
243 };
244
245 let encoded_len = quote! {
247 fn encoded_len(&self) -> usize {
248 #encoded_len
249 }
250 };
251
252 let decode = quote! {
255 fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
256 #field_parser
257 Ok((Self { #( #parsed_field_ids, )* }, buf))
258 }
259 };
260
261 let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
265 let child_struct = (!children_decl.is_empty()).then(|| {
266 let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
267 quote! {
268 #[derive(Debug, Clone, PartialEq, Eq)]
269 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
270 pub enum #child_name {
271 #( #children_ids(#children_ids), )*
272 None,
273 }
274 }
275 });
276
277 let specialize = (!children_decl.is_empty()).then(|| {
281 let constraint_fields = children_decl
284 .iter()
285 .flat_map(|decl| decl.constraints().map(|c| c.id.to_owned()))
286 .collect::<BTreeSet<_>>();
287 let constraint_ids = constraint_fields.iter().map(|id| id.to_ident());
288 let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
289
290 let case_values = children_decl.iter().map(|child_decl| {
295 let constraint_values = constraint_fields.iter().map(|id| {
296 let constraint = child_decl.constraints().find(|c| &c.id == id);
297 match constraint {
298 Some(constraint) => constraint_value(&data_fields, constraint),
299 None => quote! { _ },
300 }
301 });
302 quote! { (#( #constraint_values, )*) }
303 });
304
305 let default_case = quote! { _ => #child_name::None, };
308
309 quote! {
310 pub fn specialize(&self) -> Result<#child_name, DecodeError> {
311 Ok(
312 match (#( self.#constraint_ids, )*) {
313 #( #case_values =>
314 #child_name::#children_ids(self.try_into()?), )*
315 #default_case
316 }
317 )
318 }
319 }
320 });
321
322 quote! {
323 #[derive(Debug, Clone, PartialEq, Eq)]
324 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
325 pub struct #name {
326 #( pub #data_field_ids: #data_field_types, )*
327 #payload_field
328 }
329
330 #child_struct
331
332 impl #name {
333 #specialize
334 #payload_accessor
335
336 #(
337 pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
338 #data_field_borrows self.#data_field_ids
339 }
340 )*
341 }
342
343 impl Packet for #name {
344 #encoded_len
345 #encode
346 #decode
347 }
348 }
349}
350
351fn generate_derived_packet_decl(
357 scope: &analyzer::Scope<'_>,
358 schema: &analyzer::Schema,
359 endianness: ast::EndiannessValue,
360 id: &str,
361) -> proc_macro2::TokenStream {
362 let decl = scope.typedef[id];
363 let name = id.to_ident();
364 let parent_decl = scope.get_parent(decl).unwrap();
365 let parent_name = parent_decl.id().unwrap().to_ident();
366 let child_name = format_ident!("{id}Child");
367
368 let all_constraints = HashMap::<String, _>::from_iter(
370 scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
371 );
372
373 let all_fields = scope.iter_fields(decl).collect::<Vec<_>>();
374
375 let data_fields = packet_data_fields(scope, decl);
379 let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
380 let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
381 let data_field_borrows = data_fields
382 .iter()
383 .map(|f| {
384 if implements_copy(scope, f) {
385 quote! {}
386 } else {
387 quote! { & }
388 }
389 })
390 .collect::<Vec<_>>();
391 let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
392 let payload_accessor =
393 decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
394
395 let parent_data_fields = packet_data_fields(scope, parent_decl);
396
397 let constant_fields = packet_constant_fields(scope, decl);
399 let constant_field_ids =
400 constant_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
401 let constant_field_types =
402 constant_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
403 let constant_field_values = constant_fields.iter().map(|f| {
404 let c = all_constraints.get(f.id().unwrap()).unwrap();
405 constraint_value(&all_fields, c)
406 });
407
408 let parser_span = format_ident!("buf");
410 let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
411 for field in decl.fields() {
412 field_parser.add(field);
413 }
414
415 let mut parsed_field_ids = vec![];
419 let mut copied_field_ids = vec![];
420 let mut cloned_field_ids = vec![];
421 if decl.payload().is_some() {
422 parsed_field_ids.push(format_ident!("payload"));
423 }
424 for f in &data_fields {
425 let id = f.id().unwrap().to_ident();
426 if decl.fields().any(|ff| f.id() == ff.id()) {
427 parsed_field_ids.push(id);
428 } else if implements_copy(scope, f) {
429 copied_field_ids.push(id);
430 } else {
431 cloned_field_ids.push(id);
432 }
433 }
434
435 let (partial_field_serializer, field_serializer, encoded_len) =
436 encoder::encode_partial(scope, schema, endianness, "buf".to_ident(), decl);
437
438 let encode_partial = quote! {
439 pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
440 #partial_field_serializer
441 Ok(())
442 }
443 };
444
445 let encode = quote! {
446 fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
447 #field_serializer
448 Ok(())
449 }
450 };
451
452 let encoded_len = quote! {
454 fn encoded_len(&self) -> usize {
455 #encoded_len
456 }
457 };
458
459 let constraint_checks = decl.constraints().map(|c| {
463 let field_id = c.id.to_ident();
464 let field_name = &c.id;
465 let packet_name = id;
466 let value = constraint_value(&parent_data_fields, c);
467 let value_str = constraint_value_str(&parent_data_fields, c);
468 quote! {
469 if parent.#field_id() != #value {
470 return Err(DecodeError::InvalidFieldValue {
471 packet: #packet_name,
472 field: #field_name,
473 expected: #value_str,
474 actual: format!("{:?}", parent.#field_id()),
475 })
476 }
477 }
478 });
479
480 let decode_partial = if parent_decl.payload().is_some() {
481 quote! {
486 fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
487 let mut buf: &[u8] = &parent.payload;
488 #( #constraint_checks )*
489 #field_parser
490 if buf.is_empty() {
491 Ok(Self {
492 #( #parsed_field_ids, )*
493 #( #copied_field_ids: parent.#copied_field_ids, )*
494 #( #cloned_field_ids: parent.#cloned_field_ids.clone(), )*
495 })
496 } else {
497 Err(DecodeError::TrailingBytes)
498 }
499 }
500 }
501 } else {
502 quote! {
507 fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
508 #( #constraint_checks )*
509 Ok(Self {
510 #( #copied_field_ids: parent.#copied_field_ids, )*
511 })
512 }
513 }
514 };
515
516 let decode =
517 quote! {
520 fn decode(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
521 let (parent, trailing_bytes) = #parent_name::decode(buf)?;
522 let packet = Self::decode_partial(&parent)?;
523 Ok((packet, trailing_bytes))
524 }
525 };
526
527 let into_parent = {
532 let parent_data_field_ids = parent_data_fields.iter().map(|f| f.id().unwrap().to_ident());
533 let parent_data_field_values = parent_data_fields.iter().map(|f| {
534 let id = f.id().unwrap().to_ident();
535 match all_constraints.get(f.id().unwrap()) {
536 Some(c) => constraint_value(&parent_data_fields, c),
537 None => quote! { packet.#id },
538 }
539 });
540 if parent_decl.payload().is_some() {
541 quote! {
542 impl TryFrom<&#name> for #parent_name {
543 type Error = EncodeError;
544 fn try_from(packet: &#name) -> Result<#parent_name, Self::Error> {
545 let mut payload = Vec::new();
546 packet.encode_partial(&mut payload)?;
547 Ok(#parent_name {
548 #( #parent_data_field_ids: #parent_data_field_values, )*
549 payload,
550 })
551 }
552 }
553
554 impl TryFrom<#name> for #parent_name {
555 type Error = EncodeError;
556 fn try_from(packet: #name) -> Result<#parent_name, Self::Error> {
557 (&packet).try_into()
558 }
559 }
560 }
561 } else {
562 quote! {
563 impl From<&#name> for #parent_name {
564 fn from(packet: &#name) -> #parent_name {
565 #parent_name {
566 #( #parent_data_field_ids: #parent_data_field_values, )*
567 }
568 }
569 }
570
571 impl From<#name> for #parent_name {
572 fn from(packet: #name) -> #parent_name {
573 (&packet).into()
574 }
575 }
576 }
577 }
578 };
579
580 let into_ancestors = scope.iter_parents(parent_decl).map(|ancestor_decl| {
581 let ancestor_name = ancestor_decl.id().unwrap().to_ident();
582 quote! {
583 impl TryFrom<&#name> for #ancestor_name {
584 type Error = EncodeError;
585 fn try_from(packet: &#name) -> Result<#ancestor_name, Self::Error> {
586 (&#parent_name::try_from(packet)?).try_into()
587 }
588 }
589
590 impl TryFrom<#name> for #ancestor_name {
591 type Error = EncodeError;
592 fn try_from(packet: #name) -> Result<#ancestor_name, Self::Error> {
593 (&packet).try_into()
594 }
595 }
596 }
597 });
598
599 let try_from_parent = quote! {
604 impl TryFrom<&#parent_name> for #name {
605 type Error = DecodeError;
606 fn try_from(parent: &#parent_name) -> Result<#name, Self::Error> {
607 #name::decode_partial(&parent)
608 }
609 }
610
611 impl TryFrom<#parent_name> for #name {
612 type Error = DecodeError;
613 fn try_from(parent: #parent_name) -> Result<#name, Self::Error> {
614 (&parent).try_into()
615 }
616 }
617 };
618
619 let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
623 let child_struct = (!children_decl.is_empty()).then(|| {
624 let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
625 quote! {
626 #[derive(Debug, Clone, PartialEq, Eq)]
627 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
628 pub enum #child_name {
629 #( #children_ids(#children_ids), )*
630 None,
631 }
632 }
633 });
634
635 let specialize = (!children_decl.is_empty()).then(|| {
639 let constraint_fields = children_decl
642 .iter()
643 .flat_map(|decl| decl.constraints().map(|c| c.id.to_owned()))
644 .collect::<BTreeSet<_>>();
645 let constraint_ids = constraint_fields.iter().map(|id| id.to_ident());
646 let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
647
648 let case_values = children_decl.iter().map(|child_decl| {
653 let constraint_values = constraint_fields.iter().map(|id| {
654 let constraint = child_decl.constraints().find(|c| &c.id == id);
655 match constraint {
656 Some(constraint) => constraint_value(&data_fields, constraint),
657 None => quote! { _ },
658 }
659 });
660 quote! { (#( #constraint_values, )*) }
661 });
662
663 let default_case = quote! { _ => #child_name::None, };
666
667 quote! {
668 pub fn specialize(&self) -> Result<#child_name, DecodeError> {
669 Ok(
670 match (#( self.#constraint_ids, )*) {
671 #( #case_values =>
672 #child_name::#children_ids(self.try_into()?), )*
673 #default_case
674 }
675 )
676 }
677 }
678 });
679
680 quote! {
681 #[derive(Debug, Clone, PartialEq, Eq)]
682 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
683 pub struct #name {
684 #( pub #data_field_ids: #data_field_types, )*
685 #payload_field
686 }
687
688 #try_from_parent
689 #into_parent
690 #( #into_ancestors )*
691
692 #child_struct
693
694 impl #name {
695 #specialize
696 #decode_partial
697 #encode_partial
698 #payload_accessor
699
700 #(
701 pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
702 #data_field_borrows self.#data_field_ids
703 }
704 )*
705
706 #(
707 pub fn #constant_field_ids(&self) -> #constant_field_types {
708 #constant_field_values
709 }
710 )*
711 }
712
713 impl Packet for #name {
714 #encoded_len
715 #encode
716 #decode
717 }
718 }
719}
720
721fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2::TokenStream {
728 fn enum_default_tag(tags: &[ast::Tag]) -> Option<ast::TagOther> {
730 tags.iter()
731 .filter_map(|tag| match tag {
732 ast::Tag::Other(tag) => Some(tag.clone()),
733 _ => None,
734 })
735 .next()
736 }
737
738 fn enum_is_complete(tags: &[ast::Tag], max: usize) -> bool {
741 let mut ranges = tags
742 .iter()
743 .filter_map(|tag| match tag {
744 ast::Tag::Value(tag) => Some((tag.value, tag.value)),
745 ast::Tag::Range(tag) => Some(tag.range.clone().into_inner()),
746 _ => None,
747 })
748 .collect::<Vec<_>>();
749 ranges.sort_unstable();
750 ranges.first().unwrap().0 == 0
751 && ranges.last().unwrap().1 == max
752 && ranges.windows(2).all(|window| {
753 if let [left, right] = window {
754 left.1 == right.0 - 1
755 } else {
756 false
757 }
758 })
759 }
760
761 fn enum_is_primitive(tags: &[ast::Tag]) -> bool {
763 tags.iter().all(|tag| matches!(tag, ast::Tag::Value(_)))
764 }
765
766 fn scalar_max(width: usize) -> usize {
768 if width >= usize::BITS as usize {
769 usize::MAX
770 } else {
771 (1 << width) - 1
772 }
773 }
774
775 fn format_tag_ident(id: &str) -> proc_macro2::TokenStream {
777 let id = format_ident!("{}", id.to_upper_camel_case());
778 quote! { #id }
779 }
780
781 fn format_value(value: usize) -> LitInt {
783 syn::parse_str::<syn::LitInt>(&format!("{:#x}", value)).unwrap()
784 }
785
786 let backing_type = types::Integer::new(width);
788 let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
789 let range_max = scalar_max(width);
790 let default_tag = enum_default_tag(tags);
791 let is_open = default_tag.is_some();
792 let is_complete = enum_is_complete(tags, scalar_max(width));
793 let is_primitive = enum_is_primitive(tags);
794 let name = id.to_ident();
795
796 let use_variant_values = is_primitive && (is_complete || !is_open);
799 let repr_u64 = use_variant_values.then(|| quote! { #[repr(u64)] });
800 let mut variants = vec![];
801 for tag in tags.iter() {
802 match tag {
803 ast::Tag::Value(tag) if use_variant_values => {
804 let id = format_tag_ident(&tag.id);
805 let value = format_value(tag.value);
806 variants.push(quote! { #id = #value })
807 }
808 ast::Tag::Value(tag) => variants.push(format_tag_ident(&tag.id)),
809 ast::Tag::Range(tag) => {
810 variants.extend(tag.tags.iter().map(|tag| format_tag_ident(&tag.id)));
811 let id = format_tag_ident(&tag.id);
812 variants.push(quote! { #id(Private<#backing_type>) })
813 }
814 ast::Tag::Other(_) => (),
815 }
816 }
817
818 let mut from_cases = vec![];
820 for tag in tags.iter() {
821 match tag {
822 ast::Tag::Value(tag) => {
823 let id = format_tag_ident(&tag.id);
824 let value = format_value(tag.value);
825 from_cases.push(quote! { #value => Ok(#name::#id) })
826 }
827 ast::Tag::Range(tag) => {
828 from_cases.extend(tag.tags.iter().map(|tag| {
829 let id = format_tag_ident(&tag.id);
830 let value = format_value(tag.value);
831 quote! { #value => Ok(#name::#id) }
832 }));
833 let id = format_tag_ident(&tag.id);
834 let start = format_value(*tag.range.start());
835 let end = format_value(*tag.range.end());
836 from_cases.push(quote! { #start ..= #end => Ok(#name::#id(Private(value))) })
837 }
838 ast::Tag::Other(_) => (),
839 }
840 }
841
842 let mut into_cases = vec![];
844 for tag in tags.iter() {
845 match tag {
846 ast::Tag::Value(tag) => {
847 let id = format_tag_ident(&tag.id);
848 let value = format_value(tag.value);
849 into_cases.push(quote! { #name::#id => #value })
850 }
851 ast::Tag::Range(tag) => {
852 into_cases.extend(tag.tags.iter().map(|tag| {
853 let id = format_tag_ident(&tag.id);
854 let value = format_value(tag.value);
855 quote! { #name::#id => #value }
856 }));
857 let id = format_tag_ident(&tag.id);
858 into_cases.push(quote! { #name::#id(Private(value)) => *value })
859 }
860 ast::Tag::Other(_) => (),
861 }
862 }
863
864 if !is_complete && is_open {
866 let unknown_id = format_tag_ident(&default_tag.unwrap().id);
867 let range_max = format_value(range_max);
868 variants.push(quote! { #unknown_id(Private<#backing_type>) });
869 from_cases.push(quote! { 0..=#range_max => Ok(#name::#unknown_id(Private(value))) });
870 into_cases.push(quote! { #name::#unknown_id(Private(value)) => *value });
871 }
872
873 if backing_type.width != width || (!is_complete && !is_open) {
876 from_cases.push(quote! { _ => Err(value) });
877 }
878
879 let derived_signed_into_types = [8, 16, 32, 64]
882 .into_iter()
883 .filter(|w| *w > width)
884 .map(|w| syn::parse_str::<syn::Type>(&format!("i{}", w)).unwrap());
885 let derived_unsigned_into_types = [8, 16, 32, 64]
886 .into_iter()
887 .filter(|w| *w >= width && *w != backing_type.width)
888 .map(|w| syn::parse_str::<syn::Type>(&format!("u{}", w)).unwrap());
889 let derived_into_types = derived_signed_into_types.chain(derived_unsigned_into_types);
890
891 quote! {
892 #repr_u64
893 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
894 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
895 #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
896 pub enum #name {
897 #(#variants,)*
898 }
899
900 impl TryFrom<#backing_type> for #name {
901 type Error = #backing_type;
902 fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
903 match value {
904 #(#from_cases,)*
905 }
906 }
907 }
908
909 impl From<&#name> for #backing_type {
910 fn from(value: &#name) -> Self {
911 match value {
912 #(#into_cases,)*
913 }
914 }
915 }
916
917 impl From<#name> for #backing_type {
918 fn from(value: #name) -> Self {
919 (&value).into()
920 }
921 }
922
923 #(impl From<#name> for #derived_into_types {
924 fn from(value: #name) -> Self {
925 #backing_type::from(value) as Self
926 }
927 })*
928 }
929}
930
931fn generate_custom_field_decl(
936 endianness: ast::EndiannessValue,
937 id: &str,
938 width: usize,
939) -> proc_macro2::TokenStream {
940 let name = id;
941 let id = id.to_ident();
942 let backing_type = types::Integer::new(width);
943 let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
944 let max_value = mask_bits(width, &format!("u{}", backing_type.width));
945 let size = proc_macro2::Literal::usize_unsuffixed(width / 8);
946
947 let read_value = types::get_uint(endianness, width, &format_ident!("buf"));
948 let read_value = if [8, 16, 32, 64].contains(&width) {
949 quote! { #read_value.into() }
950 } else {
951 quote! { (#read_value).try_into().unwrap() }
953 };
954
955 let write_value = types::put_uint(
956 endianness,
957 "e! { #backing_type::from(self) },
958 width,
959 &format_ident!("buf"),
960 );
961
962 let common = quote! {
963 impl From<&#id> for #backing_type {
964 fn from(value: &#id) -> #backing_type {
965 value.0
966 }
967 }
968
969 impl From<#id> for #backing_type {
970 fn from(value: #id) -> #backing_type {
971 value.0
972 }
973 }
974
975 impl Packet for #id {
976 fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
977 if buf.len() < #size {
978 return Err(DecodeError::InvalidLengthError {
979 obj: #name,
980 wanted: #size,
981 got: buf.len(),
982 })
983 }
984
985 Ok((#read_value, buf))
986 }
987
988 fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
989 #write_value;
990 Ok(())
991 }
992
993 fn encoded_len(&self) -> usize {
994 #size
995 }
996 }
997 };
998
999 if backing_type.width == width {
1000 quote! {
1001 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
1002 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1003 #[cfg_attr(feature = "serde", serde(from = #backing_type_str, into = #backing_type_str))]
1004 pub struct #id(#backing_type);
1005
1006 #common
1007
1008 impl From<#backing_type> for #id {
1009 fn from(value: #backing_type) -> Self {
1010 #id(value)
1011 }
1012 }
1013 }
1014 } else {
1015 quote! {
1016 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
1017 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
1018 #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
1019 pub struct #id(#backing_type);
1020
1021 #common
1022
1023 impl TryFrom<#backing_type> for #id {
1024 type Error = #backing_type;
1025 fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
1026 if value > #max_value {
1027 Err(value)
1028 } else {
1029 Ok(#id(value))
1030 }
1031 }
1032 }
1033 }
1034 }
1035}
1036
1037fn generate_decl(
1038 scope: &analyzer::Scope<'_>,
1039 schema: &analyzer::Schema,
1040 file: &ast::File,
1041 decl: &ast::Decl,
1042) -> proc_macro2::TokenStream {
1043 match &decl.desc {
1044 ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
1045 match scope.get_parent(decl) {
1046 None => generate_root_packet_decl(scope, schema, file.endianness.value, id),
1047 Some(_) => generate_derived_packet_decl(scope, schema, file.endianness.value, id),
1048 }
1049 }
1050 ast::DeclDesc::Enum { id, tags, width } => generate_enum_decl(id, tags, *width),
1051 ast::DeclDesc::CustomField { id, width: Some(width), .. } => {
1052 generate_custom_field_decl(file.endianness.value, id, *width)
1053 }
1054 ast::DeclDesc::CustomField { .. } => {
1055 quote!()
1058 }
1059 _ => todo!("unsupported Decl::{:?}", decl),
1060 }
1061}
1062
1063pub fn generate_tokens(
1068 sources: &ast::SourceDatabase,
1069 file: &ast::File,
1070 custom_fields: &[String],
1071) -> proc_macro2::TokenStream {
1072 let source = sources.get(file.file).expect("could not read source");
1073 let preamble = preamble::generate(Path::new(source.name()));
1074
1075 let scope = analyzer::Scope::new(file).expect("could not create scope");
1076 let schema = analyzer::Schema::new(file);
1077 let custom_fields = custom_fields.iter().map(|custom_field| {
1078 syn::parse_str::<syn::Path>(custom_field)
1079 .unwrap_or_else(|err| panic!("invalid path '{custom_field}': {err:?}"))
1080 });
1081 let decls = file.declarations.iter().map(|decl| generate_decl(&scope, &schema, file, decl));
1082 quote! {
1083 #preamble
1084 #(use #custom_fields;)*
1085
1086 #(#decls)*
1087 }
1088}
1089
1090pub fn generate(
1095 sources: &ast::SourceDatabase,
1096 file: &ast::File,
1097 custom_fields: &[String],
1098) -> String {
1099 let syntax_tree =
1100 syn::parse2(generate_tokens(sources, file, custom_fields)).expect("Could not parse code");
1101 prettyplease::unparse(&syntax_tree)
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107 use crate::analyzer;
1108 use crate::ast;
1109 use crate::parser::parse_inline;
1110 use crate::test_utils::{assert_snapshot_eq, format_rust};
1111 use paste::paste;
1112
1113 macro_rules! make_pdl_test {
1128 ($name:ident, $code:expr, $endianness:ident) => {
1129 paste! {
1130 #[test]
1131 fn [< test_ $name _ $endianness >]() {
1132 let name = stringify!($name);
1133 let endianness = stringify!($endianness);
1134 let code = format!("{endianness}_packets\n{}", $code);
1135 let mut db = ast::SourceDatabase::new();
1136 let file = parse_inline(&mut db, "test", code).unwrap();
1137 let file = analyzer::analyze(&file).unwrap();
1138 let actual_code = generate(&db, &file, &[]);
1139 assert_snapshot_eq(
1140 &format!("tests/generated/rust/{name}_{endianness}.rs"),
1141 &format_rust(&actual_code),
1142 );
1143 }
1144 }
1145 };
1146 }
1147
1148 macro_rules! test_pdl {
1154 ($name:ident, $code:expr $(,)?) => {
1155 make_pdl_test!($name, $code, little_endian);
1156 make_pdl_test!($name, $code, big_endian);
1157 };
1158 }
1159
1160 test_pdl!(packet_decl_empty, "packet Foo {}");
1161
1162 test_pdl!(packet_decl_8bit_scalar, " packet Foo { x: 8 }");
1163 test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
1164 test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");
1165
1166 test_pdl!(
1167 enum_declaration,
1168 r#"
1169 enum IncompleteTruncatedClosed : 3 {
1170 A = 0,
1171 B = 1,
1172 }
1173
1174 enum IncompleteTruncatedOpen : 3 {
1175 A = 0,
1176 B = 1,
1177 UNKNOWN = ..
1178 }
1179
1180 enum IncompleteTruncatedClosedWithRange : 3 {
1181 A = 0,
1182 B = 1..6 {
1183 X = 1,
1184 Y = 2,
1185 }
1186 }
1187
1188 enum IncompleteTruncatedOpenWithRange : 3 {
1189 A = 0,
1190 B = 1..6 {
1191 X = 1,
1192 Y = 2,
1193 },
1194 UNKNOWN = ..
1195 }
1196
1197 enum CompleteTruncated : 3 {
1198 A = 0,
1199 B = 1,
1200 C = 2,
1201 D = 3,
1202 E = 4,
1203 F = 5,
1204 G = 6,
1205 H = 7,
1206 }
1207
1208 enum CompleteTruncatedWithRange : 3 {
1209 A = 0,
1210 B = 1..7 {
1211 X = 1,
1212 Y = 2,
1213 }
1214 }
1215
1216 enum CompleteWithRange : 8 {
1217 A = 0,
1218 B = 1,
1219 C = 2..255,
1220 }
1221 "#
1222 );
1223
1224 test_pdl!(
1225 custom_field_declaration,
1226 r#"
1227 // Still unsupported.
1228 // custom_field Dynamic "dynamic"
1229
1230 // Should generate a type with From<u32> implementation.
1231 custom_field ExactSize : 32 "exact_size"
1232
1233 // Should generate a type with TryFrom<u32> implementation.
1234 custom_field TruncatedSize : 24 "truncated_size"
1235 "#
1236 );
1237
1238 test_pdl!(
1239 packet_decl_simple_scalars,
1240 r#"
1241 packet Foo {
1242 x: 8,
1243 y: 16,
1244 z: 24,
1245 }
1246 "#
1247 );
1248
1249 test_pdl!(
1250 packet_decl_complex_scalars,
1251 r#"
1252 packet Foo {
1253 a: 3,
1254 b: 8,
1255 c: 5,
1256 d: 24,
1257 e: 12,
1258 f: 4,
1259 }
1260 "#,
1261 );
1262
1263 test_pdl!(
1266 packet_decl_mask_scalar_value,
1267 r#"
1268 packet Foo {
1269 a: 2,
1270 b: 24,
1271 c: 6,
1272 }
1273 "#,
1274 );
1275
1276 test_pdl!(
1277 struct_decl_complex_scalars,
1278 r#"
1279 struct Foo {
1280 a: 3,
1281 b: 8,
1282 c: 5,
1283 d: 24,
1284 e: 12,
1285 f: 4,
1286 }
1287 "#,
1288 );
1289
1290 test_pdl!(packet_decl_8bit_enum, " enum Foo : 8 { A = 1, B = 2 } packet Bar { x: Foo }");
1291 test_pdl!(packet_decl_24bit_enum, "enum Foo : 24 { A = 1, B = 2 } packet Bar { x: Foo }");
1292 test_pdl!(packet_decl_64bit_enum, "enum Foo : 64 { A = 1, B = 2 } packet Bar { x: Foo }");
1293
1294 test_pdl!(
1295 packet_decl_mixed_scalars_enums,
1296 "
1297 enum Enum7 : 7 {
1298 A = 1,
1299 B = 2,
1300 }
1301
1302 enum Enum9 : 9 {
1303 A = 1,
1304 B = 2,
1305 }
1306
1307 packet Foo {
1308 x: Enum7,
1309 y: 5,
1310 z: Enum9,
1311 w: 3,
1312 }
1313 "
1314 );
1315
1316 test_pdl!(packet_decl_8bit_scalar_array, " packet Foo { x: 8[3] }");
1317 test_pdl!(packet_decl_24bit_scalar_array, "packet Foo { x: 24[5] }");
1318 test_pdl!(packet_decl_64bit_scalar_array, "packet Foo { x: 64[7] }");
1319
1320 test_pdl!(
1321 packet_decl_8bit_enum_array,
1322 "enum Foo : 8 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[3] }"
1323 );
1324 test_pdl!(
1325 packet_decl_24bit_enum_array,
1326 "enum Foo : 24 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[5] }"
1327 );
1328 test_pdl!(
1329 packet_decl_64bit_enum_array,
1330 "enum Foo : 64 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[7] }"
1331 );
1332
1333 test_pdl!(
1334 packet_decl_array_dynamic_count,
1335 "
1336 packet Foo {
1337 _count_(x): 5,
1338 padding: 3,
1339 x: 24[]
1340 }
1341 "
1342 );
1343
1344 test_pdl!(
1345 packet_decl_array_dynamic_size,
1346 "
1347 packet Foo {
1348 _size_(x): 5,
1349 padding: 3,
1350 x: 24[]
1351 }
1352 "
1353 );
1354
1355 test_pdl!(
1356 packet_decl_array_unknown_element_width_dynamic_size,
1357 "
1358 struct Foo {
1359 _count_(a): 40,
1360 a: 16[],
1361 }
1362
1363 packet Bar {
1364 _size_(x): 40,
1365 x: Foo[],
1366 }
1367 "
1368 );
1369
1370 test_pdl!(
1371 packet_decl_array_unknown_element_width_dynamic_count,
1372 "
1373 struct Foo {
1374 _count_(a): 40,
1375 a: 16[],
1376 }
1377
1378 packet Bar {
1379 _count_(x): 40,
1380 x: Foo[],
1381 }
1382 "
1383 );
1384
1385 test_pdl!(
1386 packet_decl_array_with_padding,
1387 "
1388 struct Foo {
1389 _count_(a): 40,
1390 a: 16[],
1391 }
1392
1393 packet Bar {
1394 a: Foo[],
1395 _padding_ [128],
1396 }
1397 "
1398 );
1399
1400 test_pdl!(
1401 packet_decl_array_dynamic_element_size,
1402 "
1403 struct Foo {
1404 inner: 8[]
1405 }
1406 packet Bar {
1407 _elementsize_(x): 5,
1408 padding: 3,
1409 x: Foo[]
1410 }
1411 "
1412 );
1413
1414 test_pdl!(
1415 packet_decl_array_dynamic_element_size_dynamic_size,
1416 "
1417 struct Foo {
1418 inner: 8[]
1419 }
1420 packet Bar {
1421 _size_(x): 4,
1422 _elementsize_(x): 4,
1423 x: Foo[]
1424 }
1425 "
1426 );
1427
1428 test_pdl!(
1429 packet_decl_array_dynamic_element_size_dynamic_count,
1430 "
1431 struct Foo {
1432 inner: 8[]
1433 }
1434 packet Bar {
1435 _count_(x): 4,
1436 _elementsize_(x): 4,
1437 x: Foo[]
1438 }
1439 "
1440 );
1441
1442 test_pdl!(
1443 packet_decl_array_dynamic_element_size_static_count,
1444 "
1445 struct Foo {
1446 inner: 8[]
1447 }
1448 packet Bar {
1449 _elementsize_(x): 5,
1450 padding: 3,
1451 x: Foo[4]
1452 }
1453 "
1454 );
1455
1456 test_pdl!(
1457 packet_decl_array_dynamic_element_size_static_count_1,
1458 "
1459 struct Foo {
1460 inner: 8[]
1461 }
1462 packet Bar {
1463 _elementsize_(x): 5,
1464 padding: 3,
1465 x: Foo[1]
1466 }
1467 "
1468 );
1469
1470 test_pdl!(
1471 packet_decl_reserved_field,
1472 "
1473 packet Foo {
1474 _reserved_: 40,
1475 }
1476 "
1477 );
1478
1479 test_pdl!(
1480 packet_decl_custom_field,
1481 r#"
1482 custom_field Bar1 : 24 "exact"
1483 custom_field Bar2 : 32 "truncated"
1484
1485 packet Foo {
1486 a: Bar1,
1487 b: Bar2,
1488 }
1489 "#
1490 );
1491
1492 test_pdl!(
1493 packet_decl_fixed_scalar_field,
1494 "
1495 packet Foo {
1496 _fixed_ = 7 : 7,
1497 b: 57,
1498 }
1499 "
1500 );
1501
1502 test_pdl!(
1503 packet_decl_fixed_enum_field,
1504 "
1505 enum Enum7 : 7 {
1506 A = 1,
1507 B = 2,
1508 }
1509
1510 packet Foo {
1511 _fixed_ = A : Enum7,
1512 b: 57,
1513 }
1514 "
1515 );
1516
1517 test_pdl!(
1518 packet_decl_payload_field_variable_size,
1519 "
1520 packet Foo {
1521 a: 8,
1522 _size_(_payload_): 8,
1523 _payload_,
1524 b: 16,
1525 }
1526 "
1527 );
1528
1529 test_pdl!(
1530 packet_decl_payload_field_unknown_size,
1531 "
1532 packet Foo {
1533 a: 24,
1534 _payload_,
1535 }
1536 "
1537 );
1538
1539 test_pdl!(
1540 packet_decl_payload_field_unknown_size_terminal,
1541 "
1542 packet Foo {
1543 _payload_,
1544 a: 24,
1545 }
1546 "
1547 );
1548
1549 test_pdl!(
1550 packet_decl_child_packets,
1551 "
1552 enum Enum16 : 16 {
1553 A = 1,
1554 B = 2,
1555 }
1556
1557 packet Foo {
1558 a: 8,
1559 b: Enum16,
1560 _size_(_payload_): 8,
1561 _payload_
1562 }
1563
1564 packet Bar : Foo (a = 100) {
1565 x: 8,
1566 }
1567
1568 packet Baz : Foo (b = B) {
1569 y: 16,
1570 }
1571 "
1572 );
1573
1574 test_pdl!(
1575 packet_decl_grand_children,
1576 "
1577 enum Enum16 : 16 {
1578 A = 1,
1579 B = 2,
1580 }
1581
1582 packet Parent {
1583 foo: Enum16,
1584 bar: Enum16,
1585 baz: Enum16,
1586 _size_(_payload_): 8,
1587 _payload_
1588 }
1589
1590 packet Child : Parent (foo = A) {
1591 quux: Enum16,
1592 _payload_,
1593 }
1594
1595 packet GrandChild : Child (bar = A, quux = A) {
1596 _body_,
1597 }
1598
1599 packet GrandGrandChild : GrandChild (baz = A) {
1600 _body_,
1601 }
1602 "
1603 );
1604
1605 test_pdl!(
1606 packet_decl_parent_with_no_payload,
1607 "
1608 enum Enum8 : 8 {
1609 A = 0,
1610 }
1611
1612 packet Parent {
1613 v : Enum8,
1614 }
1615
1616 packet Child : Parent (v = A) {
1617 }
1618 "
1619 );
1620
1621 test_pdl!(
1622 packet_decl_parent_with_alias_child,
1623 "
1624 enum Enum8 : 8 {
1625 A = 0,
1626 B = 1,
1627 C = 2,
1628 }
1629
1630 packet Parent {
1631 v : Enum8,
1632 _payload_,
1633 }
1634
1635 packet AliasChild : Parent {
1636 _payload_
1637 }
1638
1639 packet NormalChild : Parent (v = A) {
1640 }
1641
1642 packet NormalGrandChild1 : AliasChild (v = B) {
1643 }
1644
1645 packet NormalGrandChild2 : AliasChild (v = C) {
1646 _payload_
1647 }
1648 "
1649 );
1650
1651 test_pdl!(
1652 reserved_identifier,
1653 "
1654 packet Test {
1655 type: 8,
1656 }
1657 "
1658 );
1659
1660 test_pdl!(
1661 payload_with_size_modifier,
1662 "
1663 packet Test {
1664 _size_(_payload_): 8,
1665 _payload_ : [+1],
1666 }
1667 "
1668 );
1669
1670 test_pdl!(
1671 struct_decl_child_structs,
1672 "
1673 enum Enum16 : 16 {
1674 A = 1,
1675 B = 2,
1676 }
1677
1678 struct Foo {
1679 a: 8,
1680 b: Enum16,
1681 _size_(_payload_): 8,
1682 _payload_
1683 }
1684
1685 struct Bar : Foo (a = 100) {
1686 x: 8,
1687 }
1688
1689 struct Baz : Foo (b = B) {
1690 y: 16,
1691 }
1692 "
1693 );
1694
1695 test_pdl!(
1696 struct_decl_grand_children,
1697 "
1698 enum Enum16 : 16 {
1699 A = 1,
1700 B = 2,
1701 }
1702
1703 struct Parent {
1704 foo: Enum16,
1705 bar: Enum16,
1706 baz: Enum16,
1707 _size_(_payload_): 8,
1708 _payload_
1709 }
1710
1711 struct Child : Parent (foo = A) {
1712 quux: Enum16,
1713 _payload_,
1714 }
1715
1716 struct GrandChild : Child (bar = A, quux = A) {
1717 _body_,
1718 }
1719
1720 struct GrandGrandChild : GrandChild (baz = A) {
1721 _body_,
1722 }
1723 "
1724 );
1725}