redoubt_codec_derive/
lib.rs1#[cfg(all(
13 test,
14 any(
15 target_arch = "x86_64",
16 target_arch = "x86",
17 target_arch = "aarch64",
18 target_arch = "loongarch64"
19 )
20))]
21mod tests;
22
23use proc_macro::TokenStream;
24use proc_macro_crate::{FoundCrate, crate_name};
25use proc_macro2::{Span, TokenStream as TokenStream2};
26use quote::quote;
27use syn::{Attribute, Data, DeriveInput, Fields, Ident, Index, LitStr, Meta, parse_macro_input};
28
29#[proc_macro_derive(RedoubtCodec, attributes(codec))]
35pub fn derive_redoubt_codec(input: TokenStream) -> TokenStream {
36 let input = parse_macro_input!(input as DeriveInput);
37 expand(input).unwrap_or_else(|e| e).into()
38}
39
40pub(crate) fn find_root_with_candidates(candidates: &[&'static str]) -> TokenStream2 {
43 for &candidate in candidates {
44 if let Some((crate_part, path_part)) = candidate.split_once("::") {
46 match crate_name(crate_part) {
47 Ok(FoundCrate::Itself) => {
48 let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
49 return quote!(crate::#path);
50 }
51 Ok(FoundCrate::Name(name)) => {
52 let crate_id = Ident::new(&name, Span::call_site());
53 let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
54 return quote!(#crate_id::#path);
55 }
56 Err(_) => continue,
57 }
58 } else {
59 match crate_name(candidate) {
60 Ok(FoundCrate::Itself) => return quote!(crate),
61 Ok(FoundCrate::Name(name)) => {
62 let id = Ident::new(&name, Span::call_site());
63 return quote!(#id);
64 }
65 Err(_) => continue,
66 }
67 }
68 }
69
70 let msg = "RedoubtCodec: could not find redoubt-codec or redoubt-codec-core. Add redoubt-codec to Cargo.toml.";
71 let lit = LitStr::new(msg, Span::call_site());
72 quote! { compile_error!(#lit); }
73}
74
75fn has_codec_default(attrs: &[Attribute]) -> bool {
77 attrs.iter().any(|attr| {
78 matches!(&attr.meta, Meta::List(meta_list)
79 if meta_list.path.is_ident("codec")
80 && meta_list.tokens.to_string().contains("default"))
81 })
82}
83
84fn expand(input: DeriveInput) -> Result<TokenStream2, TokenStream2> {
85 let struct_name = &input.ident;
86 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
87
88 let root = find_root_with_candidates(&["redoubt-codec-core", "redoubt-codec", "redoubt::codec"]);
89
90 let fields: Vec<(usize, &syn::Field)> = match &input.data {
92 Data::Struct(data) => match &data.fields {
93 Fields::Named(named) => named.named.iter().enumerate().collect(),
94 Fields::Unnamed(unnamed) => unnamed.unnamed.iter().enumerate().collect(),
95 Fields::Unit => vec![],
96 },
97 _ => {
98 return Err(syn::Error::new_spanned(
99 &input.ident,
100 "RedoubtCodec can only be derived for structs.",
101 )
102 .to_compile_error());
103 }
104 };
105
106 let (immut_refs, mut_refs): (Vec<TokenStream2>, Vec<TokenStream2>) = fields
108 .iter()
109 .filter(|(_, f)| !has_codec_default(&f.attrs))
110 .map(|(i, f)| {
111 if let Some(ident) = &f.ident {
112 (quote! { &self.#ident }, quote! { &mut self.#ident })
113 } else {
114 let idx = Index::from(*i);
115 (quote! { &self.#idx }, quote! { &mut self.#idx })
116 }
117 })
118 .unzip();
119
120 let len = immut_refs.len();
121 let len_lit = syn::LitInt::new(&len.to_string(), Span::call_site());
122
123 let output = quote! {
124 impl #impl_generics #root::BytesRequired for #struct_name #ty_generics #where_clause {
125 fn encode_bytes_required(&self) -> Result<usize, #root::OverflowError> {
126 let fields: [&dyn #root::BytesRequired; #len_lit] = [
127 #( #root::collections::helpers::to_bytes_required_dyn_ref(#immut_refs) ),*
128 ];
129 #root::collections::helpers::bytes_required_sum(fields.into_iter())
130 }
131 }
132
133 impl #impl_generics #root::Encode for #struct_name #ty_generics #where_clause {
134 fn encode_into(&mut self, buf: &mut #root::RedoubtCodecBuffer) -> Result<(), #root::EncodeError> {
135 let fields: [&mut dyn #root::EncodeZeroize; #len_lit] = [
136 #( #root::collections::helpers::to_encode_zeroize_dyn_mut(#mut_refs) ),*
137 ];
138 #root::collections::helpers::encode_fields(fields.into_iter(), buf)
139 }
140 }
141
142 impl #impl_generics #root::Decode for #struct_name #ty_generics #where_clause {
143 fn decode_from(&mut self, buf: &mut &mut [u8]) -> Result<(), #root::DecodeError> {
144 let fields: [&mut dyn #root::DecodeZeroize; #len_lit] = [
145 #( #root::collections::helpers::to_decode_zeroize_dyn_mut(#mut_refs) ),*
146 ];
147 #root::collections::helpers::decode_fields(fields.into_iter(), buf)
148 }
149 }
150 };
151
152 Ok(output)
153}