use crate::{analyzer, ast};
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::path::Path;
use syn::LitInt;
mod parser;
mod preamble;
mod serializer;
pub mod test;
mod types;
pub use heck::ToUpperCamelCase;
use parser::FieldParser;
pub trait ToIdent {
fn to_ident(self) -> proc_macro2::Ident;
}
impl ToIdent for &'_ str {
fn to_ident(self) -> proc_macro2::Ident {
match self {
"as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
| "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match" | "mod"
| "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self" | "static" | "struct"
| "super" | "trait" | "true" | "type" | "unsafe" | "use" | "where" | "while"
| "async" | "await" | "dyn" | "abstract" | "become" | "box" | "do" | "final"
| "macro" | "override" | "priv" | "typeof" | "unsized" | "virtual" | "yield"
| "try" => format_ident!("r#{}", self),
_ => format_ident!("{}", self),
}
}
}
pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
let suffix = if n > 31 { format!("_{suffix}") } else { String::new() };
let hex_digits = format!("{:x}", (1u64 << n) - 1)
.as_bytes()
.rchunks(4)
.rev()
.map(|chunk| std::str::from_utf8(chunk).unwrap())
.collect::<Vec<&str>>()
.join("_");
syn::parse_str::<syn::LitInt>(&format!("0x{hex_digits}{suffix}")).unwrap()
}
fn packet_data_fields<'a>(
scope: &'a analyzer::Scope<'a>,
decl: &'a ast::Decl,
) -> Vec<&'a ast::Field> {
let all_constraints = HashMap::<String, _>::from_iter(
scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
);
scope
.iter_fields(decl)
.filter(|f| f.id().is_some())
.filter(|f| !matches!(&f.desc, ast::FieldDesc::Flag { .. }))
.filter(|f| !all_constraints.contains_key(f.id().unwrap()))
.collect::<Vec<_>>()
}
fn packet_constant_fields<'a>(
scope: &'a analyzer::Scope<'a>,
decl: &'a ast::Decl,
) -> Vec<&'a ast::Field> {
let all_constraints = HashMap::<String, _>::from_iter(
scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
);
scope
.iter_fields(decl)
.filter(|f| f.id().is_some())
.filter(|f| all_constraints.contains_key(f.id().unwrap()))
.collect::<Vec<_>>()
}
fn constraint_value(
fields: &[&'_ ast::Field],
constraint: &ast::Constraint,
) -> proc_macro2::TokenStream {
match constraint {
ast::Constraint { value: Some(value), .. } => {
let value = proc_macro2::Literal::usize_unsuffixed(*value);
quote!(#value)
}
ast::Constraint { tag_id: Some(tag_id), .. } => {
let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
let type_id = fields
.iter()
.filter_map(|f| match &f.desc {
ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
Some(type_id.to_ident())
}
_ => None,
})
.next()
.unwrap();
quote!(#type_id::#tag_id)
}
_ => unreachable!("Invalid constraint: {constraint:?}"),
}
}
fn constraint_value_str(fields: &[&'_ ast::Field], constraint: &ast::Constraint) -> String {
match constraint {
ast::Constraint { value: Some(value), .. } => {
format!("{}", value)
}
ast::Constraint { tag_id: Some(tag_id), .. } => {
let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
let type_id = fields
.iter()
.filter_map(|f| match &f.desc {
ast::FieldDesc::Typedef { id, type_id } if id == &constraint.id => {
Some(type_id.to_ident())
}
_ => None,
})
.next()
.unwrap();
format!("{}::{}", type_id, tag_id)
}
_ => unreachable!("Invalid constraint: {constraint:?}"),
}
}
fn implements_copy(scope: &analyzer::Scope<'_>, field: &ast::Field) -> bool {
match &field.desc {
ast::FieldDesc::Scalar { .. } => true,
ast::FieldDesc::Typedef { type_id, .. } => match &scope.typedef[type_id].desc {
ast::DeclDesc::Enum { .. } | ast::DeclDesc::CustomField { .. } => true,
ast::DeclDesc::Struct { .. } => false,
desc => unreachable!("unexpected declaration: {desc:?}"),
},
ast::FieldDesc::Array { .. } => false,
_ => todo!(),
}
}
fn generate_root_packet_decl(
scope: &analyzer::Scope<'_>,
schema: &analyzer::Schema,
endianness: ast::EndiannessValue,
id: &str,
) -> proc_macro2::TokenStream {
let decl = scope.typedef[id];
let name = id.to_ident();
let child_name = format_ident!("{id}Child");
let data_fields = packet_data_fields(scope, decl);
let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
let data_field_borrows = data_fields
.iter()
.map(|f| {
if implements_copy(scope, f) {
quote! {}
} else {
quote! { & }
}
})
.collect::<Vec<_>>();
let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
let payload_accessor =
decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
let parser_span = format_ident!("buf");
let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
for field in decl.fields() {
field_parser.add(field);
}
let mut parsed_field_ids = vec![];
if decl.payload().is_some() {
parsed_field_ids.push(format_ident!("payload"));
}
for f in &data_fields {
let id = f.id().unwrap().to_ident();
parsed_field_ids.push(id);
}
let (encode_fields, encoded_len) =
serializer::encode(scope, schema, endianness, "buf".to_ident(), decl);
let encode = quote! {
fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
#encode_fields
Ok(())
}
};
let encoded_len = quote! {
fn encoded_len(&self) -> usize {
#encoded_len
}
};
let decode = quote! {
fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
#field_parser
Ok((Self { #( #parsed_field_ids, )* }, buf))
}
};
let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
let child_struct = (!children_decl.is_empty()).then(|| {
let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum #child_name {
#( #children_ids(#children_ids), )*
None,
}
}
});
let specialize = (!children_decl.is_empty()).then(|| {
let constraint_fields = children_decl
.iter()
.flat_map(|decl| decl.constraints().map(|c| c.id.to_owned()))
.collect::<BTreeSet<_>>();
let constraint_ids = constraint_fields.iter().map(|id| id.to_ident());
let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
let case_values = children_decl.iter().map(|child_decl| {
let constraint_values = constraint_fields.iter().map(|id| {
let constraint = child_decl.constraints().find(|c| &c.id == id);
match constraint {
Some(constraint) => constraint_value(&data_fields, constraint),
None => quote! { _ },
}
});
quote! { (#( #constraint_values, )*) }
});
let default_case = quote! { _ => #child_name::None, };
quote! {
pub fn specialize(&self) -> Result<#child_name, DecodeError> {
Ok(
match (#( self.#constraint_ids, )*) {
#( #case_values =>
#child_name::#children_ids(self.try_into()?), )*
#default_case
}
)
}
}
});
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct #name {
#( pub #data_field_ids: #data_field_types, )*
#payload_field
}
#child_struct
impl #name {
#specialize
#payload_accessor
#(
pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
#data_field_borrows self.#data_field_ids
}
)*
}
impl Packet for #name {
#encoded_len
#encode
#decode
}
}
}
fn generate_derived_packet_decl(
scope: &analyzer::Scope<'_>,
schema: &analyzer::Schema,
endianness: ast::EndiannessValue,
id: &str,
) -> proc_macro2::TokenStream {
let decl = scope.typedef[id];
let name = id.to_ident();
let parent_decl = scope.get_parent(decl).unwrap();
let parent_name = parent_decl.id().unwrap().to_ident();
let child_name = format_ident!("{id}Child");
let all_constraints = HashMap::<String, _>::from_iter(
scope.iter_constraints(decl).map(|c| (c.id.to_string(), c)),
);
let all_fields = scope.iter_fields(decl).collect::<Vec<_>>();
let data_fields = packet_data_fields(scope, decl);
let data_field_ids = data_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
let data_field_types = data_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
let data_field_borrows = data_fields
.iter()
.map(|f| {
if implements_copy(scope, f) {
quote! {}
} else {
quote! { & }
}
})
.collect::<Vec<_>>();
let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
let payload_accessor =
decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });
let parent_data_fields = packet_data_fields(scope, parent_decl);
let constant_fields = packet_constant_fields(scope, decl);
let constant_field_ids =
constant_fields.iter().map(|f| f.id().unwrap().to_ident()).collect::<Vec<_>>();
let constant_field_types =
constant_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
let constant_field_values = constant_fields.iter().map(|f| {
let c = all_constraints.get(f.id().unwrap()).unwrap();
constraint_value(&all_fields, c)
});
let parser_span = format_ident!("buf");
let mut field_parser = FieldParser::new(scope, schema, endianness, id, &parser_span);
for field in decl.fields() {
field_parser.add(field);
}
let mut parsed_field_ids = vec![];
let mut copied_field_ids = vec![];
let mut cloned_field_ids = vec![];
if decl.payload().is_some() {
parsed_field_ids.push(format_ident!("payload"));
}
for f in &data_fields {
let id = f.id().unwrap().to_ident();
if decl.fields().any(|ff| f.id() == ff.id()) {
parsed_field_ids.push(id);
} else if implements_copy(scope, f) {
copied_field_ids.push(id);
} else {
cloned_field_ids.push(id);
}
}
let (partial_field_serializer, field_serializer, encoded_len) =
serializer::encode_partial(scope, schema, endianness, "buf".to_ident(), decl);
let encode_partial = quote! {
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
#partial_field_serializer
Ok(())
}
};
let encode = quote! {
fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
#field_serializer
Ok(())
}
};
let encoded_len = quote! {
fn encoded_len(&self) -> usize {
#encoded_len
}
};
let constraint_checks = decl.constraints().map(|c| {
let field_id = c.id.to_ident();
let field_name = &c.id;
let packet_name = id;
let value = constraint_value(&parent_data_fields, c);
let value_str = constraint_value_str(&parent_data_fields, c);
quote! {
if parent.#field_id() != #value {
return Err(DecodeError::InvalidFieldValue {
packet: #packet_name,
field: #field_name,
expected: #value_str,
actual: format!("{:?}", parent.#field_id()),
})
}
}
});
let decode_partial = if parent_decl.payload().is_some() {
quote! {
fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
#( #constraint_checks )*
#field_parser
if buf.is_empty() {
Ok(Self {
#( #parsed_field_ids, )*
#( #copied_field_ids: parent.#copied_field_ids, )*
#( #cloned_field_ids: parent.#cloned_field_ids.clone(), )*
})
} else {
Err(DecodeError::TrailingBytes)
}
}
}
} else {
quote! {
fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
#( #constraint_checks )*
Ok(Self {
#( #copied_field_ids: parent.#copied_field_ids, )*
})
}
}
};
let decode =
quote! {
fn decode(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
let (parent, trailing_bytes) = #parent_name::decode(buf)?;
let packet = Self::decode_partial(&parent)?;
Ok((packet, trailing_bytes))
}
};
let into_parent = {
let parent_data_field_ids = parent_data_fields.iter().map(|f| f.id().unwrap().to_ident());
let parent_data_field_values = parent_data_fields.iter().map(|f| {
let id = f.id().unwrap().to_ident();
match all_constraints.get(f.id().unwrap()) {
Some(c) => constraint_value(&parent_data_fields, c),
None => quote! { packet.#id },
}
});
if parent_decl.payload().is_some() {
quote! {
impl TryFrom<&#name> for #parent_name {
type Error = EncodeError;
fn try_from(packet: &#name) -> Result<#parent_name, Self::Error> {
let mut payload = Vec::new();
packet.encode_partial(&mut payload)?;
Ok(#parent_name {
#( #parent_data_field_ids: #parent_data_field_values, )*
payload,
})
}
}
impl TryFrom<#name> for #parent_name {
type Error = EncodeError;
fn try_from(packet: #name) -> Result<#parent_name, Self::Error> {
(&packet).try_into()
}
}
}
} else {
quote! {
impl From<&#name> for #parent_name {
fn from(packet: &#name) -> #parent_name {
#parent_name {
#( #parent_data_field_ids: #parent_data_field_values, )*
}
}
}
impl From<#name> for #parent_name {
fn from(packet: #name) -> #parent_name {
(&packet).into()
}
}
}
}
};
let into_ancestors = scope.iter_parents(parent_decl).map(|ancestor_decl| {
let ancestor_name = ancestor_decl.id().unwrap().to_ident();
quote! {
impl TryFrom<&#name> for #ancestor_name {
type Error = EncodeError;
fn try_from(packet: &#name) -> Result<#ancestor_name, Self::Error> {
(&#parent_name::try_from(packet)?).try_into()
}
}
impl TryFrom<#name> for #ancestor_name {
type Error = EncodeError;
fn try_from(packet: #name) -> Result<#ancestor_name, Self::Error> {
(&packet).try_into()
}
}
}
});
let try_from_parent = quote! {
impl TryFrom<&#parent_name> for #name {
type Error = DecodeError;
fn try_from(parent: &#parent_name) -> Result<#name, Self::Error> {
#name::decode_partial(&parent)
}
}
impl TryFrom<#parent_name> for #name {
type Error = DecodeError;
fn try_from(parent: #parent_name) -> Result<#name, Self::Error> {
(&parent).try_into()
}
}
};
let children_decl = scope.iter_children(decl).collect::<Vec<_>>();
let child_struct = (!children_decl.is_empty()).then(|| {
let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum #child_name {
#( #children_ids(#children_ids), )*
None,
}
}
});
let specialize = (!children_decl.is_empty()).then(|| {
let constraint_fields = children_decl
.iter()
.flat_map(|decl| decl.constraints().map(|c| c.id.to_owned()))
.collect::<BTreeSet<_>>();
let constraint_ids = constraint_fields.iter().map(|id| id.to_ident());
let children_ids = children_decl.iter().map(|decl| decl.id().unwrap().to_ident());
let case_values = children_decl.iter().map(|child_decl| {
let constraint_values = constraint_fields.iter().map(|id| {
let constraint = child_decl.constraints().find(|c| &c.id == id);
match constraint {
Some(constraint) => constraint_value(&data_fields, constraint),
None => quote! { _ },
}
});
quote! { (#( #constraint_values, )*) }
});
let default_case = quote! { _ => #child_name::None, };
quote! {
pub fn specialize(&self) -> Result<#child_name, DecodeError> {
Ok(
match (#( self.#constraint_ids, )*) {
#( #case_values =>
#child_name::#children_ids(self.try_into()?), )*
#default_case
}
)
}
}
});
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct #name {
#( pub #data_field_ids: #data_field_types, )*
#payload_field
}
#try_from_parent
#into_parent
#( #into_ancestors )*
#child_struct
impl #name {
#specialize
#decode_partial
#encode_partial
#payload_accessor
#(
pub fn #data_field_ids(&self) -> #data_field_borrows #data_field_types {
#data_field_borrows self.#data_field_ids
}
)*
#(
pub fn #constant_field_ids(&self) -> #constant_field_types {
#constant_field_values
}
)*
}
impl Packet for #name {
#encoded_len
#encode
#decode
}
}
}
fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2::TokenStream {
fn enum_default_tag(tags: &[ast::Tag]) -> Option<ast::TagOther> {
tags.iter()
.filter_map(|tag| match tag {
ast::Tag::Other(tag) => Some(tag.clone()),
_ => None,
})
.next()
}
fn enum_is_complete(tags: &[ast::Tag], max: usize) -> bool {
let mut ranges = tags
.iter()
.filter_map(|tag| match tag {
ast::Tag::Value(tag) => Some((tag.value, tag.value)),
ast::Tag::Range(tag) => Some(tag.range.clone().into_inner()),
_ => None,
})
.collect::<Vec<_>>();
ranges.sort_unstable();
ranges.first().unwrap().0 == 0
&& ranges.last().unwrap().1 == max
&& ranges.windows(2).all(|window| {
if let [left, right] = window {
left.1 == right.0 - 1
} else {
false
}
})
}
fn enum_is_primitive(tags: &[ast::Tag]) -> bool {
tags.iter().all(|tag| matches!(tag, ast::Tag::Value(_)))
}
fn scalar_max(width: usize) -> usize {
if width >= usize::BITS as usize {
usize::MAX
} else {
(1 << width) - 1
}
}
fn format_tag_ident(id: &str) -> proc_macro2::TokenStream {
let id = format_ident!("{}", id.to_upper_camel_case());
quote! { #id }
}
fn format_value(value: usize) -> LitInt {
syn::parse_str::<syn::LitInt>(&format!("{:#x}", value)).unwrap()
}
let backing_type = types::Integer::new(width);
let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
let range_max = scalar_max(width);
let default_tag = enum_default_tag(tags);
let is_open = default_tag.is_some();
let is_complete = enum_is_complete(tags, scalar_max(width));
let is_primitive = enum_is_primitive(tags);
let name = id.to_ident();
let use_variant_values = is_primitive && (is_complete || !is_open);
let repr_u64 = use_variant_values.then(|| quote! { #[repr(u64)] });
let mut variants = vec![];
for tag in tags.iter() {
match tag {
ast::Tag::Value(tag) if use_variant_values => {
let id = format_tag_ident(&tag.id);
let value = format_value(tag.value);
variants.push(quote! { #id = #value })
}
ast::Tag::Value(tag) => variants.push(format_tag_ident(&tag.id)),
ast::Tag::Range(tag) => {
variants.extend(tag.tags.iter().map(|tag| format_tag_ident(&tag.id)));
let id = format_tag_ident(&tag.id);
variants.push(quote! { #id(Private<#backing_type>) })
}
ast::Tag::Other(_) => (),
}
}
let mut from_cases = vec![];
for tag in tags.iter() {
match tag {
ast::Tag::Value(tag) => {
let id = format_tag_ident(&tag.id);
let value = format_value(tag.value);
from_cases.push(quote! { #value => Ok(#name::#id) })
}
ast::Tag::Range(tag) => {
from_cases.extend(tag.tags.iter().map(|tag| {
let id = format_tag_ident(&tag.id);
let value = format_value(tag.value);
quote! { #value => Ok(#name::#id) }
}));
let id = format_tag_ident(&tag.id);
let start = format_value(*tag.range.start());
let end = format_value(*tag.range.end());
from_cases.push(quote! { #start ..= #end => Ok(#name::#id(Private(value))) })
}
ast::Tag::Other(_) => (),
}
}
let mut into_cases = vec![];
for tag in tags.iter() {
match tag {
ast::Tag::Value(tag) => {
let id = format_tag_ident(&tag.id);
let value = format_value(tag.value);
into_cases.push(quote! { #name::#id => #value })
}
ast::Tag::Range(tag) => {
into_cases.extend(tag.tags.iter().map(|tag| {
let id = format_tag_ident(&tag.id);
let value = format_value(tag.value);
quote! { #name::#id => #value }
}));
let id = format_tag_ident(&tag.id);
into_cases.push(quote! { #name::#id(Private(value)) => *value })
}
ast::Tag::Other(_) => (),
}
}
if !is_complete && is_open {
let unknown_id = format_tag_ident(&default_tag.unwrap().id);
let range_max = format_value(range_max);
variants.push(quote! { #unknown_id(Private<#backing_type>) });
from_cases.push(quote! { 0..=#range_max => Ok(#name::#unknown_id(Private(value))) });
into_cases.push(quote! { #name::#unknown_id(Private(value)) => *value });
}
if backing_type.width != width || (!is_complete && !is_open) {
from_cases.push(quote! { _ => Err(value) });
}
let derived_signed_into_types = [8, 16, 32, 64]
.into_iter()
.filter(|w| *w > width)
.map(|w| syn::parse_str::<syn::Type>(&format!("i{}", w)).unwrap());
let derived_unsigned_into_types = [8, 16, 32, 64]
.into_iter()
.filter(|w| *w >= width && *w != backing_type.width)
.map(|w| syn::parse_str::<syn::Type>(&format!("u{}", w)).unwrap());
let derived_into_types = derived_signed_into_types.chain(derived_unsigned_into_types);
quote! {
#repr_u64
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
pub enum #name {
#(#variants,)*
}
impl TryFrom<#backing_type> for #name {
type Error = #backing_type;
fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
match value {
#(#from_cases,)*
}
}
}
impl From<&#name> for #backing_type {
fn from(value: &#name) -> Self {
match value {
#(#into_cases,)*
}
}
}
impl From<#name> for #backing_type {
fn from(value: #name) -> Self {
(&value).into()
}
}
#(impl From<#name> for #derived_into_types {
fn from(value: #name) -> Self {
#backing_type::from(value) as Self
}
})*
}
}
fn generate_custom_field_decl(
endianness: ast::EndiannessValue,
id: &str,
width: usize,
) -> proc_macro2::TokenStream {
let name = id;
let id = id.to_ident();
let backing_type = types::Integer::new(width);
let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
let max_value = mask_bits(width, &format!("u{}", backing_type.width));
let size = proc_macro2::Literal::usize_unsuffixed(width / 8);
let read_value = types::get_uint(endianness, width, &format_ident!("buf"));
let read_value = if [8, 16, 32, 64].contains(&width) {
quote! { #read_value.into() }
} else {
quote! { (#read_value).try_into().unwrap() }
};
let write_value = types::put_uint(
endianness,
"e! { #backing_type::from(self) },
width,
&format_ident!("buf"),
);
let common = quote! {
impl From<&#id> for #backing_type {
fn from(value: &#id) -> #backing_type {
value.0
}
}
impl From<#id> for #backing_type {
fn from(value: #id) -> #backing_type {
value.0
}
}
impl Packet for #id {
fn decode(mut buf: &[u8]) -> Result<(Self, &[u8]), DecodeError> {
if buf.len() < #size {
return Err(DecodeError::InvalidLengthError {
obj: #name,
wanted: #size,
got: buf.len(),
})
}
Ok((#read_value, buf))
}
fn encode(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
#write_value;
Ok(())
}
fn encoded_len(&self) -> usize {
#size
}
}
};
if backing_type.width == width {
quote! {
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(from = #backing_type_str, into = #backing_type_str))]
pub struct #id(#backing_type);
#common
impl From<#backing_type> for #id {
fn from(value: #backing_type) -> Self {
#id(value)
}
}
}
} else {
quote! {
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
pub struct #id(#backing_type);
#common
impl TryFrom<#backing_type> for #id {
type Error = #backing_type;
fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
if value > #max_value {
Err(value)
} else {
Ok(#id(value))
}
}
}
}
}
}
fn generate_decl(
scope: &analyzer::Scope<'_>,
schema: &analyzer::Schema,
file: &ast::File,
decl: &ast::Decl,
) -> proc_macro2::TokenStream {
match &decl.desc {
ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
match scope.get_parent(decl) {
None => generate_root_packet_decl(scope, schema, file.endianness.value, id),
Some(_) => generate_derived_packet_decl(scope, schema, file.endianness.value, id),
}
}
ast::DeclDesc::Enum { id, tags, width } => generate_enum_decl(id, tags, *width),
ast::DeclDesc::CustomField { id, width: Some(width), .. } => {
generate_custom_field_decl(file.endianness.value, id, *width)
}
ast::DeclDesc::CustomField { .. } => {
quote!()
}
_ => todo!("unsupported Decl::{:?}", decl),
}
}
pub fn generate_tokens(
sources: &ast::SourceDatabase,
file: &ast::File,
custom_fields: &[String],
) -> proc_macro2::TokenStream {
let source = sources.get(file.file).expect("could not read source");
let preamble = preamble::generate(Path::new(source.name()));
let scope = analyzer::Scope::new(file).expect("could not create scope");
let schema = analyzer::Schema::new(file);
let custom_fields = custom_fields.iter().map(|custom_field| {
syn::parse_str::<syn::Path>(custom_field)
.unwrap_or_else(|err| panic!("invalid path '{custom_field}': {err:?}"))
});
let decls = file.declarations.iter().map(|decl| generate_decl(&scope, &schema, file, decl));
quote! {
#preamble
#(use #custom_fields;)*
#(#decls)*
}
}
pub fn generate(
sources: &ast::SourceDatabase,
file: &ast::File,
custom_fields: &[String],
) -> String {
let syntax_tree =
syn::parse2(generate_tokens(sources, file, custom_fields)).expect("Could not parse code");
prettyplease::unparse(&syntax_tree)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer;
use crate::ast;
use crate::parser::parse_inline;
use crate::test_utils::{assert_snapshot_eq, format_rust};
use paste::paste;
macro_rules! make_pdl_test {
($name:ident, $code:expr, $endianness:ident) => {
paste! {
#[test]
fn [< test_ $name _ $endianness >]() {
let name = stringify!($name);
let endianness = stringify!($endianness);
let code = format!("{endianness}_packets\n{}", $code);
let mut db = ast::SourceDatabase::new();
let file = parse_inline(&mut db, "test", code).unwrap();
let file = analyzer::analyze(&file).unwrap();
let actual_code = generate(&db, &file, &[]);
assert_snapshot_eq(
&format!("tests/generated/rust/{name}_{endianness}.rs"),
&format_rust(&actual_code),
);
}
}
};
}
macro_rules! test_pdl {
($name:ident, $code:expr $(,)?) => {
make_pdl_test!($name, $code, little_endian);
make_pdl_test!($name, $code, big_endian);
};
}
test_pdl!(packet_decl_empty, "packet Foo {}");
test_pdl!(packet_decl_8bit_scalar, " packet Foo { x: 8 }");
test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");
test_pdl!(
enum_declaration,
r#"
enum IncompleteTruncatedClosed : 3 {
A = 0,
B = 1,
}
enum IncompleteTruncatedOpen : 3 {
A = 0,
B = 1,
UNKNOWN = ..
}
enum IncompleteTruncatedClosedWithRange : 3 {
A = 0,
B = 1..6 {
X = 1,
Y = 2,
}
}
enum IncompleteTruncatedOpenWithRange : 3 {
A = 0,
B = 1..6 {
X = 1,
Y = 2,
},
UNKNOWN = ..
}
enum CompleteTruncated : 3 {
A = 0,
B = 1,
C = 2,
D = 3,
E = 4,
F = 5,
G = 6,
H = 7,
}
enum CompleteTruncatedWithRange : 3 {
A = 0,
B = 1..7 {
X = 1,
Y = 2,
}
}
enum CompleteWithRange : 8 {
A = 0,
B = 1,
C = 2..255,
}
"#
);
test_pdl!(
custom_field_declaration,
r#"
// Still unsupported.
// custom_field Dynamic "dynamic"
// Should generate a type with From<u32> implementation.
custom_field ExactSize : 32 "exact_size"
// Should generate a type with TryFrom<u32> implementation.
custom_field TruncatedSize : 24 "truncated_size"
"#
);
test_pdl!(
packet_decl_simple_scalars,
r#"
packet Foo {
x: 8,
y: 16,
z: 24,
}
"#
);
test_pdl!(
packet_decl_complex_scalars,
r#"
packet Foo {
a: 3,
b: 8,
c: 5,
d: 24,
e: 12,
f: 4,
}
"#,
);
test_pdl!(
packet_decl_mask_scalar_value,
r#"
packet Foo {
a: 2,
b: 24,
c: 6,
}
"#,
);
test_pdl!(
struct_decl_complex_scalars,
r#"
struct Foo {
a: 3,
b: 8,
c: 5,
d: 24,
e: 12,
f: 4,
}
"#,
);
test_pdl!(packet_decl_8bit_enum, " enum Foo : 8 { A = 1, B = 2 } packet Bar { x: Foo }");
test_pdl!(packet_decl_24bit_enum, "enum Foo : 24 { A = 1, B = 2 } packet Bar { x: Foo }");
test_pdl!(packet_decl_64bit_enum, "enum Foo : 64 { A = 1, B = 2 } packet Bar { x: Foo }");
test_pdl!(
packet_decl_mixed_scalars_enums,
"
enum Enum7 : 7 {
A = 1,
B = 2,
}
enum Enum9 : 9 {
A = 1,
B = 2,
}
packet Foo {
x: Enum7,
y: 5,
z: Enum9,
w: 3,
}
"
);
test_pdl!(packet_decl_8bit_scalar_array, " packet Foo { x: 8[3] }");
test_pdl!(packet_decl_24bit_scalar_array, "packet Foo { x: 24[5] }");
test_pdl!(packet_decl_64bit_scalar_array, "packet Foo { x: 64[7] }");
test_pdl!(
packet_decl_8bit_enum_array,
"enum Foo : 8 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[3] }"
);
test_pdl!(
packet_decl_24bit_enum_array,
"enum Foo : 24 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[5] }"
);
test_pdl!(
packet_decl_64bit_enum_array,
"enum Foo : 64 { FOO_BAR = 1, BAZ = 2 } packet Bar { x: Foo[7] }"
);
test_pdl!(
packet_decl_array_dynamic_count,
"
packet Foo {
_count_(x): 5,
padding: 3,
x: 24[]
}
"
);
test_pdl!(
packet_decl_array_dynamic_size,
"
packet Foo {
_size_(x): 5,
padding: 3,
x: 24[]
}
"
);
test_pdl!(
packet_decl_array_unknown_element_width_dynamic_size,
"
struct Foo {
_count_(a): 40,
a: 16[],
}
packet Bar {
_size_(x): 40,
x: Foo[],
}
"
);
test_pdl!(
packet_decl_array_unknown_element_width_dynamic_count,
"
struct Foo {
_count_(a): 40,
a: 16[],
}
packet Bar {
_count_(x): 40,
x: Foo[],
}
"
);
test_pdl!(
packet_decl_array_with_padding,
"
struct Foo {
_count_(a): 40,
a: 16[],
}
packet Bar {
a: Foo[],
_padding_ [128],
}
"
);
test_pdl!(
packet_decl_array_dynamic_element_size,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[]
}
"
);
test_pdl!(
packet_decl_array_dynamic_element_size_dynamic_size,
"
struct Foo {
inner: 8[]
}
packet Bar {
_size_(x): 4,
_elementsize_(x): 4,
x: Foo[]
}
"
);
test_pdl!(
packet_decl_array_dynamic_element_size_dynamic_count,
"
struct Foo {
inner: 8[]
}
packet Bar {
_count_(x): 4,
_elementsize_(x): 4,
x: Foo[]
}
"
);
test_pdl!(
packet_decl_array_dynamic_element_size_static_count,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[4]
}
"
);
test_pdl!(
packet_decl_array_dynamic_element_size_static_count_1,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[1]
}
"
);
test_pdl!(
packet_decl_reserved_field,
"
packet Foo {
_reserved_: 40,
}
"
);
test_pdl!(
packet_decl_custom_field,
r#"
custom_field Bar1 : 24 "exact"
custom_field Bar2 : 32 "truncated"
packet Foo {
a: Bar1,
b: Bar2,
}
"#
);
test_pdl!(
packet_decl_fixed_scalar_field,
"
packet Foo {
_fixed_ = 7 : 7,
b: 57,
}
"
);
test_pdl!(
packet_decl_fixed_enum_field,
"
enum Enum7 : 7 {
A = 1,
B = 2,
}
packet Foo {
_fixed_ = A : Enum7,
b: 57,
}
"
);
test_pdl!(
packet_decl_payload_field_variable_size,
"
packet Foo {
a: 8,
_size_(_payload_): 8,
_payload_,
b: 16,
}
"
);
test_pdl!(
packet_decl_payload_field_unknown_size,
"
packet Foo {
a: 24,
_payload_,
}
"
);
test_pdl!(
packet_decl_payload_field_unknown_size_terminal,
"
packet Foo {
_payload_,
a: 24,
}
"
);
test_pdl!(
packet_decl_child_packets,
"
enum Enum16 : 16 {
A = 1,
B = 2,
}
packet Foo {
a: 8,
b: Enum16,
_size_(_payload_): 8,
_payload_
}
packet Bar : Foo (a = 100) {
x: 8,
}
packet Baz : Foo (b = B) {
y: 16,
}
"
);
test_pdl!(
packet_decl_grand_children,
"
enum Enum16 : 16 {
A = 1,
B = 2,
}
packet Parent {
foo: Enum16,
bar: Enum16,
baz: Enum16,
_size_(_payload_): 8,
_payload_
}
packet Child : Parent (foo = A) {
quux: Enum16,
_payload_,
}
packet GrandChild : Child (bar = A, quux = A) {
_body_,
}
packet GrandGrandChild : GrandChild (baz = A) {
_body_,
}
"
);
test_pdl!(
packet_decl_parent_with_no_payload,
"
enum Enum8 : 8 {
A = 0,
}
packet Parent {
v : Enum8,
}
packet Child : Parent (v = A) {
}
"
);
test_pdl!(
packet_decl_parent_with_alias_child,
"
enum Enum8 : 8 {
A = 0,
B = 1,
C = 2,
}
packet Parent {
v : Enum8,
_payload_,
}
packet AliasChild : Parent {
_payload_
}
packet NormalChild : Parent (v = A) {
}
packet NormalGrandChild1 : AliasChild (v = B) {
}
packet NormalGrandChild2 : AliasChild (v = C) {
_payload_
}
"
);
test_pdl!(
reserved_identifier,
"
packet Test {
type: 8,
}
"
);
test_pdl!(
payload_with_size_modifier,
"
packet Test {
_size_(_payload_): 8,
_payload_ : [+1],
}
"
);
test_pdl!(
struct_decl_child_structs,
"
enum Enum16 : 16 {
A = 1,
B = 2,
}
struct Foo {
a: 8,
b: Enum16,
_size_(_payload_): 8,
_payload_
}
struct Bar : Foo (a = 100) {
x: 8,
}
struct Baz : Foo (b = B) {
y: 16,
}
"
);
test_pdl!(
struct_decl_grand_children,
"
enum Enum16 : 16 {
A = 1,
B = 2,
}
struct Parent {
foo: Enum16,
bar: Enum16,
baz: Enum16,
_size_(_payload_): 8,
_payload_
}
struct Child : Parent (foo = A) {
quux: Enum16,
_payload_,
}
struct GrandChild : Child (bar = A, quux = A) {
_body_,
}
struct GrandGrandChild : GrandChild (baz = A) {
_body_,
}
"
);
}