1use crate::{analyzer, ast};
18use quote::{format_ident, quote};
19use std::collections::BTreeSet;
20use std::collections::HashMap;
21use std::path::Path;
22use syn::LitInt;
23
24mod parser;
25mod preamble;
26mod serializer;
27pub mod test;
28mod types;
29
30use parser::FieldParser;
31use serializer::FieldSerializer;
32
33pub use heck::ToUpperCamelCase;
34
35pub trait ToIdent {
36 fn to_ident(self) -> proc_macro2::Ident;
39}
40
41impl ToIdent for &'_ str {
42 fn to_ident(self) -> proc_macro2::Ident {
43 match self {
44 "as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
45 | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod"
46 | "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self" | "static" | "struct"
47 | "super" | "trait" | "true" | "type" | "unsafe" | "use" | "where" | "while"
48 | "async" | "await" | "dyn" | "abstract" | "become" | "box" | "do" | "final"
49 | "macro" | "override" | "priv" | "typeof" | "unsized" | "virtual" | "yield"
50 | "try" => format_ident!("r#{}", self),
51 _ => format_ident!("{}", self),
52 }
53 }
54}
55
56pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
63 let suffix = if n > 31 { format!("_{suffix}") } else { String::new() };
64 let hex_digits = format!("{:x}", (1u64 << n) - 1)
66 .as_bytes()
67 .rchunks(4)
68 .rev()
69 .map(|chunk| std::str::from_utf8(chunk).unwrap())
70 .collect::<Vec<&str>>()
71 .join("_");
72 syn::parse_str::<syn::LitInt>(&format!("0x{hex_digits}{suffix}")).unwrap()
73}
74
75fn generate_packet_size_getter<'a>(
76 scope: &analyzer::Scope<'a>,
77 schema: &analyzer::Schema,
78 fields: impl Iterator<Item = &'a ast::Field>,
79 is_packet: bool,
80) -> (usize, proc_macro2::TokenStream) {
81 let mut constant_width = 0;
82 let mut dynamic_widths = Vec::new();
83
84 for field in fields {
85 if let Some(width) =
86 schema.padded_size(field.key).or(schema.field_size(field.key).static_())
87 {
88 constant_width += width;
89 continue;
90 }
91
92 let decl = scope.get_type_declaration(field);
93 dynamic_widths.push(match &field.desc {
94 ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body => {
95 if is_packet {
96 quote! {
97 self.child.get_total_size()
98 }
99 } else {
100 quote! {
101 self.payload.len()
102 }
103 }
104 }
105 ast::FieldDesc::Scalar { id, width } => {
106 assert!(field.cond.is_some());
107 let id = id.to_ident();
108 let width = syn::Index::from(*width / 8);
109 quote!(if self.#id.is_some() { #width } else { 0 })
110 }
111 ast::FieldDesc::Typedef { id, type_id, .. } if field.cond.is_some() => {
112 let id = id.to_ident();
113 match &scope.typedef[type_id].desc {
114 ast::DeclDesc::Enum { width, .. } => {
115 let width = syn::Index::from(*width / 8);
116 quote!(if self.#id.is_some() { #width } else { 0 })
117 }
118 _ => {
119 let type_id = type_id.to_ident();
120 quote! {
121 self.#id
122 .as_ref()
123 .map(#type_id::get_size)
124 .unwrap_or(0)
125 }
126 }
127 }
128 }
129 ast::FieldDesc::Typedef { id, .. } => {
130 let id = id.to_ident();
131 quote!(self.#id.get_size())
132 }
133 ast::FieldDesc::Array { id, width, .. } => {
134 let id = id.to_ident();
135 match &decl {
136 Some(ast::Decl {
137 desc: ast::DeclDesc::Struct { .. } | ast::DeclDesc::CustomField { .. },
138 ..
139 }) => {
140 quote! {
141 self.#id.iter().map(|elem| elem.get_size()).sum::<usize>()
142 }
143 }
144 Some(ast::Decl { desc: ast::DeclDesc::Enum { width, .. }, .. }) => {
145 let width = syn::Index::from(width / 8);
146 let mul_width = (width.index > 1).then(|| quote!(* #width));
147 quote! {
148 self.#id.len() #mul_width
149 }
150 }
151 _ => {
152 let width = syn::Index::from(width.unwrap() / 8);
153 let mul_width = (width.index > 1).then(|| quote!(* #width));
154 quote! {
155 self.#id.len() #mul_width
156 }
157 }
158 }
159 }
160 _ => panic!("Unsupported field type: {field:?}"),
161 });
162 }
163
164 if constant_width > 0 {
165 let width = syn::Index::from(constant_width / 8);
166 dynamic_widths.insert(0, quote!(#width));
167 }
168 if dynamic_widths.is_empty() {
169 dynamic_widths.push(quote!(0))
170 }
171
172 (
173 constant_width,
174 quote! {
175 #(#dynamic_widths)+*
176 },
177 )
178}
179
180fn top_level_packet<'a>(scope: &analyzer::Scope<'a>, packet_name: &'a str) -> &'a ast::Decl {
181 let mut decl = scope.typedef[packet_name];
182 while let ast::DeclDesc::Packet { parent_id: Some(parent_id), .. }
183 | ast::DeclDesc::Struct { parent_id: Some(parent_id), .. } = &decl.desc
184 {
185 decl = scope.typedef[parent_id];
186 }
187 decl
188}
189
190fn find_constrained_parent_fields<'a>(
196 scope: &analyzer::Scope<'a>,
197 id: &str,
198) -> Vec<&'a ast::Field> {
199 let all_parent_fields: HashMap<String, &'a ast::Field> = HashMap::from_iter(
200 scope
201 .iter_parent_fields(scope.typedef[id])
202 .filter_map(|f| f.id().map(|id| (id.to_string(), f))),
203 );
204
205 let mut fields = Vec::new();
206 let mut field_names = BTreeSet::new();
207 let mut children = scope.iter_children(scope.typedef[id]).collect::<Vec<_>>();
208
209 while let Some(child) = children.pop() {
210 if let ast::DeclDesc::Packet { id, constraints, .. }
211 | ast::DeclDesc::Struct { id, constraints, .. } = &child.desc
212 {
213 for constraint in constraints {
214 if field_names.insert(&constraint.id)
215 && all_parent_fields.contains_key(&constraint.id)
216 {
217 fields.push(all_parent_fields[&constraint.id]);
218 }
219 }
220 children.extend(scope.iter_children(scope.typedef[id]).collect::<Vec<_>>());
221 }
222 }
223
224 fields
225}
226
227fn generate_data_struct(
232 scope: &analyzer::Scope<'_>,
233 schema: &analyzer::Schema,
234 endianness: ast::EndiannessValue,
235 id: &str,
236) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
237 let decl = scope.typedef[id];
238 let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });
239
240 let span = format_ident!("bytes");
241 let serializer_span = format_ident!("buffer");
242 let mut field_parser = FieldParser::new(scope, schema, endianness, id, &span);
243 let mut field_serializer =
244 FieldSerializer::new(scope, schema, endianness, id, &serializer_span);
245 for field in decl.fields() {
246 field_parser.add(field);
247 field_serializer.add(field);
248 }
249 field_parser.done();
250
251 let (parse_arg_names, parse_arg_types) = if is_packet {
252 let fields = find_constrained_parent_fields(scope, id);
253 let names = fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
254 let types = fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
255 (names, types)
256 } else {
257 (Vec::new(), Vec::new()) };
259
260 let (constant_width, packet_size) =
261 generate_packet_size_getter(scope, schema, decl.fields(), is_packet);
262 let conforms = if constant_width == 0 {
263 quote! { true }
264 } else {
265 let constant_width = syn::Index::from(constant_width / 8);
266 quote! { #span.len() >= #constant_width }
267 };
268
269 let visibility = if is_packet { quote!() } else { quote!(pub) };
270 let has_payload = decl.payload().is_some();
271 let has_children = scope.iter_children(decl).next().is_some();
272
273 let struct_name = if is_packet { format_ident!("{id}Data") } else { id.to_ident() };
274 let backed_fields = decl
275 .fields()
276 .filter(|f| f.id().is_some() && !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
277 .collect::<Vec<_>>();
278
279 let mut field_names =
280 backed_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
281 let mut field_types = backed_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
282
283 if has_children || has_payload {
284 if is_packet {
285 field_names.push(format_ident!("child"));
286 let field_type = format_ident!("{id}DataChild");
287 field_types.push(quote!(#field_type));
288 } else {
289 field_names.push(format_ident!("payload"));
290 field_types.push(quote!(Vec<u8>));
291 }
292 }
293
294 let data_struct_decl = quote! {
295 #[derive(Debug, Clone, PartialEq, Eq)]
296 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
297 pub struct #struct_name {
298 #(#visibility #field_names: #field_types,)*
299 }
300 };
301
302 let data_struct_impl = quote! {
303 impl #struct_name {
304 fn conforms(#span: &[u8]) -> bool {
305 #conforms
306 }
307
308 #visibility fn parse(
309 #span: &[u8] #(, #parse_arg_names: #parse_arg_types)*
310 ) -> Result<Self, DecodeError> {
311 let mut cell = Cell::new(#span);
312 let packet = Self::parse_inner(&mut cell #(, #parse_arg_names)*)?;
313 Ok(packet)
315 }
316
317 fn parse_inner(
318 mut #span: &mut Cell<&[u8]> #(, #parse_arg_names: #parse_arg_types)*
319 ) -> Result<Self, DecodeError> {
320 #field_parser
321 Ok(Self {
322 #(#field_names,)*
323 })
324 }
325
326 fn write_to<T: BufMut>(&self, buffer: &mut T) -> Result<(), EncodeError> {
327 #field_serializer
328 Ok(())
329 }
330
331 fn get_total_size(&self) -> usize {
332 self.get_size()
333 }
334
335 fn get_size(&self) -> usize {
336 #packet_size
337 }
338 }
339 };
340
341 (data_struct_decl, data_struct_impl)
342}
343
344pub fn constraint_to_value(
347 all_fields: &HashMap<String, &'_ ast::Field>,
348 constraint: &ast::Constraint,
349) -> proc_macro2::TokenStream {
350 match constraint {
351 ast::Constraint { value: Some(value), .. } => {
352 let value = proc_macro2::Literal::usize_unsuffixed(*value);
353 quote!(#value)
354 }
355 ast::Constraint { tag_id: Some(tag_id), .. } => {
358 let type_id = match &all_fields[&constraint.id].desc {
359 ast::FieldDesc::Typedef { type_id, .. } => type_id.to_ident(),
360 _ => unreachable!("Invalid constraint: {constraint:?}"),
361 };
362 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
363 quote!(#type_id::#tag_id)
364 }
365 _ => unreachable!("Invalid constraint: {constraint:?}"),
366 }
367}
368
369fn generate_packet_decl(
371 scope: &analyzer::Scope<'_>,
372 schema: &analyzer::Schema,
373 endianness: ast::EndiannessValue,
374 id: &str,
375) -> proc_macro2::TokenStream {
376 let decl = scope.typedef[id];
377 let top_level = top_level_packet(scope, id);
378 let top_level_id = top_level.id().unwrap();
379 let top_level_packet = top_level_id.to_ident();
380 let top_level_data = format_ident!("{top_level_id}Data");
381 let top_level_id_lower = top_level_id.to_lowercase().to_ident();
382
383 let span = format_ident!("bytes");
386 let id_lower = id.to_lowercase().to_ident();
387 let id_packet = id.to_ident();
388 let id_child = format_ident!("{id}Child");
389 let id_data_child = format_ident!("{id}DataChild");
390 let id_builder = format_ident!("{id}Builder");
391
392 let mut parents = scope.iter_parents_and_self(decl).collect::<Vec<_>>();
393 parents.reverse();
394
395 let parent_ids = parents.iter().map(|p| p.id().unwrap()).collect::<Vec<_>>();
396 let parent_shifted_ids = parent_ids.iter().skip(1).map(|id| id.to_ident());
397 let parent_lower_ids =
398 parent_ids.iter().map(|id| id.to_lowercase().to_ident()).collect::<Vec<_>>();
399 let parent_shifted_lower_ids = parent_lower_ids.iter().skip(1).collect::<Vec<_>>();
400 let parent_packet = parent_ids.iter().map(|id| id.to_ident());
401 let parent_data = parent_ids.iter().map(|id| format_ident!("{id}Data"));
402 let parent_data_child = parent_ids.iter().map(|id| format_ident!("{id}DataChild"));
403
404 let all_fields = {
405 let mut fields = scope
406 .iter_fields(decl)
407 .filter(|f| f.id().is_some() && !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
408 .collect::<Vec<_>>();
409 fields.sort_by_key(|f| f.id());
410 fields
411 };
412 let all_named_fields =
413 HashMap::from_iter(all_fields.iter().map(|f| (f.id().unwrap().to_string(), *f)));
414
415 let all_field_names = all_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
416 let all_field_types = all_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
417 let all_field_borrows =
418 all_fields.iter().map(|f| types::rust_borrow(f, scope)).collect::<Vec<_>>();
419 let all_field_getter_names =
420 all_fields.iter().map(|f| format_ident!("get_{}", f.id().unwrap()));
421 let all_field_self_field = all_fields.iter().map(|f| {
422 for (parent, parent_id) in parents.iter().zip(parent_lower_ids.iter()) {
423 if parent.fields().any(|ff| ff.id() == f.id()) {
424 return quote!(self.#parent_id);
425 }
426 }
427 unreachable!("Could not find {f:?} in parent chain");
428 });
429
430 let all_constraints = HashMap::<String, _>::from_iter(
431 scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
432 );
433
434 let unconstrained_fields = all_fields
435 .iter()
436 .filter(|f| !all_constraints.contains_key(f.id().unwrap()))
437 .collect::<Vec<_>>();
438 let unconstrained_field_names =
439 unconstrained_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
440 let unconstrained_field_types = unconstrained_fields.iter().map(|f| types::rust_type(f));
441
442 let rev_parents = parents.iter().rev().collect::<Vec<_>>();
443 let builder_assignments = rev_parents.iter().enumerate().map(|(idx, parent)| {
444 let parent_id = parent.id().unwrap();
445 let parent_id_lower = parent_id.to_lowercase().to_ident();
446 let parent_data = format_ident!("{parent_id}Data");
447 let parent_data_child = format_ident!("{parent_id}DataChild");
448
449 let named_fields = {
450 let mut names = parent
451 .fields()
452 .filter(|f| !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
453 .filter_map(ast::Field::id)
454 .collect::<Vec<_>>();
455 names.sort_unstable();
456 names
457 };
458
459 let mut field = named_fields.iter().map(|id| id.to_ident()).collect::<Vec<_>>();
460 let mut value = named_fields
461 .iter()
462 .map(|&id| match all_constraints.get(id) {
463 Some(constraint) => constraint_to_value(&all_named_fields, constraint),
464 None => {
465 let id = id.to_ident();
466 quote!(self.#id)
467 }
468 })
469 .collect::<Vec<_>>();
470
471 if parent.payload().is_some() {
472 field.push(format_ident!("child"));
473 if idx == 0 {
474 value.push(quote! {
477 match self.payload {
478 None => #parent_data_child::None,
479 Some(bytes) => #parent_data_child::Payload(bytes),
480 }
481 });
482 } else {
483 let prev_parent_id = rev_parents[idx - 1].id().unwrap();
485 let prev_parent_id_lower = prev_parent_id.to_lowercase().to_ident();
486 let prev_parent_id = prev_parent_id.to_ident();
487 value.push(quote! {
488 #parent_data_child::#prev_parent_id(#prev_parent_id_lower)
489 });
490 }
491 } else if scope.iter_children(parent).next().is_some() {
492 field.push(format_ident!("child"));
493 value.push(quote! { #parent_data_child::None });
494 }
495
496 quote! {
497 let #parent_id_lower = #parent_data {
498 #(#field: #value,)*
499 };
500 }
501 });
502
503 let children = scope.iter_children(decl).collect::<Vec<_>>();
504 let has_payload = decl.payload().is_some();
505 let has_children_or_payload = !children.is_empty() || has_payload;
506 let child = children.iter().map(|child| child.id().unwrap().to_ident()).collect::<Vec<_>>();
507 let child_data = child.iter().map(|child| format_ident!("{child}Data")).collect::<Vec<_>>();
508 let get_payload = (children.is_empty() && has_payload).then(|| {
509 quote! {
510 pub fn get_payload(&self) -> &[u8] {
511 match &self.#id_lower.child {
512 #id_data_child::Payload(bytes) => &bytes,
513 #id_data_child::None => &[],
514 }
515 }
516 }
517 });
518 let child_declaration = has_children_or_payload.then(|| {
519 quote! {
520 #[derive(Debug, Clone, PartialEq, Eq)]
521 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
522 pub enum #id_data_child {
523 #(#child(#child_data),)*
524 Payload(Bytes),
525 None,
526 }
527
528 impl #id_data_child {
529 fn get_total_size(&self) -> usize {
530 match self {
531 #(#id_data_child::#child(value) => value.get_total_size(),)*
532 #id_data_child::Payload(bytes) => bytes.len(),
533 #id_data_child::None => 0,
534 }
535 }
536 }
537
538 #[derive(Debug, Clone, PartialEq, Eq)]
539 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
540 pub enum #id_child {
541 #(#child(#child),)*
542 Payload(Bytes),
543 None,
544 }
545 }
546 });
547 let specialize = has_children_or_payload.then(|| {
548 quote! {
549 pub fn specialize(&self) -> #id_child {
550 match &self.#id_lower.child {
551 #(
552 #id_data_child::#child(_) =>
553 #id_child::#child(#child::new(self.#top_level_id_lower.clone()).unwrap()),
554 )*
555 #id_data_child::Payload(payload) => #id_child::Payload(payload.clone()),
556 #id_data_child::None => #id_child::None,
557 }
558 }
559 }
560 });
561
562 let builder_payload_field = has_children_or_payload.then(|| {
563 quote! {
564 pub payload: Option<Bytes>
565 }
566 });
567
568 let ancestor_packets = parent_ids[..parent_ids.len() - 1].iter().map(|id| id.to_ident());
569 let impl_from_and_try_from = (top_level_id != id).then(|| {
570 quote! {
571 #(
572 impl From<#id_packet> for #ancestor_packets {
573 fn from(packet: #id_packet) -> #ancestor_packets {
574 #ancestor_packets::new(packet.#top_level_id_lower).unwrap()
575 }
576 }
577 )*
578
579 impl TryFrom<#top_level_packet> for #id_packet {
580 type Error = DecodeError;
581 fn try_from(packet: #top_level_packet) -> Result<#id_packet, Self::Error> {
582 #id_packet::new(packet.#top_level_id_lower)
583 }
584 }
585 }
586 });
587
588 let (data_struct_decl, data_struct_impl) = generate_data_struct(scope, schema, endianness, id);
589
590 quote! {
591 #child_declaration
592
593 #data_struct_decl
594
595 #[derive(Debug, Clone, PartialEq, Eq)]
596 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
597 pub struct #id_packet {
598 #(
599 #[cfg_attr(feature = "serde", serde(flatten))]
600 #parent_lower_ids: #parent_data,
601 )*
602 }
603
604 #[derive(Debug)]
605 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
606 pub struct #id_builder {
607 #(pub #unconstrained_field_names: #unconstrained_field_types,)*
608 #builder_payload_field
609 }
610
611 #data_struct_impl
612
613 impl Packet for #id_packet {
614 fn encoded_len(&self) -> usize {
615 self.get_size()
616 }
617 fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
618 self.#top_level_id_lower.write_to(buf)
619 }
620 fn decode(_: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
621 unimplemented!("Rust legacy does not implement full packet trait")
622 }
623 }
624
625 impl TryFrom<#id_packet> for Bytes {
626 type Error = EncodeError;
627 fn try_from(packet: #id_packet) -> Result<Self, Self::Error> {
628 packet.encode_to_bytes()
629 }
630 }
631
632 impl TryFrom<#id_packet> for Vec<u8> {
633 type Error = EncodeError;
634 fn try_from(packet: #id_packet) -> Result<Self, Self::Error> {
635 packet.encode_to_vec()
636 }
637 }
638
639 #impl_from_and_try_from
640
641 impl #id_packet {
642 pub fn parse(#span: &[u8]) -> Result<Self, DecodeError> {
643 let mut cell = Cell::new(#span);
644 let packet = Self::parse_inner(&mut cell)?;
645 Ok(packet)
647 }
648
649 fn parse_inner(mut bytes: &mut Cell<&[u8]>) -> Result<Self, DecodeError> {
650 let data = #top_level_data::parse_inner(&mut bytes)?;
651 Self::new(data)
652 }
653
654 #specialize
655
656 fn new(#top_level_id_lower: #top_level_data) -> Result<Self, DecodeError> {
657 #(
658 let #parent_shifted_lower_ids = match &#parent_lower_ids.child {
659 #parent_data_child::#parent_shifted_ids(value) => value.clone(),
660 _ => return Err(DecodeError::InvalidChildError {
661 expected: stringify!(#parent_data_child::#parent_shifted_ids),
662 actual: format!("{:?}", &#parent_lower_ids.child),
663 }),
664 };
665 )*
666 Ok(Self { #(#parent_lower_ids),* })
667 }
668
669 #(pub fn #all_field_getter_names(&self) -> #all_field_borrows #all_field_types {
670 #all_field_borrows #all_field_self_field.#all_field_names
671 })*
672
673 #get_payload
674
675 fn write_to(&self, buffer: &mut impl BufMut) -> Result<(), EncodeError> {
676 self.#id_lower.write_to(buffer)
677 }
678
679 pub fn get_size(&self) -> usize {
680 self.#top_level_id_lower.get_size()
681 }
682 }
683
684 impl #id_builder {
685 pub fn build(self) -> #id_packet {
686 #(#builder_assignments;)*
687 #id_packet::new(#top_level_id_lower).unwrap()
688 }
689 }
690
691 #(
692 impl From<#id_builder> for #parent_packet {
693 fn from(builder: #id_builder) -> #parent_packet {
694 builder.build().into()
695 }
696 }
697 )*
698 }
699}
700
701fn generate_struct_decl(
703 scope: &analyzer::Scope<'_>,
704 schema: &analyzer::Schema,
705 endianness: ast::EndiannessValue,
706 id: &str,
707) -> proc_macro2::TokenStream {
708 let (struct_decl, struct_impl) = generate_data_struct(scope, schema, endianness, id);
709 quote! {
710 #struct_decl
711 #struct_impl
712 }
713}
714
715fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2::TokenStream {
727 fn enum_default_tag(tags: &[ast::Tag]) -> Option<ast::TagOther> {
729 tags.iter()
730 .filter_map(|tag| match tag {
731 ast::Tag::Other(tag) => Some(tag.clone()),
732 _ => None,
733 })
734 .next()
735 }
736
737 fn enum_is_complete(tags: &[ast::Tag], max: usize) -> bool {
740 let mut ranges = tags
741 .iter()
742 .filter_map(|tag| match tag {
743 ast::Tag::Value(tag) => Some((tag.value, tag.value)),
744 ast::Tag::Range(tag) => Some(tag.range.clone().into_inner()),
745 _ => None,
746 })
747 .collect::<Vec<_>>();
748 ranges.sort_unstable();
749 ranges.first().unwrap().0 == 0
750 && ranges.last().unwrap().1 == max
751 && ranges.windows(2).all(|window| {
752 if let [left, right] = window {
753 left.1 == right.0 - 1
754 } else {
755 false
756 }
757 })
758 }
759
760 fn enum_is_primitive(tags: &[ast::Tag]) -> bool {
762 tags.iter().all(|tag| matches!(tag, ast::Tag::Value(_)))
763 }
764
765 fn scalar_max(width: usize) -> usize {
767 if width >= usize::BITS as usize {
768 usize::MAX
769 } else {
770 (1 << width) - 1
771 }
772 }
773
774 fn format_tag_ident(id: &str) -> proc_macro2::TokenStream {
776 let id = format_ident!("{}", id.to_upper_camel_case());
777 quote! { #id }
778 }
779
780 fn format_value(value: usize) -> LitInt {
782 syn::parse_str::<syn::LitInt>(&format!("{:#x}", value)).unwrap()
783 }
784
785 let backing_type = types::Integer::new(width);
787 let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
788 let range_max = scalar_max(width);
789 let default_tag = enum_default_tag(tags);
790 let is_open = default_tag.is_some();
791 let is_complete = enum_is_complete(tags, scalar_max(width));
792 let is_primitive = enum_is_primitive(tags);
793 let name = id.to_ident();
794
795 let use_variant_values = is_primitive && (is_complete || !is_open);
798 let repr_u64 = use_variant_values.then(|| quote! { #[repr(u64)] });
799 let mut variants = vec![];
800 for tag in tags.iter() {
801 match tag {
802 ast::Tag::Value(tag) if use_variant_values => {
803 let id = format_tag_ident(&tag.id);
804 let value = format_value(tag.value);
805 variants.push(quote! { #id = #value })
806 }
807 ast::Tag::Value(tag) => variants.push(format_tag_ident(&tag.id)),
808 ast::Tag::Range(tag) => {
809 variants.extend(tag.tags.iter().map(|tag| format_tag_ident(&tag.id)));
810 let id = format_tag_ident(&tag.id);
811 variants.push(quote! { #id(Private<#backing_type>) })
812 }
813 ast::Tag::Other(_) => (),
814 }
815 }
816
817 let mut from_cases = vec![];
819 for tag in tags.iter() {
820 match tag {
821 ast::Tag::Value(tag) => {
822 let id = format_tag_ident(&tag.id);
823 let value = format_value(tag.value);
824 from_cases.push(quote! { #value => Ok(#name::#id) })
825 }
826 ast::Tag::Range(tag) => {
827 from_cases.extend(tag.tags.iter().map(|tag| {
828 let id = format_tag_ident(&tag.id);
829 let value = format_value(tag.value);
830 quote! { #value => Ok(#name::#id) }
831 }));
832 let id = format_tag_ident(&tag.id);
833 let start = format_value(*tag.range.start());
834 let end = format_value(*tag.range.end());
835 from_cases.push(quote! { #start ..= #end => Ok(#name::#id(Private(value))) })
836 }
837 ast::Tag::Other(_) => (),
838 }
839 }
840
841 let mut into_cases = vec![];
843 for tag in tags.iter() {
844 match tag {
845 ast::Tag::Value(tag) => {
846 let id = format_tag_ident(&tag.id);
847 let value = format_value(tag.value);
848 into_cases.push(quote! { #name::#id => #value })
849 }
850 ast::Tag::Range(tag) => {
851 into_cases.extend(tag.tags.iter().map(|tag| {
852 let id = format_tag_ident(&tag.id);
853 let value = format_value(tag.value);
854 quote! { #name::#id => #value }
855 }));
856 let id = format_tag_ident(&tag.id);
857 into_cases.push(quote! { #name::#id(Private(value)) => *value })
858 }
859 ast::Tag::Other(_) => (),
860 }
861 }
862
863 if !is_complete && is_open {
865 let unknown_id = format_tag_ident(&default_tag.unwrap().id);
866 let range_max = format_value(range_max);
867 variants.push(quote! { #unknown_id(Private<#backing_type>) });
868 from_cases.push(quote! { 0..=#range_max => Ok(#name::#unknown_id(Private(value))) });
869 into_cases.push(quote! { #name::#unknown_id(Private(value)) => *value });
870 }
871
872 if backing_type.width != width || (!is_complete && !is_open) {
875 from_cases.push(quote! { _ => Err(value) });
876 }
877
878 let derived_signed_into_types = [8, 16, 32, 64]
881 .into_iter()
882 .filter(|w| *w > width)
883 .map(|w| syn::parse_str::<syn::Type>(&format!("i{}", w)).unwrap());
884 let derived_unsigned_into_types = [8, 16, 32, 64]
885 .into_iter()
886 .filter(|w| *w >= width && *w != backing_type.width)
887 .map(|w| syn::parse_str::<syn::Type>(&format!("u{}", w)).unwrap());
888 let derived_into_types = derived_signed_into_types.chain(derived_unsigned_into_types);
889
890 quote! {
891 #repr_u64
892 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
893 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
894 #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
895 pub enum #name {
896 #(#variants,)*
897 }
898
899 impl TryFrom<#backing_type> for #name {
900 type Error = #backing_type;
901 fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
902 match value {
903 #(#from_cases,)*
904 }
905 }
906 }
907
908 impl From<&#name> for #backing_type {
909 fn from(value: &#name) -> Self {
910 match value {
911 #(#into_cases,)*
912 }
913 }
914 }
915
916 impl From<#name> for #backing_type {
917 fn from(value: #name) -> Self {
918 (&value).into()
919 }
920 }
921
922 #(impl From<#name> for #derived_into_types {
923 fn from(value: #name) -> Self {
924 #backing_type::from(value) as Self
925 }
926 })*
927 }
928}
929
930fn generate_custom_field_decl(id: &str, width: usize) -> proc_macro2::TokenStream {
935 let id = id.to_ident();
936 let backing_type = types::Integer::new(width);
937 let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
938 let max_value = mask_bits(width, &format!("u{}", backing_type.width));
939 let common = quote! {
940 impl From<&#id> for #backing_type {
941 fn from(value: &#id) -> #backing_type {
942 value.0
943 }
944 }
945
946 impl From<#id> for #backing_type {
947 fn from(value: #id) -> #backing_type {
948 value.0
949 }
950 }
951 };
952
953 if backing_type.width == width {
954 quote! {
955 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
956 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
957 #[cfg_attr(feature = "serde", serde(from = #backing_type_str, into = #backing_type_str))]
958 pub struct #id(#backing_type);
959
960 #common
961
962 impl From<#backing_type> for #id {
963 fn from(value: #backing_type) -> Self {
964 #id(value)
965 }
966 }
967 }
968 } else {
969 quote! {
970 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
971 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
972 #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
973 pub struct #id(#backing_type);
974
975 #common
976
977 impl TryFrom<#backing_type> for #id {
978 type Error = #backing_type;
979 fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
980 if value > #max_value {
981 Err(value)
982 } else {
983 Ok(#id(value))
984 }
985 }
986 }
987 }
988 }
989}
990
991fn generate_decl(
992 scope: &analyzer::Scope<'_>,
993 schema: &analyzer::Schema,
994 file: &ast::File,
995 decl: &ast::Decl,
996) -> proc_macro2::TokenStream {
997 match &decl.desc {
998 ast::DeclDesc::Packet { id, .. } => {
999 generate_packet_decl(scope, schema, file.endianness.value, id)
1000 }
1001 ast::DeclDesc::Struct { id, parent_id: None, .. } => {
1002 generate_struct_decl(scope, schema, file.endianness.value, id)
1008 }
1009 ast::DeclDesc::Enum { id, tags, width } => generate_enum_decl(id, tags, *width),
1010 ast::DeclDesc::CustomField { id, width: Some(width), .. } => {
1011 generate_custom_field_decl(id, *width)
1012 }
1013 _ => todo!("unsupported Decl::{:?}", decl),
1014 }
1015}
1016
1017pub fn generate_tokens(
1022 sources: &ast::SourceDatabase,
1023 file: &ast::File,
1024) -> proc_macro2::TokenStream {
1025 let source = sources.get(file.file).expect("could not read source");
1026 let preamble = preamble::generate(Path::new(source.name()));
1027
1028 let scope = analyzer::Scope::new(file).expect("could not create scope");
1029 let schema = analyzer::Schema::new(file);
1030 let decls = file.declarations.iter().map(|decl| generate_decl(&scope, &schema, file, decl));
1031 quote! {
1032 #preamble
1033
1034 #(#decls)*
1035 }
1036}
1037
1038pub fn generate(sources: &ast::SourceDatabase, file: &ast::File) -> String {
1043 let syntax_tree = syn::parse2(generate_tokens(sources, file)).expect("Could not parse code");
1044 prettyplease::unparse(&syntax_tree)
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050 use crate::analyzer;
1051 use crate::ast;
1052 use crate::parser::parse_inline;
1053 use crate::test_utils::{assert_snapshot_eq, format_rust};
1054 use googletest::prelude::{elements_are, eq, expect_that};
1055 use paste::paste;
1056
1057 pub fn parse_str(text: &str) -> ast::File {
1063 let mut db = ast::SourceDatabase::new();
1064 let file = parse_inline(&mut db, "stdin", String::from(text)).expect("parse error");
1065 analyzer::analyze(&file).expect("analyzer error")
1066 }
1067
1068 #[googletest::test]
1069 fn test_find_constrained_parent_fields() -> googletest::Result<()> {
1070 let code = "
1071 little_endian_packets
1072 packet Parent {
1073 a: 8,
1074 b: 8,
1075 c: 8,
1076 _payload_,
1077 }
1078 packet Child: Parent(a = 10) {
1079 x: 8,
1080 _payload_,
1081 }
1082 packet GrandChild: Child(b = 20) {
1083 y: 8,
1084 _payload_,
1085 }
1086 packet GrandGrandChild: GrandChild(c = 30) {
1087 z: 8,
1088 }
1089 ";
1090 let file = parse_str(code);
1091 let scope = analyzer::Scope::new(&file).unwrap();
1092 let find_fields = |id| {
1093 find_constrained_parent_fields(&scope, id)
1094 .iter()
1095 .map(|field| field.id().unwrap())
1096 .collect::<Vec<_>>()
1097 };
1098
1099 expect_that!(find_fields("Parent"), elements_are![]);
1100 expect_that!(find_fields("Child"), elements_are![eq("b"), eq("c")]);
1101 expect_that!(find_fields("GrandChild"), elements_are![eq("c")]);
1102 expect_that!(find_fields("GrandGrandChild"), elements_are![]);
1103 Ok(())
1104 }
1105
1106 macro_rules! make_pdl_test {
1121 ($name:ident, $code:expr, $endianness:ident) => {
1122 paste! {
1123 #[test]
1124 fn [< test_ $name _ $endianness >]() {
1125 let name = stringify!($name);
1126 let endianness = stringify!($endianness);
1127 let code = format!("{endianness}_packets\n{}", $code);
1128 let mut db = ast::SourceDatabase::new();
1129 let file = parse_inline(&mut db, "test", code).unwrap();
1130 let file = analyzer::analyze(&file).unwrap();
1131 let actual_code = generate(&db, &file);
1132 assert_snapshot_eq(
1133 &format!("tests/generated/rust_legacy/{name}_{endianness}.rs"),
1134 &format_rust(&actual_code),
1135 );
1136 }
1137 }
1138 };
1139 }
1140
1141 macro_rules! test_pdl {
1147 ($name:ident, $code:expr $(,)?) => {
1148 make_pdl_test!($name, $code, little_endian);
1149 make_pdl_test!($name, $code, big_endian);
1150 };
1151 }
1152
1153 test_pdl!(packet_decl_empty, "packet Foo {}");
1154
1155 test_pdl!(packet_decl_8bit_scalar, " packet Foo { x: 8 }");
1156 test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
1157 test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");
1158
1159 test_pdl!(
1160 enum_declaration,
1161 r#"
1162 enum IncompleteTruncatedClosed : 3 {
1163 A = 0,
1164 B = 1,
1165 }
1166
1167 enum IncompleteTruncatedOpen : 3 {
1168 A = 0,
1169 B = 1,
1170 UNKNOWN = ..
1171 }
1172
1173 enum IncompleteTruncatedClosedWithRange : 3 {
1174 A = 0,
1175 B = 1..6 {
1176 X = 1,
1177 Y = 2,
1178 }
1179 }
1180
1181 enum IncompleteTruncatedOpenWithRange : 3 {
1182 A = 0,
1183 B = 1..6 {
1184 X = 1,
1185 Y = 2,
1186 },
1187 UNKNOWN = ..
1188 }
1189
1190 enum CompleteTruncated : 3 {
1191 A = 0,
1192 B = 1,
1193 C = 2,
1194 D = 3,
1195 E = 4,
1196 F = 5,
1197 G = 6,
1198 H = 7,
1199 }
1200
1201 enum CompleteTruncatedWithRange : 3 {
1202 A = 0,
1203 B = 1..7 {
1204 X = 1,
1205 Y = 2,
1206 }
1207 }
1208
1209 enum CompleteWithRange : 8 {
1210 A = 0,
1211 B = 1,
1212 C = 2..255,
1213 }
1214 "#
1215 );
1216
1217 test_pdl!(
1218 custom_field_declaration,
1219 r#"
1220 // Still unsupported.
1221 // custom_field Dynamic "dynamic"
1222
1223 // Should generate a type with From<u32> implementation.
1224 custom_field ExactSize : 32 "exact_size"
1225
1226 // Should generate a type with TryFrom<u32> implementation.
1227 custom_field TruncatedSize : 24 "truncated_size"
1228 "#
1229 );
1230
1231 test_pdl!(
1232 packet_decl_simple_scalars,
1233 r#"
1234 packet Foo {
1235 x: 8,
1236 y: 16,
1237 z: 24,
1238 }
1239 "#
1240 );
1241
1242 test_pdl!(
1243 packet_decl_complex_scalars,
1244 r#"
1245 packet Foo {
1246 a: 3,
1247 b: 8,
1248 c: 5,
1249 d: 24,
1250 e: 12,
1251 f: 4,
1252 }
1253 "#,
1254 );
1255
1256 test_pdl!(
1259 packet_decl_mask_scalar_value,
1260 r#"
1261 packet Foo {
1262 a: 2,
1263 b: 24,
1264 c: 6,
1265 }
1266 "#,
1267 );
1268
1269 test_pdl!(
1270 struct_decl_complex_scalars,
1271 r#"
1272 struct Foo {
1273 a: 3,
1274 b: 8,
1275 c: 5,
1276 d: 24,
1277 e: 12,
1278 f: 4,
1279 }
1280 "#,
1281 );
1282
1283 test_pdl!(packet_decl_8bit_enum, " enum Foo : 8 { A = 1, B = 2 } packet Bar { x: Foo }");
1284 test_pdl!(packet_decl_24bit_enum, "enum Foo : 24 { A = 1, B = 2 } packet Bar { x: Foo }");
1285 test_pdl!(packet_decl_64bit_enum, "enum Foo : 64 { A = 1, B = 2 } packet Bar { x: Foo }");
1286
1287 test_pdl!(
1288 packet_decl_mixed_scalars_enums,
1289 "
1290 enum Enum7 : 7 {
1291 A = 1,
1292 B = 2,
1293 }
1294
1295 enum Enum9 : 9 {
1296 A = 1,
1297 B = 2,
1298 }
1299
1300 packet Foo {
1301 x: Enum7,
1302 y: 5,
1303 z: Enum9,
1304 w: 3,
1305 }
1306 "
1307 );
1308
1309 test_pdl!(packet_decl_8bit_scalar_array, " packet Foo { x: 8[3] }");
1310 test_pdl!(packet_decl_24bit_scalar_array, "packet Foo { x: 24[5] }");
1311 test_pdl!(packet_decl_64bit_scalar_array, "packet Foo { x: 64[7] }");
1312
1313 test_pdl!(
1314 packet_decl_8bit_enum_array,
1315 "enum Foo : 8 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[3] }"
1316 );
1317 test_pdl!(
1318 packet_decl_24bit_enum_array,
1319 "enum Foo : 24 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[5] }"
1320 );
1321 test_pdl!(
1322 packet_decl_64bit_enum_array,
1323 "enum Foo : 64 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[7] }"
1324 );
1325
1326 test_pdl!(
1327 packet_decl_array_dynamic_count,
1328 "
1329 packet Foo {
1330 _count_(x): 5,
1331 padding: 3,
1332 x: 24[]
1333 }
1334 "
1335 );
1336
1337 test_pdl!(
1338 packet_decl_array_dynamic_size,
1339 "
1340 packet Foo {
1341 _size_(x): 5,
1342 padding: 3,
1343 x: 24[]
1344 }
1345 "
1346 );
1347
1348 test_pdl!(
1349 packet_decl_array_unknown_element_width_dynamic_size,
1350 "
1351 struct Foo {
1352 _count_(a): 40,
1353 a: 16[],
1354 }
1355
1356 packet Bar {
1357 _size_(x): 40,
1358 x: Foo[],
1359 }
1360 "
1361 );
1362
1363 test_pdl!(
1364 packet_decl_array_unknown_element_width_dynamic_count,
1365 "
1366 struct Foo {
1367 _count_(a): 40,
1368 a: 16[],
1369 }
1370
1371 packet Bar {
1372 _count_(x): 40,
1373 x: Foo[],
1374 }
1375 "
1376 );
1377
1378 test_pdl!(
1379 packet_decl_array_with_padding,
1380 "
1381 struct Foo {
1382 _count_(a): 40,
1383 a: 16[],
1384 }
1385
1386 packet Bar {
1387 a: Foo[],
1388 _padding_ [128],
1389 }
1390 "
1391 );
1392
1393 test_pdl!(
1394 packet_decl_reserved_field,
1395 "
1396 packet Foo {
1397 _reserved_: 40,
1398 }
1399 "
1400 );
1401
1402 test_pdl!(
1403 packet_decl_custom_field,
1404 r#"
1405 custom_field Bar1 : 24 "exact"
1406 custom_field Bar2 : 32 "truncated"
1407
1408 packet Foo {
1409 a: Bar1,
1410 b: Bar2,
1411 }
1412 "#
1413 );
1414
1415 test_pdl!(
1416 packet_decl_fixed_scalar_field,
1417 "
1418 packet Foo {
1419 _fixed_ = 7 : 7,
1420 b: 57,
1421 }
1422 "
1423 );
1424
1425 test_pdl!(
1426 packet_decl_fixed_enum_field,
1427 "
1428 enum Enum7 : 7 {
1429 A = 1,
1430 B = 2,
1431 }
1432
1433 packet Foo {
1434 _fixed_ = A : Enum7,
1435 b: 57,
1436 }
1437 "
1438 );
1439
1440 test_pdl!(
1441 packet_decl_payload_field_variable_size,
1442 "
1443 packet Foo {
1444 a: 8,
1445 _size_(_payload_): 8,
1446 _payload_,
1447 b: 16,
1448 }
1449 "
1450 );
1451
1452 test_pdl!(
1453 packet_decl_payload_field_unknown_size,
1454 "
1455 packet Foo {
1456 a: 24,
1457 _payload_,
1458 }
1459 "
1460 );
1461
1462 test_pdl!(
1463 packet_decl_payload_field_unknown_size_terminal,
1464 "
1465 packet Foo {
1466 _payload_,
1467 a: 24,
1468 }
1469 "
1470 );
1471
1472 test_pdl!(
1473 packet_decl_child_packets,
1474 "
1475 enum Enum16 : 16 {
1476 A = 1,
1477 B = 2,
1478 }
1479
1480 packet Foo {
1481 a: 8,
1482 b: Enum16,
1483 _size_(_payload_): 8,
1484 _payload_
1485 }
1486
1487 packet Bar : Foo (a = 100) {
1488 x: 8,
1489 }
1490
1491 packet Baz : Foo (b = B) {
1492 y: 16,
1493 }
1494 "
1495 );
1496
1497 test_pdl!(
1498 packet_decl_grand_children,
1499 "
1500 enum Enum16 : 16 {
1501 A = 1,
1502 B = 2,
1503 }
1504
1505 packet Parent {
1506 foo: Enum16,
1507 bar: Enum16,
1508 baz: Enum16,
1509 _size_(_payload_): 8,
1510 _payload_
1511 }
1512
1513 packet Child : Parent (foo = A) {
1514 quux: Enum16,
1515 _payload_,
1516 }
1517
1518 packet GrandChild : Child (bar = A, quux = A) {
1519 _body_,
1520 }
1521
1522 packet GrandGrandChild : GrandChild (baz = A) {
1523 _body_,
1524 }
1525 "
1526 );
1527
1528 test_pdl!(
1529 packet_decl_parent_with_no_payload,
1530 "
1531 enum Enum8 : 8 {
1532 A = 0,
1533 }
1534
1535 packet Parent {
1536 v : Enum8,
1537 }
1538
1539 packet Child : Parent (v = A) {
1540 }
1541 "
1542 );
1543
1544 test_pdl!(
1545 packet_decl_parent_with_alias_child,
1546 "
1547 enum Enum8 : 8 {
1548 A = 0,
1549 B = 1,
1550 C = 2,
1551 }
1552
1553 packet Parent {
1554 v : Enum8,
1555 _payload_,
1556 }
1557
1558 packet AliasChild : Parent {
1559 _payload_
1560 }
1561
1562 packet NormalChild : Parent (v = A) {
1563 }
1564
1565 packet NormalGrandChild1 : AliasChild (v = B) {
1566 }
1567
1568 packet NormalGrandChild2 : AliasChild (v = C) {
1569 _payload_
1570 }
1571 "
1572 );
1573
1574 test_pdl!(
1575 reserved_identifier,
1576 "
1577 packet Test {
1578 type: 8,
1579 }
1580 "
1581 );
1582
1583 test_pdl!(
1584 payload_with_size_modifier,
1585 "
1586 packet Test {
1587 _size_(_payload_): 8,
1588 _payload_ : [+1],
1589 }
1590 "
1591 );
1592
1593 }