redoubt_codec_derive/
lib.rs

1// Copyright (c) 2025-2026 Federico Hoerth <memparanoid@gmail.com>
2// SPDX-License-Identifier: GPL-3.0-only
3// See LICENSE in the repository root for full license text.
4
5//! Procedural macros for redoubt-codec.
6//!
7//! ## License
8//!
9//! GPL-3.0-only
10
11// Only run unit tests on architectures where insta (-> sha2 -> cpufeatures) compiles
12#[cfg(all(
13    test,
14    any(
15        target_arch = "x86_64",
16        target_arch = "x86",
17        target_arch = "aarch64",
18        target_arch = "loongarch64"
19    )
20))]
21mod tests;
22
23use proc_macro::TokenStream;
24use proc_macro_crate::{FoundCrate, crate_name};
25use proc_macro2::{Span, TokenStream as TokenStream2};
26use quote::quote;
27use syn::{Attribute, Data, DeriveInput, Fields, Ident, Index, LitStr, Meta, parse_macro_input};
28
29/// Derives `BytesRequired`, `Encode`, and `Decode` for a struct.
30///
31/// # Attributes
32///
33/// - `#[codec(default)]` on a field: Skip encoding/decoding, use `Default::default()`
34#[proc_macro_derive(RedoubtCodec, attributes(codec))]
35pub fn derive_redoubt_codec(input: TokenStream) -> TokenStream {
36    let input = parse_macro_input!(input as DeriveInput);
37    expand(input).unwrap_or_else(|e| e).into()
38}
39
40/// Find the root crate path from a list of candidates.
41/// Candidates can be crate names like "redoubt-codec" or paths like "redoubt::codec".
42pub(crate) fn find_root_with_candidates(candidates: &[&'static str]) -> TokenStream2 {
43    for &candidate in candidates {
44        // Check if candidate contains "::" (path syntax like "redoubt::codec")
45        if let Some((crate_part, path_part)) = candidate.split_once("::") {
46            match crate_name(crate_part) {
47                Ok(FoundCrate::Itself) => {
48                    let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
49                    return quote!(crate::#path);
50                }
51                Ok(FoundCrate::Name(name)) => {
52                    let crate_id = Ident::new(&name, Span::call_site());
53                    let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
54                    return quote!(#crate_id::#path);
55                }
56                Err(_) => continue,
57            }
58        } else {
59            match crate_name(candidate) {
60                Ok(FoundCrate::Itself) => return quote!(crate),
61                Ok(FoundCrate::Name(name)) => {
62                    let id = Ident::new(&name, Span::call_site());
63                    return quote!(#id);
64                }
65                Err(_) => continue,
66            }
67        }
68    }
69
70    let msg = "RedoubtCodec: could not find redoubt-codec or redoubt-codec-core. Add redoubt-codec to Cargo.toml.";
71    let lit = LitStr::new(msg, Span::call_site());
72    quote! { compile_error!(#lit); }
73}
74
75/// Checks if a field has the `#[codec(default)]` attribute.
76fn has_codec_default(attrs: &[Attribute]) -> bool {
77    attrs.iter().any(|attr| {
78        matches!(&attr.meta, Meta::List(meta_list)
79            if meta_list.path.is_ident("codec")
80            && meta_list.tokens.to_string().contains("default"))
81    })
82}
83
84fn expand(input: DeriveInput) -> Result<TokenStream2, TokenStream2> {
85    let struct_name = &input.ident;
86    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
87
88    let root = find_root_with_candidates(&["redoubt-codec-core", "redoubt-codec", "redoubt::codec"]);
89
90    // Get fields
91    let fields: Vec<(usize, &syn::Field)> = match &input.data {
92        Data::Struct(data) => match &data.fields {
93            Fields::Named(named) => named.named.iter().enumerate().collect(),
94            Fields::Unnamed(unnamed) => unnamed.unnamed.iter().enumerate().collect(),
95            Fields::Unit => vec![],
96        },
97        _ => {
98            return Err(syn::Error::new_spanned(
99                &input.ident,
100                "RedoubtCodec can only be derived for structs.",
101            )
102            .to_compile_error());
103        }
104    };
105
106    // Generate field references (filter out fields with #[codec(default)])
107    let (immut_refs, mut_refs): (Vec<TokenStream2>, Vec<TokenStream2>) = fields
108        .iter()
109        .filter(|(_, f)| !has_codec_default(&f.attrs))
110        .map(|(i, f)| {
111            if let Some(ident) = &f.ident {
112                (quote! { &self.#ident }, quote! { &mut self.#ident })
113            } else {
114                let idx = Index::from(*i);
115                (quote! { &self.#idx }, quote! { &mut self.#idx })
116            }
117        })
118        .unzip();
119
120    let len = immut_refs.len();
121    let len_lit = syn::LitInt::new(&len.to_string(), Span::call_site());
122
123    let output = quote! {
124        impl #impl_generics #root::BytesRequired for #struct_name #ty_generics #where_clause {
125            fn encode_bytes_required(&self) -> Result<usize, #root::OverflowError> {
126                let fields: [&dyn #root::BytesRequired; #len_lit] = [
127                    #( #root::collections::helpers::to_bytes_required_dyn_ref(#immut_refs) ),*
128                ];
129                #root::collections::helpers::bytes_required_sum(fields.into_iter())
130            }
131        }
132
133        impl #impl_generics #root::Encode for #struct_name #ty_generics #where_clause {
134            fn encode_into(&mut self, buf: &mut #root::RedoubtCodecBuffer) -> Result<(), #root::EncodeError> {
135                let fields: [&mut dyn #root::EncodeZeroize; #len_lit] = [
136                    #( #root::collections::helpers::to_encode_zeroize_dyn_mut(#mut_refs) ),*
137                ];
138                #root::collections::helpers::encode_fields(fields.into_iter(), buf)
139            }
140        }
141
142        impl #impl_generics #root::Decode for #struct_name #ty_generics #where_clause {
143            fn decode_from(&mut self, buf: &mut &mut [u8]) -> Result<(), #root::DecodeError> {
144                let fields: [&mut dyn #root::DecodeZeroize; #len_lit] = [
145                    #( #root::collections::helpers::to_decode_zeroize_dyn_mut(#mut_refs) ),*
146                ];
147                #root::collections::helpers::decode_fields(fields.into_iter(), buf)
148            }
149        }
150    };
151
152    Ok(output)
153}