cosmwasm_derive/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use std::env;
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_quote,
7    punctuated::Punctuated,
8    ItemFn, Token,
9};
10
11macro_rules! maybe {
12    ($result:expr) => {{
13        match { $result } {
14            Ok(val) => val,
15            Err(err) => return err.into_compile_error(),
16        }
17    }};
18}
19
20struct Options {
21    crate_path: syn::Path,
22}
23
24impl Default for Options {
25    fn default() -> Self {
26        Self {
27            crate_path: parse_quote!(::cosmwasm_std),
28        }
29    }
30}
31
32impl Parse for Options {
33    fn parse(input: ParseStream) -> syn::Result<Self> {
34        let mut ret = Self::default();
35        let attrs = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
36
37        for kv in attrs {
38            if kv.path.is_ident("crate") {
39                let path_as_string: syn::LitStr = syn::parse2(kv.value.to_token_stream())?;
40                ret.crate_path = path_as_string.parse()?;
41            } else {
42                return Err(syn::Error::new_spanned(kv, "Unknown attribute"));
43            }
44        }
45
46        Ok(ret)
47    }
48}
49
50// function documented in cosmwasm-std
51#[proc_macro_attribute]
52pub fn entry_point(
53    attr: proc_macro::TokenStream,
54    item: proc_macro::TokenStream,
55) -> proc_macro::TokenStream {
56    entry_point_impl(attr.into(), item.into()).into()
57}
58
59fn expand_attributes(func: &mut ItemFn) -> syn::Result<TokenStream> {
60    let attributes = std::mem::take(&mut func.attrs);
61    let mut stream = TokenStream::new();
62    for attribute in attributes {
63        if !attribute.path().is_ident("migrate_version") {
64            func.attrs.push(attribute);
65            continue;
66        }
67
68        if func.sig.ident != "migrate" {
69            return Err(syn::Error::new_spanned(
70                &attribute,
71                "you only want to add this attribute to your migrate function",
72            ));
73        }
74
75        let version: syn::Expr = attribute.parse_args()?;
76        if !(matches!(version, syn::Expr::Lit(_)) || matches!(version, syn::Expr::Path(_))) {
77            return Err(syn::Error::new_spanned(
78                &attribute,
79                "Expected `u64` or `path::to::constant` in the migrate_version attribute",
80            ));
81        }
82
83        stream = quote! {
84            #stream
85
86            const _: () = {
87                #[allow(unused)]
88                #[doc(hidden)]
89                #[cfg(target_arch = "wasm32")]
90                #[link_section = "cw_migrate_version"]
91                /// This is an internal constant exported as a custom section denoting the contract migrate version.
92                /// The format and even the existence of this value is an implementation detail, DO NOT RELY ON THIS!
93                static __CW_MIGRATE_VERSION: [u8; version_size(#version)] = stringify_version(#version);
94
95                #[allow(unused)]
96                #[doc(hidden)]
97                const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
98                    let mut result: [u8; N] = [0; N];
99                    let mut index = N;
100                    while index > 0 {
101                        let digit: u8 = (version%10) as u8;
102                        result[index-1] = digit + b'0';
103                        version /= 10;
104                        index -= 1;
105                    }
106                    result
107                }
108
109                #[allow(unused)]
110                #[doc(hidden)]
111                const fn version_size(version: u64) -> usize {
112                    if version > 0 {
113                        (version.ilog10()+1) as usize
114                    } else {
115                        panic!("Contract migrate version should be greater than 0.")
116                    }
117                }
118            };
119        };
120    }
121
122    Ok(stream)
123}
124
125fn expand_bindings(crate_path: &syn::Path, mut function: syn::ItemFn) -> TokenStream {
126    let attribute_code = maybe!(expand_attributes(&mut function));
127
128    // The first argument is `deps`, the rest is region pointers
129    let args = function.sig.inputs.len().saturating_sub(1);
130    let fn_name = &function.sig.ident;
131    let wasm_export = format_ident!("__wasm_export_{fn_name}");
132
133    // Prevent contract dev from using the wrong identifier for the do_migrate_with_info function
134    if fn_name == "migrate_with_info" {
135        return syn::Error::new_spanned(
136            &function.sig.ident,
137            r#"To use the new migrate function signature, you should provide a "migrate" entry point with 4 arguments, not "migrate_with_info""#,
138        ).into_compile_error();
139    }
140
141    // Migrate entry point can take 2 or 3 arguments (not counting deps)
142    let do_call = if fn_name == "migrate" && args == 3 {
143        format_ident!("do_migrate_with_info")
144    } else {
145        format_ident!("do_{fn_name}")
146    };
147
148    let decl_args = (0..args).map(|item| format_ident!("ptr_{item}"));
149    let call_args = decl_args.clone();
150
151    quote! {
152        #attribute_code
153
154        #function
155
156        #[cfg(target_arch = "wasm32")]
157        mod #wasm_export { // new module to avoid conflict of function name
158            #[no_mangle]
159            extern "C" fn #fn_name(#( #decl_args : u32 ),*) -> u32 {
160                #crate_path::#do_call(&super::#fn_name, #( #call_args ),*)
161            }
162        }
163    }
164}
165
166fn entry_point_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
167    let mut function: syn::ItemFn = maybe!(syn::parse2(item));
168    let Options { crate_path } = maybe!(syn::parse2(attr));
169
170    if env::var("CARGO_PRIMARY_PACKAGE").is_ok() {
171        expand_bindings(&crate_path, function)
172    } else {
173        function
174            .attrs
175            .retain(|attr| !attr.path().is_ident("migrate_version"));
176
177        quote! { #function }
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use std::env;
184
185    use proc_macro2::TokenStream;
186    use quote::quote;
187
188    use crate::entry_point_impl;
189
190    fn setup_environment() {
191        env::set_var("CARGO_PRIMARY_PACKAGE", "1");
192    }
193
194    #[test]
195    fn contract_migrate_version_on_non_migrate() {
196        setup_environment();
197
198        let code = quote! {
199            #[migrate_version(42)]
200            fn anything_else() -> Response {
201                // Logic here
202            }
203        };
204
205        let actual = entry_point_impl(TokenStream::new(), code);
206        let expected = quote! {
207            ::core::compile_error! { "you only want to add this attribute to your migrate function" }
208        };
209
210        assert_eq!(actual.to_string(), expected.to_string());
211    }
212
213    #[test]
214    fn contract_migrate_version_expansion() {
215        setup_environment();
216
217        let code = quote! {
218            #[migrate_version(2)]
219            fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
220                // Logic here
221            }
222        };
223
224        let actual = entry_point_impl(TokenStream::new(), code);
225        let expected = quote! {
226            const _: () = {
227                #[allow(unused)]
228                #[doc(hidden)]
229                #[cfg(target_arch = "wasm32")]
230                #[link_section = "cw_migrate_version"]
231                /// This is an internal constant exported as a custom section denoting the contract migrate version.
232                /// The format and even the existence of this value is an implementation detail, DO NOT RELY ON THIS!
233                static __CW_MIGRATE_VERSION: [u8; version_size(2)] = stringify_version(2);
234
235                #[allow(unused)]
236                #[doc(hidden)]
237                const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
238                    let mut result: [u8; N] = [0; N];
239                    let mut index = N;
240                    while index > 0 {
241                        let digit: u8 = (version%10) as u8;
242                        result[index-1] = digit + b'0';
243                        version /= 10;
244                        index -= 1;
245                    }
246                    result
247                }
248
249                #[allow(unused)]
250                #[doc(hidden)]
251                const fn version_size(version: u64) -> usize {
252                    if version > 0 {
253                        (version.ilog10()+1) as usize
254                    } else {
255                        panic!("Contract migrate version should be greater than 0.")
256                    }
257                }
258            };
259
260            fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
261                // Logic here
262            }
263
264            #[cfg(target_arch = "wasm32")]
265            mod __wasm_export_migrate {
266                #[no_mangle]
267                extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
268                    ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
269                }
270            }
271        };
272
273        assert_eq!(actual.to_string(), expected.to_string());
274
275        // this should cause a compiler error
276        let code = quote! {
277            #[entry_point]
278            pub fn migrate_with_info(
279                deps: DepsMut,
280                env: Env,
281                msg: MigrateMsg,
282                migrate_info: MigrateInfo,
283            ) -> Result<Response, ()> {
284                // Logic here
285            }
286        };
287
288        let actual = entry_point_impl(TokenStream::new(), code);
289        let expected = quote! {
290            ::core::compile_error! { "To use the new migrate function signature, you should provide a \"migrate\" entry point with 4 arguments, not \"migrate_with_info\"" }
291        };
292
293        assert_eq!(actual.to_string(), expected.to_string());
294    }
295
296    #[test]
297    fn contract_migrate_version_with_const_expansion() {
298        setup_environment();
299
300        let code = quote! {
301            #[migrate_version(CONTRACT_VERSION)]
302            fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
303                // Logic here
304            }
305        };
306
307        let actual = entry_point_impl(TokenStream::new(), code);
308        let expected = quote! {
309            const _: () = {
310                #[allow(unused)]
311                #[doc(hidden)]
312                #[cfg(target_arch = "wasm32")]
313                #[link_section = "cw_migrate_version"]
314                /// This is an internal constant exported as a custom section denoting the contract migrate version.
315                /// The format and even the existence of this value is an implementation detail, DO NOT RELY ON THIS!
316                static __CW_MIGRATE_VERSION: [u8; version_size(CONTRACT_VERSION)] = stringify_version(CONTRACT_VERSION);
317
318                #[allow(unused)]
319                #[doc(hidden)]
320                const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
321                    let mut result: [u8; N] = [0; N];
322                    let mut index = N;
323                    while index > 0 {
324                        let digit: u8 = (version%10) as u8;
325                        result[index-1] = digit + b'0';
326                        version /= 10;
327                        index -= 1;
328                    }
329                    result
330                }
331
332                #[allow(unused)]
333                #[doc(hidden)]
334                const fn version_size(version: u64) -> usize {
335                    if version > 0 {
336                        (version.ilog10()+1) as usize
337                    } else {
338                        panic!("Contract migrate version should be greater than 0.")
339                    }
340                }
341            };
342
343            fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
344                // Logic here
345            }
346
347            #[cfg(target_arch = "wasm32")]
348            mod __wasm_export_migrate {
349                #[no_mangle]
350                extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
351                    ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
352                }
353            }
354        };
355
356        assert_eq!(actual.to_string(), expected.to_string());
357    }
358
359    #[test]
360    fn default_expansion() {
361        setup_environment();
362
363        let code = quote! {
364            fn instantiate(deps: DepsMut, env: Env) -> Response {
365                // Logic here
366            }
367        };
368
369        let actual = entry_point_impl(TokenStream::new(), code);
370        let expected = quote! {
371            fn instantiate(deps: DepsMut, env: Env) -> Response { }
372
373            #[cfg(target_arch = "wasm32")]
374            mod __wasm_export_instantiate {
375                #[no_mangle]
376                extern "C" fn instantiate(ptr_0: u32) -> u32 {
377                    ::cosmwasm_std::do_instantiate(&super::instantiate, ptr_0)
378                }
379            }
380        };
381
382        assert_eq!(actual.to_string(), expected.to_string());
383    }
384
385    #[test]
386    fn renamed_expansion() {
387        setup_environment();
388
389        let attribute = quote!(crate = "::my_crate::cw_std");
390        let code = quote! {
391            fn instantiate(deps: DepsMut, env: Env) -> Response {
392                // Logic here
393            }
394        };
395
396        let actual = entry_point_impl(attribute, code);
397        let expected = quote! {
398            fn instantiate(deps: DepsMut, env: Env) -> Response { }
399
400            #[cfg(target_arch = "wasm32")]
401            mod __wasm_export_instantiate {
402                #[no_mangle]
403                extern "C" fn instantiate(ptr_0: u32) -> u32 {
404                    ::my_crate::cw_std::do_instantiate(&super::instantiate, ptr_0)
405                }
406            }
407        };
408
409        assert_eq!(actual.to_string(), expected.to_string());
410    }
411}