1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use std::env;
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_quote,
7 punctuated::Punctuated,
8 ItemFn, Token,
9};
10
11macro_rules! maybe {
12 ($result:expr) => {{
13 match { $result } {
14 Ok(val) => val,
15 Err(err) => return err.into_compile_error(),
16 }
17 }};
18}
19
20struct Options {
21 crate_path: syn::Path,
22}
23
24impl Default for Options {
25 fn default() -> Self {
26 Self {
27 crate_path: parse_quote!(::cosmwasm_std),
28 }
29 }
30}
31
32impl Parse for Options {
33 fn parse(input: ParseStream) -> syn::Result<Self> {
34 let mut ret = Self::default();
35 let attrs = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
36
37 for kv in attrs {
38 if kv.path.is_ident("crate") {
39 let path_as_string: syn::LitStr = syn::parse2(kv.value.to_token_stream())?;
40 ret.crate_path = path_as_string.parse()?;
41 } else {
42 return Err(syn::Error::new_spanned(kv, "Unknown attribute"));
43 }
44 }
45
46 Ok(ret)
47 }
48}
49
50#[proc_macro_attribute]
52pub fn entry_point(
53 attr: proc_macro::TokenStream,
54 item: proc_macro::TokenStream,
55) -> proc_macro::TokenStream {
56 entry_point_impl(attr.into(), item.into()).into()
57}
58
59fn expand_attributes(func: &mut ItemFn) -> syn::Result<TokenStream> {
60 let attributes = std::mem::take(&mut func.attrs);
61 let mut stream = TokenStream::new();
62 for attribute in attributes {
63 if !attribute.path().is_ident("migrate_version") {
64 func.attrs.push(attribute);
65 continue;
66 }
67
68 if func.sig.ident != "migrate" {
69 return Err(syn::Error::new_spanned(
70 &attribute,
71 "you only want to add this attribute to your migrate function",
72 ));
73 }
74
75 let version: syn::Expr = attribute.parse_args()?;
76 if !(matches!(version, syn::Expr::Lit(_)) || matches!(version, syn::Expr::Path(_))) {
77 return Err(syn::Error::new_spanned(
78 &attribute,
79 "Expected `u64` or `path::to::constant` in the migrate_version attribute",
80 ));
81 }
82
83 stream = quote! {
84 #stream
85
86 const _: () = {
87 #[allow(unused)]
88 #[doc(hidden)]
89 #[cfg(target_arch = "wasm32")]
90 #[link_section = "cw_migrate_version"]
91 static __CW_MIGRATE_VERSION: [u8; version_size(#version)] = stringify_version(#version);
94
95 #[allow(unused)]
96 #[doc(hidden)]
97 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
98 let mut result: [u8; N] = [0; N];
99 let mut index = N;
100 while index > 0 {
101 let digit: u8 = (version%10) as u8;
102 result[index-1] = digit + b'0';
103 version /= 10;
104 index -= 1;
105 }
106 result
107 }
108
109 #[allow(unused)]
110 #[doc(hidden)]
111 const fn version_size(version: u64) -> usize {
112 if version > 0 {
113 (version.ilog10()+1) as usize
114 } else {
115 panic!("Contract migrate version should be greater than 0.")
116 }
117 }
118 };
119 };
120 }
121
122 Ok(stream)
123}
124
125fn expand_bindings(crate_path: &syn::Path, mut function: syn::ItemFn) -> TokenStream {
126 let attribute_code = maybe!(expand_attributes(&mut function));
127
128 let args = function.sig.inputs.len().saturating_sub(1);
130 let fn_name = &function.sig.ident;
131 let wasm_export = format_ident!("__wasm_export_{fn_name}");
132
133 let do_call = if fn_name == "migrate" && args == 3 {
135 format_ident!("do_migrate_with_info")
136 } else {
137 format_ident!("do_{fn_name}")
138 };
139
140 let decl_args = (0..args).map(|item| format_ident!("ptr_{item}"));
141 let call_args = decl_args.clone();
142
143 quote! {
144 #attribute_code
145
146 #function
147
148 #[cfg(target_arch = "wasm32")]
149 mod #wasm_export { #[no_mangle]
151 extern "C" fn #fn_name(#( #decl_args : u32 ),*) -> u32 {
152 #crate_path::#do_call(&super::#fn_name, #( #call_args ),*)
153 }
154 }
155 }
156}
157
158fn entry_point_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
159 let mut function: syn::ItemFn = maybe!(syn::parse2(item));
160 let Options { crate_path } = maybe!(syn::parse2(attr));
161
162 if env::var("CARGO_PRIMARY_PACKAGE").is_ok() {
163 expand_bindings(&crate_path, function)
164 } else {
165 function
166 .attrs
167 .retain(|attr| !attr.path().is_ident("migrate_version"));
168
169 quote! { #function }
170 }
171}
172
173#[cfg(test)]
174mod test {
175 use std::env;
176
177 use proc_macro2::TokenStream;
178 use quote::quote;
179
180 use crate::entry_point_impl;
181
182 fn setup_environment() {
183 env::set_var("CARGO_PRIMARY_PACKAGE", "1");
184 }
185
186 #[test]
187 fn contract_migrate_version_on_non_migrate() {
188 setup_environment();
189
190 let code = quote! {
191 #[migrate_version(42)]
192 fn anything_else() -> Response {
193 }
195 };
196
197 let actual = entry_point_impl(TokenStream::new(), code);
198 let expected = quote! {
199 ::core::compile_error! { "you only want to add this attribute to your migrate function" }
200 };
201
202 assert_eq!(actual.to_string(), expected.to_string());
203 }
204
205 #[test]
206 fn contract_migrate_version_expansion() {
207 setup_environment();
208
209 let code = quote! {
210 #[migrate_version(2)]
211 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
212 }
214 };
215
216 let actual = entry_point_impl(TokenStream::new(), code);
217 let expected = quote! {
218 const _: () = {
219 #[allow(unused)]
220 #[doc(hidden)]
221 #[cfg(target_arch = "wasm32")]
222 #[link_section = "cw_migrate_version"]
223 static __CW_MIGRATE_VERSION: [u8; version_size(2)] = stringify_version(2);
226
227 #[allow(unused)]
228 #[doc(hidden)]
229 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
230 let mut result: [u8; N] = [0; N];
231 let mut index = N;
232 while index > 0 {
233 let digit: u8 = (version%10) as u8;
234 result[index-1] = digit + b'0';
235 version /= 10;
236 index -= 1;
237 }
238 result
239 }
240
241 #[allow(unused)]
242 #[doc(hidden)]
243 const fn version_size(version: u64) -> usize {
244 if version > 0 {
245 (version.ilog10()+1) as usize
246 } else {
247 panic!("Contract migrate version should be greater than 0.")
248 }
249 }
250 };
251
252 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
253 }
255
256 #[cfg(target_arch = "wasm32")]
257 mod __wasm_export_migrate {
258 #[no_mangle]
259 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
260 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
261 }
262 }
263 };
264
265 assert_eq!(actual.to_string(), expected.to_string());
266 }
267
268 #[test]
269 fn contract_migrate_version_with_const_expansion() {
270 setup_environment();
271
272 let code = quote! {
273 #[migrate_version(CONTRACT_VERSION)]
274 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
275 }
277 };
278
279 let actual = entry_point_impl(TokenStream::new(), code);
280 let expected = quote! {
281 const _: () = {
282 #[allow(unused)]
283 #[doc(hidden)]
284 #[cfg(target_arch = "wasm32")]
285 #[link_section = "cw_migrate_version"]
286 static __CW_MIGRATE_VERSION: [u8; version_size(CONTRACT_VERSION)] = stringify_version(CONTRACT_VERSION);
289
290 #[allow(unused)]
291 #[doc(hidden)]
292 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
293 let mut result: [u8; N] = [0; N];
294 let mut index = N;
295 while index > 0 {
296 let digit: u8 = (version%10) as u8;
297 result[index-1] = digit + b'0';
298 version /= 10;
299 index -= 1;
300 }
301 result
302 }
303
304 #[allow(unused)]
305 #[doc(hidden)]
306 const fn version_size(version: u64) -> usize {
307 if version > 0 {
308 (version.ilog10()+1) as usize
309 } else {
310 panic!("Contract migrate version should be greater than 0.")
311 }
312 }
313 };
314
315 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
316 }
318
319 #[cfg(target_arch = "wasm32")]
320 mod __wasm_export_migrate {
321 #[no_mangle]
322 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
323 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
324 }
325 }
326 };
327
328 assert_eq!(actual.to_string(), expected.to_string());
329 }
330
331 #[test]
332 fn default_expansion() {
333 setup_environment();
334
335 let code = quote! {
336 fn instantiate(deps: DepsMut, env: Env) -> Response {
337 }
339 };
340
341 let actual = entry_point_impl(TokenStream::new(), code);
342 let expected = quote! {
343 fn instantiate(deps: DepsMut, env: Env) -> Response { }
344
345 #[cfg(target_arch = "wasm32")]
346 mod __wasm_export_instantiate {
347 #[no_mangle]
348 extern "C" fn instantiate(ptr_0: u32) -> u32 {
349 ::cosmwasm_std::do_instantiate(&super::instantiate, ptr_0)
350 }
351 }
352 };
353
354 assert_eq!(actual.to_string(), expected.to_string());
355 }
356
357 #[test]
358 fn renamed_expansion() {
359 setup_environment();
360
361 let attribute = quote!(crate = "::my_crate::cw_std");
362 let code = quote! {
363 fn instantiate(deps: DepsMut, env: Env) -> Response {
364 }
366 };
367
368 let actual = entry_point_impl(attribute, code);
369 let expected = quote! {
370 fn instantiate(deps: DepsMut, env: Env) -> Response { }
371
372 #[cfg(target_arch = "wasm32")]
373 mod __wasm_export_instantiate {
374 #[no_mangle]
375 extern "C" fn instantiate(ptr_0: u32) -> u32 {
376 ::my_crate::cw_std::do_instantiate(&super::instantiate, ptr_0)
377 }
378 }
379 };
380
381 assert_eq!(actual.to_string(), expected.to_string());
382 }
383}