1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use std::env;
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_quote,
7 punctuated::Punctuated,
8 ItemFn, Token,
9};
10
11macro_rules! maybe {
12 ($result:expr) => {{
13 match { $result } {
14 Ok(val) => val,
15 Err(err) => return err.into_compile_error(),
16 }
17 }};
18}
19
20struct Options {
21 crate_path: syn::Path,
22}
23
24impl Default for Options {
25 fn default() -> Self {
26 Self {
27 crate_path: parse_quote!(::cosmwasm_std),
28 }
29 }
30}
31
32impl Parse for Options {
33 fn parse(input: ParseStream) -> syn::Result<Self> {
34 let mut ret = Self::default();
35 let attrs = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
36
37 for kv in attrs {
38 if kv.path.is_ident("crate") {
39 let path_as_string: syn::LitStr = syn::parse2(kv.value.to_token_stream())?;
40 ret.crate_path = path_as_string.parse()?;
41 } else {
42 return Err(syn::Error::new_spanned(kv, "Unknown attribute"));
43 }
44 }
45
46 Ok(ret)
47 }
48}
49
50#[proc_macro_attribute]
52pub fn entry_point(
53 attr: proc_macro::TokenStream,
54 item: proc_macro::TokenStream,
55) -> proc_macro::TokenStream {
56 entry_point_impl(attr.into(), item.into()).into()
57}
58
59fn expand_attributes(func: &mut ItemFn) -> syn::Result<TokenStream> {
60 let attributes = std::mem::take(&mut func.attrs);
61 let mut stream = TokenStream::new();
62 for attribute in attributes {
63 if !attribute.path().is_ident("migrate_version") {
64 func.attrs.push(attribute);
65 continue;
66 }
67
68 if func.sig.ident != "migrate" {
69 return Err(syn::Error::new_spanned(
70 &attribute,
71 "you only want to add this attribute to your migrate function",
72 ));
73 }
74
75 let version: syn::Expr = attribute.parse_args()?;
76 if !(matches!(version, syn::Expr::Lit(_)) || matches!(version, syn::Expr::Path(_))) {
77 return Err(syn::Error::new_spanned(
78 &attribute,
79 "Expected `u64` or `path::to::constant` in the migrate_version attribute",
80 ));
81 }
82
83 stream = quote! {
84 #stream
85
86 const _: () = {
87 #[allow(unused)]
88 #[doc(hidden)]
89 #[cfg(target_arch = "wasm32")]
90 #[link_section = "cw_migrate_version"]
91 static __CW_MIGRATE_VERSION: [u8; version_size(#version)] = stringify_version(#version);
94
95 #[allow(unused)]
96 #[doc(hidden)]
97 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
98 let mut result: [u8; N] = [0; N];
99 let mut index = N;
100 while index > 0 {
101 let digit: u8 = (version%10) as u8;
102 result[index-1] = digit + b'0';
103 version /= 10;
104 index -= 1;
105 }
106 result
107 }
108
109 #[allow(unused)]
110 #[doc(hidden)]
111 const fn version_size(version: u64) -> usize {
112 if version > 0 {
113 (version.ilog10()+1) as usize
114 } else {
115 panic!("Contract migrate version should be greater than 0.")
116 }
117 }
118 };
119 };
120 }
121
122 Ok(stream)
123}
124
125fn expand_bindings(crate_path: &syn::Path, mut function: syn::ItemFn) -> TokenStream {
126 let attribute_code = maybe!(expand_attributes(&mut function));
127
128 let args = function.sig.inputs.len().saturating_sub(1);
130 let fn_name = &function.sig.ident;
131 let wasm_export = format_ident!("__wasm_export_{fn_name}");
132
133 if fn_name == "migrate_with_info" {
135 return syn::Error::new_spanned(
136 &function.sig.ident,
137 r#"To use the new migrate function signature, you should provide a "migrate" entry point with 4 arguments, not "migrate_with_info""#,
138 ).into_compile_error();
139 }
140
141 let do_call = if fn_name == "migrate" && args == 3 {
143 format_ident!("do_migrate_with_info")
144 } else {
145 format_ident!("do_{fn_name}")
146 };
147
148 let decl_args = (0..args).map(|item| format_ident!("ptr_{item}"));
149 let call_args = decl_args.clone();
150
151 quote! {
152 #attribute_code
153
154 #function
155
156 #[cfg(target_arch = "wasm32")]
157 mod #wasm_export { #[no_mangle]
159 extern "C" fn #fn_name(#( #decl_args : u32 ),*) -> u32 {
160 #crate_path::#do_call(&super::#fn_name, #( #call_args ),*)
161 }
162 }
163 }
164}
165
166fn entry_point_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
167 let mut function: syn::ItemFn = maybe!(syn::parse2(item));
168 let Options { crate_path } = maybe!(syn::parse2(attr));
169
170 if env::var("CARGO_PRIMARY_PACKAGE").is_ok() {
171 expand_bindings(&crate_path, function)
172 } else {
173 function
174 .attrs
175 .retain(|attr| !attr.path().is_ident("migrate_version"));
176
177 quote! { #function }
178 }
179}
180
181#[cfg(test)]
182mod test {
183 use std::env;
184
185 use proc_macro2::TokenStream;
186 use quote::quote;
187
188 use crate::entry_point_impl;
189
190 fn setup_environment() {
191 env::set_var("CARGO_PRIMARY_PACKAGE", "1");
192 }
193
194 #[test]
195 fn contract_migrate_version_on_non_migrate() {
196 setup_environment();
197
198 let code = quote! {
199 #[migrate_version(42)]
200 fn anything_else() -> Response {
201 }
203 };
204
205 let actual = entry_point_impl(TokenStream::new(), code);
206 let expected = quote! {
207 ::core::compile_error! { "you only want to add this attribute to your migrate function" }
208 };
209
210 assert_eq!(actual.to_string(), expected.to_string());
211 }
212
213 #[test]
214 fn contract_migrate_version_expansion() {
215 setup_environment();
216
217 let code = quote! {
218 #[migrate_version(2)]
219 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
220 }
222 };
223
224 let actual = entry_point_impl(TokenStream::new(), code);
225 let expected = quote! {
226 const _: () = {
227 #[allow(unused)]
228 #[doc(hidden)]
229 #[cfg(target_arch = "wasm32")]
230 #[link_section = "cw_migrate_version"]
231 static __CW_MIGRATE_VERSION: [u8; version_size(2)] = stringify_version(2);
234
235 #[allow(unused)]
236 #[doc(hidden)]
237 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
238 let mut result: [u8; N] = [0; N];
239 let mut index = N;
240 while index > 0 {
241 let digit: u8 = (version%10) as u8;
242 result[index-1] = digit + b'0';
243 version /= 10;
244 index -= 1;
245 }
246 result
247 }
248
249 #[allow(unused)]
250 #[doc(hidden)]
251 const fn version_size(version: u64) -> usize {
252 if version > 0 {
253 (version.ilog10()+1) as usize
254 } else {
255 panic!("Contract migrate version should be greater than 0.")
256 }
257 }
258 };
259
260 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
261 }
263
264 #[cfg(target_arch = "wasm32")]
265 mod __wasm_export_migrate {
266 #[no_mangle]
267 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
268 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
269 }
270 }
271 };
272
273 assert_eq!(actual.to_string(), expected.to_string());
274
275 let code = quote! {
277 #[entry_point]
278 pub fn migrate_with_info(
279 deps: DepsMut,
280 env: Env,
281 msg: MigrateMsg,
282 migrate_info: MigrateInfo,
283 ) -> Result<Response, ()> {
284 }
286 };
287
288 let actual = entry_point_impl(TokenStream::new(), code);
289 let expected = quote! {
290 ::core::compile_error! { "To use the new migrate function signature, you should provide a \"migrate\" entry point with 4 arguments, not \"migrate_with_info\"" }
291 };
292
293 assert_eq!(actual.to_string(), expected.to_string());
294 }
295
296 #[test]
297 fn contract_migrate_version_with_const_expansion() {
298 setup_environment();
299
300 let code = quote! {
301 #[migrate_version(CONTRACT_VERSION)]
302 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
303 }
305 };
306
307 let actual = entry_point_impl(TokenStream::new(), code);
308 let expected = quote! {
309 const _: () = {
310 #[allow(unused)]
311 #[doc(hidden)]
312 #[cfg(target_arch = "wasm32")]
313 #[link_section = "cw_migrate_version"]
314 static __CW_MIGRATE_VERSION: [u8; version_size(CONTRACT_VERSION)] = stringify_version(CONTRACT_VERSION);
317
318 #[allow(unused)]
319 #[doc(hidden)]
320 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
321 let mut result: [u8; N] = [0; N];
322 let mut index = N;
323 while index > 0 {
324 let digit: u8 = (version%10) as u8;
325 result[index-1] = digit + b'0';
326 version /= 10;
327 index -= 1;
328 }
329 result
330 }
331
332 #[allow(unused)]
333 #[doc(hidden)]
334 const fn version_size(version: u64) -> usize {
335 if version > 0 {
336 (version.ilog10()+1) as usize
337 } else {
338 panic!("Contract migrate version should be greater than 0.")
339 }
340 }
341 };
342
343 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
344 }
346
347 #[cfg(target_arch = "wasm32")]
348 mod __wasm_export_migrate {
349 #[no_mangle]
350 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
351 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
352 }
353 }
354 };
355
356 assert_eq!(actual.to_string(), expected.to_string());
357 }
358
359 #[test]
360 fn default_expansion() {
361 setup_environment();
362
363 let code = quote! {
364 fn instantiate(deps: DepsMut, env: Env) -> Response {
365 }
367 };
368
369 let actual = entry_point_impl(TokenStream::new(), code);
370 let expected = quote! {
371 fn instantiate(deps: DepsMut, env: Env) -> Response { }
372
373 #[cfg(target_arch = "wasm32")]
374 mod __wasm_export_instantiate {
375 #[no_mangle]
376 extern "C" fn instantiate(ptr_0: u32) -> u32 {
377 ::cosmwasm_std::do_instantiate(&super::instantiate, ptr_0)
378 }
379 }
380 };
381
382 assert_eq!(actual.to_string(), expected.to_string());
383 }
384
385 #[test]
386 fn renamed_expansion() {
387 setup_environment();
388
389 let attribute = quote!(crate = "::my_crate::cw_std");
390 let code = quote! {
391 fn instantiate(deps: DepsMut, env: Env) -> Response {
392 }
394 };
395
396 let actual = entry_point_impl(attribute, code);
397 let expected = quote! {
398 fn instantiate(deps: DepsMut, env: Env) -> Response { }
399
400 #[cfg(target_arch = "wasm32")]
401 mod __wasm_export_instantiate {
402 #[no_mangle]
403 extern "C" fn instantiate(ptr_0: u32) -> u32 {
404 ::my_crate::cw_std::do_instantiate(&super::instantiate, ptr_0)
405 }
406 }
407 };
408
409 assert_eq!(actual.to_string(), expected.to_string());
410 }
411}