kelk_derive/
lib.rs

1//! Kelk-derive contains derive macros for auto-generated code used in [Pactus](https://pactus.org/) blockchain.
2//!
3use proc_macro2::TokenStream;
4use quote::{quote, quote_spanned, ToTokens};
5use std::str::FromStr;
6use syn::{
7    parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed,
8    FieldsUnnamed, GenericParam, Generics, Ident, Index, Type, TypeParamBound,
9};
10
11/// The attribute macro to inject the code at the beginning of entry functions
12/// for the Wasm contract actor.
13///
14/// It can be added to the contract's instantiate, process and query functions
15/// like this:
16/// ```
17/// use kelk::kelk_entry;
18/// use kelk::context::Context;
19///
20/// type InstantiateMsg = ();
21/// type ProcessMsg = ();
22/// type QueryMsg = ();
23///
24/// enum Error {};
25///
26/// #[kelk_entry]
27/// pub fn instantiate(ctx: Context, msg: InstantiateMsg) -> Result<(), Error> {
28///    unimplemented!();
29/// }
30///
31/// #[kelk_entry]
32/// pub fn process(ctx: Context, msg: ProcessMsg) -> Result<(), Error> {
33///   unimplemented!();
34/// }
35///
36/// #[kelk_entry]
37/// pub fn query(ctx: Context, msg: QueryMsg) -> Result<(), Error> {
38///   unimplemented!();
39/// }
40/// ```
41///
42/// where `InstantiateMsg`, `ProcessMsg`, and `QueryMsg` are contract defined
43/// types that implement CBOR encoding.
44#[proc_macro_attribute]
45pub fn kelk_entry(
46    _attr: proc_macro::TokenStream,
47    mut item: proc_macro::TokenStream,
48) -> proc_macro::TokenStream {
49    let cloned = item.clone();
50    let function = parse_macro_input!(cloned as syn::ItemFn);
51    let name = function.sig.ident.to_string();
52
53    let method = match name.as_ref() {
54        "instantiate" => "create",
55        "process" => "load",
56        "query" => "load",
57        _ => {
58            return proc_macro::TokenStream::from(quote! {
59                compile_error!("entry function should be either \"instantiate\", \"process\", or \"query\""),
60            })
61        }
62    };
63
64    let gen_code = format!(
65        r##"
66        #[cfg(target_arch = "wasm32")]
67        mod __wasm_export_{name} {{
68            #[no_mangle]
69            extern "C" fn {name}(msg_ptr: u64) -> u64 {{
70                let ctx = kelk::context::OwnedContext {{
71                    storage: kelk::storage::Storage::{method}(kelk::alloc::boxed::Box::new(kelk::Kelk::new()))
72                        .unwrap(),
73                    blockchain: kelk::blockchain::Blockchain::new(kelk::alloc::boxed::Box::new(
74                        kelk::Kelk::new(),
75                    )),
76                }};
77
78                kelk::do_{name}(&super::{name}, ctx.as_ref(), msg_ptr)
79            }}
80        }}
81    "##,
82    );
83
84    let entry = proc_macro::TokenStream::from_str(&gen_code).unwrap();
85    item.extend(entry);
86    item
87}
88
89/// Derives `Codec` trait for the given `struct`.
90///
91/// # Examples
92///
93/// ```
94/// use kelk::Codec;
95/// use kelk::storage::codec::Codec;
96///
97/// #[derive(Codec)]
98/// struct Test {
99///     a: u32,
100/// }
101///
102/// assert_eq!(Test::PACKED_LEN, 4);
103/// ```
104#[proc_macro_derive(Codec)]
105pub fn derive_codec(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
106    // Parse the input tokens into a syntax tree.
107    let input = parse_macro_input!(input as DeriveInput);
108
109    // Used in the quasi-quotation below as `#name`.
110    let name = input.ident;
111
112    // Add a bound `T: Codec` to every type parameter T.
113    let generics = add_trait_bounds(input.generics, parse_quote!(Codec));
114    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
115
116    // Generate an expression to sum up the heap size of each field.
117    let packed_len_body = packed_len_body(&input.data);
118    let (to_bytes_body, from_bytes_body) = codec_body(&input.data);
119
120    let expanded = quote! {
121        impl #impl_generics Codec for #name #ty_generics #where_clause {
122
123            const PACKED_LEN: u32 = #packed_len_body;
124
125            #[inline]
126            fn to_bytes(&self, bytes: &mut [u8]) {
127                debug_assert_eq!(bytes.len(), Self::PACKED_LEN as usize);
128
129                #to_bytes_body
130            }
131
132            #[inline]
133            fn from_bytes(bytes: &[u8]) -> Self {
134                debug_assert_eq!(bytes.len(), Self::PACKED_LEN as usize);
135
136                Self { #from_bytes_body }
137            }
138        }
139    };
140
141    // Hand the output tokens back to the compiler.
142    proc_macro::TokenStream::from(expanded)
143}
144
145fn packed_len_body(data: &Data) -> TokenStream {
146    match *data {
147        Data::Struct(ref data) => {
148            match data.fields {
149                Fields::Named(ref fields) => {
150                    // Expands to an expression like
151                    //
152                    //     0 + <self.x as Codec>::PACKED_LEN + <self.y as Codec>::PACKED_LEN
153                    let recurse = fields.named.iter().map(|f| {
154                        let ty = &f.ty;
155                        quote_spanned! {f.span()=>
156                            <#ty as Codec>::PACKED_LEN
157                        }
158                    });
159
160                    quote! {
161                        0  #(+ #recurse)*
162                    }
163                }
164                Fields::Unnamed(ref fields) => {
165                    // Expands to an expression like
166                    //
167                    //     0 + <self.0 as Codec>::PACKED_LEN + <self.1 as Codec>::PACKED_LEN
168                    let recurse = fields.unnamed.iter().map(|f| {
169                        let ty = &f.ty;
170                        quote_spanned! {f.span()=>
171                            <#ty as Codec>::PACKED_LEN
172                        }
173                    });
174                    quote! {
175                        0 #(+ #recurse)*
176                    }
177                }
178                Fields::Unit => {
179                    // Unit structs cannot own more than 0 bytes of heap memory.
180                    quote!(0)
181                }
182            }
183        }
184        Data::Enum(_) | Data::Union(_) => unimplemented!(),
185    }
186}
187
188// Add a bound `T: trait_bound` to every type parameter T.
189fn add_trait_bounds(mut generics: Generics, trait_bound: TypeParamBound) -> Generics {
190    for param in &mut generics.params {
191        if let GenericParam::Type(ref mut type_param) = *param {
192            type_param.bounds.push(trait_bound.clone());
193        }
194    }
195    generics
196}
197
198fn codec_body(data: &Data) -> (TokenStream, TokenStream) {
199    // this also contains `bytes` variable
200    match *data {
201        Data::Struct(ref data) => {
202            match data.fields {
203                //  Normal struct: named fields
204                Fields::Named(FieldsNamed { ref named, .. }) => {
205                    //  Collect references to all the field names. A precondition of
206                    //  reaching this code path is that all fields HAVE names, so it
207                    //  is safe to have an unreachable trap in the None condition.
208                    let names: Vec<(&Type, &Ident)> = named
209                        .iter()
210                        .map(|f| (&f.ty, f.ident.as_ref().unwrap()))
211                        .collect();
212                    codegen_struct(&names)
213                }
214                //  Tuple struct: unnamed fields
215                Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) => {
216                    let mut nums: Vec<(&Type, Index)> = Vec::new();
217                    for (i, f) in unnamed.into_iter().enumerate() {
218                        nums.push((&f.ty, i.into()));
219                    }
220                    codegen_struct(&nums)
221                }
222
223                Fields::Unit => {
224                    // Unit structs cannot own more than 0 bytes of heap memory.
225                    (quote!(0), quote!(0))
226                }
227            }
228        }
229        Data::Enum(_) | Data::Union(_) => unimplemented!(),
230    }
231}
232
233fn codegen_struct<T: ToTokens>(fields: &[(&Type, T)]) -> (TokenStream, TokenStream) {
234    let mut beg_offset = quote! { 0 };
235    let mut recurse_to_bytes = vec![];
236    let mut recurse_from_bytes = vec![];
237
238    for field in fields.iter() {
239        let ty = field.0;
240        let name = &field.1;
241        let struct_size = quote! { <#ty as Codec>::PACKED_LEN as usize};
242        let end_offset = quote! { #beg_offset + #struct_size };
243        let bytes_slice = quote! { bytes[#beg_offset..#end_offset] };
244
245        recurse_to_bytes.push(quote! {
246            Codec::to_bytes(&self.#name, &mut #bytes_slice);
247        });
248
249        recurse_from_bytes.push(quote! {
250            #name: Codec::from_bytes(& #bytes_slice),
251        });
252
253        beg_offset = quote! { #beg_offset + #struct_size };
254    }
255
256    (
257        quote! {
258            #(#recurse_to_bytes)*
259        },
260        quote! {
261            #(#recurse_from_bytes)*
262        },
263    )
264}