#![doc(
html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
html_favicon_url = "https://commonware.xyz/favicon.ico"
)]
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::quote;
use syn::{
parenthesized, parse_macro_input, parse_quote, DeriveInput, Error, Generics, Ident, Type,
WhereClause, WherePredicate,
};
fn codec_path() -> proc_macro2::TokenStream {
match crate_name("commonware-codec") {
Ok(FoundCrate::Itself) => quote!(crate),
Ok(FoundCrate::Name(name)) => {
let ident = Ident::new(&name, Span::call_site());
quote!(::#ident)
}
Err(_) => quote!(::commonware_codec),
}
}
fn where_clause_with(generics: &Generics, predicate: WherePredicate) -> WhereClause {
let mut generics = generics.clone();
generics.make_where_clause().predicates.push(predicate);
generics
.where_clause
.expect("make_where_clause should create a where clause")
}
#[proc_macro_derive(FixedArray, attributes(fixed_array))]
pub fn fixed_array(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, _) = input.generics.split_for_impl();
let mut infallible = false;
let mut bytes_ty: Option<Type> = None;
for attr in &input.attrs {
if !attr.path().is_ident("fixed_array") {
continue;
}
let result = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("infallible") {
infallible = true;
Ok(())
} else if meta.path.is_ident("bytes") {
let content;
parenthesized!(content in meta.input);
bytes_ty = Some(content.parse()?);
Ok(())
} else {
Err(meta.error("expected `infallible` or `bytes(...)`"))
}
});
if let Err(e) = result {
return e.to_compile_error().into();
}
}
if !input.generics.params.is_empty() && bytes_ty.is_none() {
return Error::new_spanned(
&input.generics,
"generic types must name the byte array type: #[fixed_array(bytes([u8; N]))]",
)
.to_compile_error()
.into();
}
let codec = codec_path();
let bytes = bytes_ty.as_ref().map_or_else(
|| quote!([u8; <#name as #codec::FixedSize>::SIZE]),
|ty| quote!(#ty),
);
let decode_fixed_where = where_clause_with(
&input.generics,
parse_quote!(#name #ty_generics: #codec::DecodeFixed),
);
let encode_fixed_where = where_clause_with(
&input.generics,
parse_quote!(#name #ty_generics: #codec::EncodeFixed),
);
let from_arrays = if infallible {
quote! {
impl #impl_generics core::convert::From<#bytes> for #name #ty_generics #decode_fixed_where {
fn from(bytes: #bytes) -> Self {
<Self as #codec::DecodeFixed>::decode_fixed(bytes)
.expect("infallible decode of fixed-size array")
}
}
impl #impl_generics core::convert::From<&#bytes> for #name #ty_generics #decode_fixed_where {
fn from(bytes: &#bytes) -> Self {
<Self as core::convert::From<#bytes>>::from(*bytes)
}
}
}
} else {
quote! {
impl #impl_generics core::convert::TryFrom<#bytes> for #name #ty_generics #decode_fixed_where {
type Error = #codec::Error;
fn try_from(bytes: #bytes) -> core::result::Result<Self, Self::Error> {
<Self as #codec::DecodeFixed>::decode_fixed(bytes)
}
}
impl #impl_generics core::convert::TryFrom<&#bytes> for #name #ty_generics #decode_fixed_where {
type Error = #codec::Error;
fn try_from(bytes: &#bytes) -> core::result::Result<Self, Self::Error> {
<Self as #codec::DecodeFixed>::decode_fixed(*bytes)
}
}
}
};
let expanded = quote! {
#from_arrays
impl #impl_generics core::convert::TryFrom<&[u8]> for #name #ty_generics #decode_fixed_where {
type Error = #codec::Error;
fn try_from(bytes: &[u8]) -> core::result::Result<Self, Self::Error> {
<Self as #codec::Decode>::decode_cfg(bytes, &())
}
}
impl #impl_generics core::convert::From<#name #ty_generics> for #bytes #encode_fixed_where {
fn from(value: #name #ty_generics) -> Self {
#codec::EncodeFixed::encode_fixed(&value)
}
}
impl #impl_generics core::convert::From<&#name #ty_generics> for #bytes #encode_fixed_where {
fn from(value: &#name #ty_generics) -> Self {
#codec::EncodeFixed::encode_fixed(value)
}
}
};
TokenStream::from(expanded)
}