Skip to main content

fluentbase_codec_derive/
lib.rs

1//! Procedural macros for deriving `fluentbase_codec::Codec` on Rust structs.
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input,
7    parse_quote,
8    Data,
9    DeriveInput,
10    Fields,
11    GenericParam,
12    Ident,
13    Type,
14    WhereClause,
15    WherePredicate,
16};
17
18/// Holds information about a struct field
19struct FieldInfo {
20    ident: Ident,
21    ty: Type,
22}
23
24/// Represents a struct for which we are deriving Codec
25struct CodecStruct {
26    struct_name: Ident,
27    generics: syn::Generics,
28    fields: Vec<FieldInfo>,
29}
30
31impl CodecStruct {
32    /// Parse the DeriveInput to extract struct information
33    fn parse(ast: &DeriveInput) -> Self {
34        let data_struct = match &ast.data {
35            Data::Struct(s) => s,
36            _ => panic!("`Codec` can only be derived for structs"),
37        };
38
39        let named_fields = match &data_struct.fields {
40            Fields::Named(named_fields) => named_fields,
41            _ => panic!("`Codec` can only be derived for structs with named fields"),
42        };
43
44        let fields = named_fields
45            .named
46            .iter()
47            .map(|field| {
48                let ident = field.ident.as_ref().unwrap().clone();
49                let ty = field.ty.clone();
50                FieldInfo { ident, ty }
51            })
52            .collect();
53
54        CodecStruct {
55            struct_name: ast.ident.clone(),
56            generics: ast.generics.clone(),
57            fields,
58        }
59    }
60
61    /// Detect the crate path for the codec library
62    fn get_crate_path() -> TokenStream2 {
63        let crate_name = std::env::var("CARGO_PKG_NAME").unwrap_or_default();
64        if crate_name == "fluentbase-codec"
65            || crate_name == "fluentbase-sdk"
66            || crate_name == "fluentbase-types"
67            || crate_name == "fluentbase-runtime"
68        {
69            quote! { ::fluentbase_codec }
70        } else {
71            quote! { ::fluentbase_sdk::codec }
72        }
73    }
74
75    /// Prepare generics by adding necessary type and const parameters
76    fn prepare_generics(&self, original_generics: &syn::Generics) -> syn::Generics {
77        let mut generics = original_generics.clone();
78        let crate_path = Self::get_crate_path();
79
80        // Check if B and ALIGN parameters already exist
81        let needs_b = !generics
82            .params
83            .iter()
84            .any(|p| matches!(p, GenericParam::Type(t) if t.ident == "B"));
85        let needs_align = !generics
86            .params
87            .iter()
88            .any(|p| matches!(p, GenericParam::Const(c) if c.ident == "ALIGN"));
89
90        // Add them if needed
91        if needs_b {
92            generics
93                .params
94                .push(parse_quote!(B: #crate_path::byteorder::ByteOrder));
95        }
96        if needs_align {
97            generics.params.push(parse_quote!(const ALIGN: usize));
98        }
99
100        generics
101    }
102
103    /// Add where clause predicates for the Encoder trait bound on each field
104    fn add_encoder_bounds(
105        &self,
106        generics: &syn::Generics,
107        sol_mode: bool,
108        is_static: bool,
109    ) -> WhereClause {
110        let crate_path = Self::get_crate_path();
111
112        // Create bounds for each field requiring Encoder implementation
113        let encoder_bounds: Vec<WherePredicate> = self
114            .fields
115            .iter()
116            .map(|field| {
117                let ty = &field.ty;
118                parse_quote!(#ty: #crate_path::Encoder<B, ALIGN, {#sol_mode}, {#is_static}>)
119            })
120            .collect();
121
122        // Add them to existing where clause or create a new one
123        if let Some(mut where_clause) = generics.where_clause.clone() {
124            where_clause.predicates.extend(encoder_bounds);
125            where_clause
126        } else {
127            parse_quote!(where #(#encoder_bounds),*)
128        }
129    }
130
131    /// Generate expression for checking if a field type is dynamic
132    fn generate_is_dynamic_expr(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
133        let crate_path = Self::get_crate_path();
134
135        let is_dynamic_expr = self.fields.iter().map(|field| {
136            let ty = &field.ty;
137            quote! {
138                <#ty as #crate_path::Encoder<B, ALIGN, {#sol_mode}, {#is_static}>>::IS_DYNAMIC
139            }
140        });
141
142        quote! {
143            false #( || #is_dynamic_expr)*
144        }
145    }
146
147    /// Generate expression for calculating header size
148    fn generate_header_size_expr(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
149        let crate_path = Self::get_crate_path();
150
151        let header_sizes = self.fields.iter().map(|field| {
152            let ty = &field.ty;
153            if sol_mode {
154                quote! {
155                    <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE
156                }
157            } else {
158                quote! {
159                    #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE)
160                }
161            }
162        });
163
164        quote! {
165            0 #( + #header_sizes)*
166        }
167    }
168
169    /// Generate code for the aligned_header_size expression
170    fn generate_aligned_header_size(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
171        let crate_path = Self::get_crate_path();
172
173        if sol_mode {
174            let sizes = self.fields.iter().map(|field| {
175                let ty = &field.ty;
176                let ts = quote! {
177                    <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>
178                };
179                quote! {
180                    if #ts ::IS_DYNAMIC {
181                        32
182                    } else {
183                        #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE)
184                    }
185                }
186            });
187            quote! { 0 #( + #sizes)* }
188        } else {
189            quote! { <Self as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE }
190        }
191    }
192
193    /// Generate encode implementation for fields
194    fn generate_encode_fields(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
195        let crate_path = Self::get_crate_path();
196
197        let encode_fields = self.fields.iter().map(|field| {
198            let ident = &field.ident;
199            let ty = &field.ty;
200
201            if sol_mode {
202                quote! {
203                    if <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::IS_DYNAMIC {
204                        <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::encode(&self.#ident, &mut tail, tail_offset)?;
205                        tail_offset += #crate_path::align_up::<ALIGN>(4);
206                    } else {
207                        <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::encode(&self.#ident, &mut tail, tail_offset)?;
208                        tail_offset += #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE);
209                    }
210                }
211            } else {
212                quote! {
213                    <#ty as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::encode(&self.#ident, buf, current_offset)?;
214                    current_offset += #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE);
215                }
216            }
217        });
218
219        quote! { #(#encode_fields)* }
220    }
221
222    /// Generate decode implementation for fields
223    fn generate_decode_fields(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
224        let crate_path = Self::get_crate_path();
225
226        let decode_fields = self.fields.iter().map(|field| {
227            let ident = &field.ident;
228            let ty = &field.ty;
229
230            if sol_mode {
231                quote! {
232                    let #ident = <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::decode(&mut tmp, current_offset)?;
233                    current_offset += if <#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::IS_DYNAMIC {
234                        32
235                    } else {
236                        #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE)
237                    };
238                }
239            } else {
240                quote! {
241                    let #ident = <#ty as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::decode(buf, current_offset)?;
242                    current_offset += #crate_path::align_up::<ALIGN>(<#ty as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE);
243                }
244            }
245        });
246
247        quote! { #(#decode_fields)* }
248    }
249
250    /// Generate encode method implementation
251    fn generate_encode_impl(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
252        let crate_path = Self::get_crate_path();
253        let encode_fields = self.generate_encode_fields(sol_mode, is_static);
254        let aligned_header_size = self.generate_aligned_header_size(sol_mode, is_static);
255
256        if sol_mode {
257            quote! {
258                let aligned_offset = #crate_path::align_up::<ALIGN>(offset);
259                let is_dynamic = <Self as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::IS_DYNAMIC;
260                let aligned_header_size = #aligned_header_size;
261
262                let mut tail = if is_dynamic {
263                    let buf_len = buf.len();
264                    let offset = if buf_len != 0 { buf_len } else { 32 };
265                    #crate_path::write_u32_aligned::<B, ALIGN>(buf, aligned_offset, offset as u32);
266                    if buf.len() < aligned_header_size + offset {
267                        buf.resize(aligned_header_size + offset, 0);
268                    }
269                    buf.split_off(offset)
270                } else {
271                    if buf.len() < aligned_offset + aligned_header_size {
272                        buf.resize(aligned_offset + aligned_header_size, 0);
273                    }
274                    buf.split_off(aligned_offset)
275                };
276                let mut tail_offset = 0;
277
278                #encode_fields
279
280                buf.unsplit(tail);
281                Ok(())
282            }
283        } else {
284            quote! {
285                let mut current_offset = #crate_path::align_up::<ALIGN>(offset);
286                let header_size = <Self as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE;
287
288                if buf.len() < current_offset + header_size {
289                    buf.resize(current_offset + header_size, 0);
290                }
291
292                #encode_fields
293                Ok(())
294            }
295        }
296    }
297
298    /// Generate decode method implementation
299    fn generate_decode_impl(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
300        let crate_path = Self::get_crate_path();
301        let decode_fields = self.generate_decode_fields(sol_mode, is_static);
302        let struct_name = &self.struct_name;
303
304        // Get field identifiers for struct initialization
305        let struct_initialization = self.fields.iter().map(|field| {
306            let ident = &field.ident;
307            quote! { #ident }
308        });
309
310        let decode_body = if sol_mode {
311            quote! {
312                let mut aligned_offset = #crate_path::align_up::<ALIGN>(offset);
313
314                let mut tmp = if <Self as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::IS_DYNAMIC {
315                    let offset = #crate_path::read_u32_aligned::<B, ALIGN>(&buf.chunk(), aligned_offset)? as usize;
316                    &buf.chunk()[offset..]
317                } else {
318                    &buf.chunk()[aligned_offset..]
319                };
320
321                let mut current_offset = 0;
322
323                #decode_fields
324            }
325        } else {
326            quote! {
327                let mut current_offset = #crate_path::align_up::<ALIGN>(offset);
328                #decode_fields
329            }
330        };
331
332        quote! {
333            #decode_body
334
335            Ok(#struct_name {
336                #( #struct_initialization ),*
337            })
338        }
339    }
340
341    /// Generate partial_decode method implementation for struct data
342    fn generate_partial_decode_impl(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
343        let crate_path = Self::get_crate_path();
344
345        if sol_mode {
346            quote! {
347                // For Solidity ABI encoding
348                let aligned_offset = #crate_path::align_up::<ALIGN>(offset);
349
350                if <Self as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::IS_DYNAMIC {
351                    // For dynamic structs, read the offset pointer
352                    let data_offset = #crate_path::read_u32_aligned::<B, ALIGN>(&buffer.chunk(), aligned_offset)? as usize;
353                    // Return the actual data location and the header size
354                    Ok((data_offset, <Self as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE))
355                } else {
356                    // For static structs, return current offset and header size
357                    Ok((aligned_offset, <Self as #crate_path::Encoder<B, ALIGN, {true}, {#is_static}>>::HEADER_SIZE))
358                }
359            }
360        } else {
361            quote! {
362                // For Compact ABI encoding
363                let aligned_offset = #crate_path::align_up::<ALIGN>(offset);
364                // Return the current offset and the struct's header size
365                Ok((aligned_offset, <Self as #crate_path::Encoder<B, ALIGN, {false}, {#is_static}>>::HEADER_SIZE))
366            }
367        }
368    }
369
370    /// Generate the complete trait implementation for a specific mode and static/dynamic setting
371    fn generate_impl(&self, sol_mode: bool, is_static: bool) -> TokenStream2 {
372        let struct_name = &self.struct_name;
373        let crate_path = Self::get_crate_path();
374
375        let generics = self.prepare_generics(&self.generics);
376        let where_clause = self.add_encoder_bounds(&generics, sol_mode, is_static);
377        let (impl_generics, ty_generics, _) = generics.split_for_impl();
378
379        let has_custom_generics = !self.generics.params.is_empty();
380        let struct_name_with_ty = if has_custom_generics {
381            quote! { #struct_name #ty_generics }
382        } else {
383            quote! { #struct_name }
384        };
385
386        let header_size = self.generate_header_size_expr(sol_mode, is_static);
387        let is_dynamic = self.generate_is_dynamic_expr(sol_mode, is_static);
388
389        let encode_impl = self.generate_encode_impl(sol_mode, is_static);
390        let decode_impl = self.generate_decode_impl(sol_mode, is_static);
391        let partial_decode_impl = self.generate_partial_decode_impl(sol_mode, is_static);
392
393        quote! {
394            impl #impl_generics #crate_path::Encoder<B, ALIGN, {#sol_mode}, {#is_static}>
395                for #struct_name_with_ty
396                #where_clause
397            {
398                const HEADER_SIZE: usize = #header_size;
399                const IS_DYNAMIC: bool = #is_dynamic;
400
401                fn encode(&self, buf: &mut #crate_path::bytes::BytesMut, offset: usize) -> Result<(), #crate_path::CodecError> {
402                    #encode_impl
403                }
404
405                fn decode(buf: &impl #crate_path::bytes::Buf, offset: usize) -> Result<Self, #crate_path::CodecError> {
406                    #decode_impl
407                }
408
409                fn partial_decode(buffer: &impl #crate_path::bytes::Buf, offset: usize) -> Result<(usize, usize), #crate_path::CodecError> {
410                    #partial_decode_impl
411                }
412            }
413        }
414    }
415}
416
417impl ToTokens for CodecStruct {
418    fn to_tokens(&self, tokens: &mut TokenStream2) {
419        let sol_impl_static = self.generate_impl(true, true);
420        let sol_impl_dynamic = self.generate_impl(true, false);
421        let wasm_impl_static = self.generate_impl(false, true);
422        let wasm_impl_dynamic = self.generate_impl(false, false);
423
424        tokens.extend(quote! {
425            #sol_impl_static
426            #sol_impl_dynamic
427            #wasm_impl_static
428            #wasm_impl_dynamic
429        });
430    }
431}
432
433/// Derive macro for implementing Codec trait for structs
434#[proc_macro_derive(Codec, attributes(codec))]
435pub fn codec_macro_derive(input: TokenStream) -> TokenStream {
436    let ast = parse_macro_input!(input as DeriveInput);
437    let codec_struct = CodecStruct::parse(&ast);
438    quote! {
439        #codec_struct
440    }
441    .into()
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use insta::assert_snapshot;
448    use proc_macro2::TokenStream;
449    use syn::parse_quote;
450
451    fn get_generated_code(input: TokenStream) -> String {
452        let ast = syn::parse2::<DeriveInput>(input).unwrap();
453        let codec_struct = CodecStruct::parse(&ast);
454        let tokens = quote! { #codec_struct };
455        prettyplease::unparse(&syn::parse2::<syn::File>(tokens).unwrap())
456    }
457
458    #[test]
459    fn test_simple_struct() {
460        let input = parse_quote! {
461            #[derive(Codec, Default, Debug, PartialEq)]
462            struct TestStruct {
463                bool_val: bool,
464                bytes_val: Bytes,
465                vec_val: Vec<u32>,
466            }
467        };
468
469        assert_snapshot!("simple_struct", get_generated_code(input));
470    }
471
472    #[test]
473    fn test_generic_struct() {
474        let input = parse_quote! {
475            #[derive(Codec, Default, Debug, PartialEq)]
476            struct GenericStruct<T>
477            where
478                T: Clone + Default,
479            {
480                field1: T,
481                field2: Vec<T>,
482            }
483        };
484
485        assert_snapshot!("generic_struct", get_generated_code(input));
486    }
487
488    #[test]
489    fn test_single_field_struct() {
490        let input = parse_quote! {
491            #[derive(Codec, Default, Debug, PartialEq)]
492            struct SingleFieldStruct {
493                value: u64,
494            }
495        };
496
497        assert_snapshot!("single_field_struct", get_generated_code(input));
498    }
499
500    #[test]
501    fn test_empty_struct() {
502        let input = parse_quote! {
503            #[derive(Codec, Default, Debug, PartialEq)]
504            struct EmptyStruct {}
505        };
506
507        assert_snapshot!("empty_struct", get_generated_code(input));
508    }
509}