1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3use proc_macro::TokenStream;
6use quote::quote;
7use syn::{Data, DeriveInput, Fields, parse_macro_input};
8
9#[proc_macro_derive(Encode)]
29pub fn derive_wire_encode(input: TokenStream) -> TokenStream {
30 let input = parse_macro_input!(input as DeriveInput);
31 match impl_wire_encode(&input) {
32 Ok(tokens) => tokens.into(),
33 Err(err) => err.to_compile_error().into(),
34 }
35}
36
37#[proc_macro_derive(Decode)]
57pub fn derive_wire_decode(input: TokenStream) -> TokenStream {
58 let input = parse_macro_input!(input as DeriveInput);
59 match impl_wire_decode(&input) {
60 Ok(tokens) => tokens.into(),
61 Err(err) => err.to_compile_error().into(),
62 }
63}
64
65fn named_fields(input: &DeriveInput) -> syn::Result<&syn::FieldsNamed> {
68 let name = &input.ident;
69 match &input.data {
70 Data::Struct(data) => match &data.fields {
71 Fields::Named(named) => Ok(named),
72 _ => Err(syn::Error::new_spanned(
73 name,
74 "Encode / Decode can only be derived for structs with named fields",
75 )),
76 },
77 Data::Enum(_) => Err(syn::Error::new_spanned(
78 name,
79 "Encode / Decode cannot be derived for enums",
80 )),
81 Data::Union(_) => Err(syn::Error::new_spanned(
82 name,
83 "Encode / Decode cannot be derived for unions",
84 )),
85 }
86}
87
88fn reject_generics(input: &DeriveInput) -> syn::Result<()> {
91 if !input.generics.params.is_empty() {
92 return Err(syn::Error::new_spanned(
93 &input.generics,
94 "Encode / Decode cannot be derived for generic structs",
95 ));
96 }
97 Ok(())
98}
99
100fn impl_wire_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
103 reject_generics(input)?;
104 let name = &input.ident;
105 let fields = named_fields(input)?;
106
107 let encode_stmts: Vec<_> = fields
108 .named
109 .iter()
110 .map(|f| {
111 let ident = f.ident.as_ref().unwrap();
112 quote! {
113 conduit_core::Encode::encode(&self.#ident, buf);
114 }
115 })
116 .collect();
117
118 let size_terms: Vec<_> = fields
119 .named
120 .iter()
121 .map(|f| {
122 let ident = f.ident.as_ref().unwrap();
123 quote! {
124 conduit_core::Encode::encode_size(&self.#ident)
125 }
126 })
127 .collect();
128
129 let size_expr = if size_terms.is_empty() {
131 quote! { 0 }
132 } else {
133 let first = &size_terms[0];
134 let rest = &size_terms[1..];
135 quote! { #first #(+ #rest)* }
136 };
137
138 Ok(quote! {
139 impl conduit_core::Encode for #name {
140 fn encode(&self, buf: &mut Vec<u8>) {
141 #(#encode_stmts)*
142 }
143
144 fn encode_size(&self) -> usize {
145 #size_expr
146 }
147 }
148 })
149}
150
151fn impl_wire_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
154 reject_generics(input)?;
155 let name = &input.ident;
156 let fields = named_fields(input)?;
157
158 let decode_stmts: Vec<_> = fields
159 .named
160 .iter()
161 .map(|f| {
162 let ident = f.ident.as_ref().unwrap();
163 quote! {
164 let (#ident, __n) = conduit_core::Decode::decode(&__data[__offset..])?;
165 __offset += __n;
166 }
167 })
168 .collect();
169
170 let field_names: Vec<_> = fields
171 .named
172 .iter()
173 .map(|f| f.ident.as_ref().unwrap())
174 .collect();
175
176 Ok(quote! {
177 impl conduit_core::Decode for #name {
178 fn decode(__data: &[u8]) -> Option<(Self, usize)> {
179 let mut __offset = 0usize;
180 #(#decode_stmts)*
181 Some((Self { #(#field_names),* }, __offset))
182 }
183 }
184 })
185}