Skip to main content

commonware_codec_macros/
lib.rs

1//! Augment the development of [`commonware-codec`](https://docs.rs/commonware-codec) with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use proc_macro_crate::{crate_name, FoundCrate};
11use quote::quote;
12use syn::{
13    parenthesized, parse_macro_input, parse_quote, DeriveInput, Error, Generics, Ident, Type,
14    WhereClause, WherePredicate,
15};
16
17/// Resolves the path to the `commonware-codec` crate, accounting for renames and use within
18/// `commonware-codec` itself.
19fn codec_path() -> proc_macro2::TokenStream {
20    match crate_name("commonware-codec") {
21        Ok(FoundCrate::Itself) => quote!(crate),
22        Ok(FoundCrate::Name(name)) => {
23            let ident = Ident::new(&name, Span::call_site());
24            quote!(::#ident)
25        }
26        Err(_) => quote!(::commonware_codec),
27    }
28}
29
30/// Returns a where clause that preserves user predicates and adds one generated bound.
31fn where_clause_with(generics: &Generics, predicate: WherePredicate) -> WhereClause {
32    let mut generics = generics.clone();
33    generics.make_where_clause().predicates.push(predicate);
34    generics
35        .where_clause
36        .expect("make_where_clause should create a where clause")
37}
38
39/// Derives byte-array conversion impls for a fixed-size type.
40///
41/// Generates:
42/// - `TryFrom<[u8; SIZE]>` and `TryFrom<&[u8; SIZE]>`, or `From<[u8; SIZE]>` and
43///   `From<&[u8; SIZE]>` when `infallible` (decoding via `DecodeFixed`).
44/// - `TryFrom<&[u8]>`
45/// - `From<T> for [u8; SIZE]`
46/// - `From<&T> for [u8; SIZE]`
47///
48/// The type must implement `Read<Cfg = ()>` and `EncodeFixed`.
49///
50/// # Attributes
51///
52/// - `#[fixed_array(infallible)]`: emit `From<[u8; SIZE]>` instead of `TryFrom<[u8; SIZE]>`.
53///   The type's decode must never fail (any `[u8; SIZE]` is a valid value), since the generated
54///   `From` unwraps the `DecodeFixed` result.
55/// - `#[fixed_array(bytes([u8; N]))]`: required for any generic type (lifetime, type, or
56///   const). Stable Rust forbids a generic parameter inside the const expression
57///   `[u8; <T as FixedSize>::SIZE]`, so the byte array type must be named.
58#[proc_macro_derive(FixedArray, attributes(fixed_array))]
59pub fn fixed_array(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as DeriveInput);
61    let name = &input.ident;
62    let (impl_generics, ty_generics, _) = input.generics.split_for_impl();
63
64    let mut infallible = false;
65    let mut bytes_ty: Option<Type> = None;
66    for attr in &input.attrs {
67        if !attr.path().is_ident("fixed_array") {
68            continue;
69        }
70        let result = attr.parse_nested_meta(|meta| {
71            if meta.path.is_ident("infallible") {
72                infallible = true;
73                Ok(())
74            } else if meta.path.is_ident("bytes") {
75                let content;
76                parenthesized!(content in meta.input);
77                bytes_ty = Some(content.parse()?);
78                Ok(())
79            } else {
80                Err(meta.error("expected `infallible` or `bytes(...)`"))
81            }
82        });
83        if let Err(e) = result {
84            return e.to_compile_error().into();
85        }
86    }
87
88    // Stable Rust forbids any generic parameter (lifetime, type, or const) inside the const
89    // expression `<T as FixedSize>::SIZE`, so generic types must name the byte array type.
90    if !input.generics.params.is_empty() && bytes_ty.is_none() {
91        return Error::new_spanned(
92            &input.generics,
93            "generic types must name the byte array type: #[fixed_array(bytes([u8; N]))]",
94        )
95        .to_compile_error()
96        .into();
97    }
98
99    let codec = codec_path();
100    let bytes = bytes_ty.as_ref().map_or_else(
101        || quote!([u8; <#name as #codec::FixedSize>::SIZE]),
102        |ty| quote!(#ty),
103    );
104    let decode_fixed_where = where_clause_with(
105        &input.generics,
106        parse_quote!(#name #ty_generics: #codec::DecodeFixed),
107    );
108    let encode_fixed_where = where_clause_with(
109        &input.generics,
110        parse_quote!(#name #ty_generics: #codec::EncodeFixed),
111    );
112
113    let from_arrays = if infallible {
114        quote! {
115            impl #impl_generics core::convert::From<#bytes> for #name #ty_generics #decode_fixed_where {
116                fn from(bytes: #bytes) -> Self {
117                    <Self as #codec::DecodeFixed>::decode_fixed(bytes)
118                        .expect("infallible decode of fixed-size array")
119                }
120            }
121
122            impl #impl_generics core::convert::From<&#bytes> for #name #ty_generics #decode_fixed_where {
123                fn from(bytes: &#bytes) -> Self {
124                    <Self as core::convert::From<#bytes>>::from(*bytes)
125                }
126            }
127        }
128    } else {
129        quote! {
130            impl #impl_generics core::convert::TryFrom<#bytes> for #name #ty_generics #decode_fixed_where {
131                type Error = #codec::Error;
132
133                fn try_from(bytes: #bytes) -> core::result::Result<Self, Self::Error> {
134                    <Self as #codec::DecodeFixed>::decode_fixed(bytes)
135                }
136            }
137
138            impl #impl_generics core::convert::TryFrom<&#bytes> for #name #ty_generics #decode_fixed_where {
139                type Error = #codec::Error;
140
141                fn try_from(bytes: &#bytes) -> core::result::Result<Self, Self::Error> {
142                    <Self as #codec::DecodeFixed>::decode_fixed(*bytes)
143                }
144            }
145        }
146    };
147
148    let expanded = quote! {
149        #from_arrays
150
151        impl #impl_generics core::convert::TryFrom<&[u8]> for #name #ty_generics #decode_fixed_where {
152            type Error = #codec::Error;
153
154            fn try_from(bytes: &[u8]) -> core::result::Result<Self, Self::Error> {
155                <Self as #codec::Decode>::decode_cfg(bytes, &())
156            }
157        }
158
159        impl #impl_generics core::convert::From<#name #ty_generics> for #bytes #encode_fixed_where {
160            fn from(value: #name #ty_generics) -> Self {
161                #codec::EncodeFixed::encode_fixed(&value)
162            }
163        }
164
165        impl #impl_generics core::convert::From<&#name #ty_generics> for #bytes #encode_fixed_where {
166            fn from(value: &#name #ty_generics) -> Self {
167                #codec::EncodeFixed::encode_fixed(value)
168            }
169        }
170    };
171
172    TokenStream::from(expanded)
173}