1use proc_macro2::TokenStream;
6use quote::{format_ident, quote, ToTokens};
7use std::env;
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_quote,
11 punctuated::Punctuated,
12 ItemFn, Token,
13};
14
15macro_rules! maybe {
16 ($result:expr) => {{
17 match { $result } {
18 Ok(val) => val,
19 Err(err) => return err.into_compile_error(),
20 }
21 }};
22}
23
24struct Options {
25 crate_path: syn::Path,
26}
27
28impl Default for Options {
29 fn default() -> Self {
30 Self {
31 crate_path: parse_quote!(::cosmwasm_std),
32 }
33 }
34}
35
36impl Parse for Options {
37 fn parse(input: ParseStream) -> syn::Result<Self> {
38 let mut ret = Self::default();
39 let attrs = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
40
41 for kv in attrs {
42 if kv.path.is_ident("crate") {
43 let path_as_string: syn::LitStr = syn::parse2(kv.value.to_token_stream())?;
44 ret.crate_path = path_as_string.parse()?;
45 } else {
46 return Err(syn::Error::new_spanned(kv, "Unknown attribute"));
47 }
48 }
49
50 Ok(ret)
51 }
52}
53
54#[proc_macro_attribute]
56pub fn entry_point(
57 attr: proc_macro::TokenStream,
58 item: proc_macro::TokenStream,
59) -> proc_macro::TokenStream {
60 entry_point_impl(attr.into(), item.into()).into()
61}
62
63fn expand_attributes(func: &mut ItemFn) -> syn::Result<TokenStream> {
64 let attributes = std::mem::take(&mut func.attrs);
65 let mut stream = TokenStream::new();
66 for attribute in attributes {
67 if !attribute.path().is_ident("migrate_version") {
68 func.attrs.push(attribute);
69 continue;
70 }
71
72 if func.sig.ident != "migrate" {
73 return Err(syn::Error::new_spanned(
74 &attribute,
75 "you only want to add this attribute to your migrate function",
76 ));
77 }
78
79 let version: syn::Expr = attribute.parse_args()?;
80 if !(matches!(version, syn::Expr::Lit(_)) || matches!(version, syn::Expr::Path(_))) {
81 return Err(syn::Error::new_spanned(
82 &attribute,
83 "Expected `u64` or `path::to::constant` in the migrate_version attribute",
84 ));
85 }
86
87 stream = quote! {
88 #stream
89
90 const _: () = {
91 #[allow(unused)]
92 #[doc(hidden)]
93 #[cfg(target_arch = "wasm32")]
94 #[link_section = "cw_migrate_version"]
95 static __CW_MIGRATE_VERSION: [u8; version_size(#version)] = stringify_version(#version);
98
99 #[allow(unused)]
100 #[doc(hidden)]
101 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
102 let mut result: [u8; N] = [0; N];
103 let mut index = N;
104 while index > 0 {
105 let digit: u8 = (version%10) as u8;
106 result[index-1] = digit + b'0';
107 version /= 10;
108 index -= 1;
109 }
110 result
111 }
112
113 #[allow(unused)]
114 #[doc(hidden)]
115 const fn version_size(version: u64) -> usize {
116 if version > 0 {
117 (version.ilog10()+1) as usize
118 } else {
119 panic!("Contract migrate version should be greater than 0.")
120 }
121 }
122 };
123 };
124 }
125
126 Ok(stream)
127}
128
129fn expand_bindings(crate_path: &syn::Path, mut function: syn::ItemFn) -> TokenStream {
130 let attribute_code = maybe!(expand_attributes(&mut function));
131
132 let args = function.sig.inputs.len().saturating_sub(1);
134 let fn_name = &function.sig.ident;
135 let wasm_export = format_ident!("__wasm_export_{fn_name}");
136
137 if fn_name == "migrate_with_info" {
139 return syn::Error::new_spanned(
140 &function.sig.ident,
141 r#"To use the new migrate function signature, you should provide a "migrate" entry point with 4 arguments, not "migrate_with_info""#,
142 ).into_compile_error();
143 }
144
145 let do_call = if fn_name == "migrate" && args == 3 {
147 format_ident!("do_migrate_with_info")
148 } else {
149 format_ident!("do_{fn_name}")
150 };
151
152 let decl_args = (0..args).map(|item| format_ident!("ptr_{item}"));
153 let call_args = decl_args.clone();
154
155 quote! {
156 #attribute_code
157
158 #function
159
160 #[cfg(target_arch = "wasm32")]
161 mod #wasm_export { #[no_mangle]
163 extern "C" fn #fn_name(#( #decl_args : u32 ),*) -> u32 {
164 #crate_path::#do_call(&super::#fn_name, #( #call_args ),*)
165 }
166 }
167 }
168}
169
170fn entry_point_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
171 let mut function: syn::ItemFn = maybe!(syn::parse2(item));
172 let Options { crate_path } = maybe!(syn::parse2(attr));
173
174 if env::var("CARGO_PRIMARY_PACKAGE").is_ok() {
175 expand_bindings(&crate_path, function)
176 } else {
177 function
178 .attrs
179 .retain(|attr| !attr.path().is_ident("migrate_version"));
180
181 quote! { #function }
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use std::env;
188
189 use proc_macro2::TokenStream;
190 use quote::quote;
191
192 use crate::entry_point_impl;
193
194 fn setup_environment() {
195 env::set_var("CARGO_PRIMARY_PACKAGE", "1");
196 }
197
198 #[test]
199 fn contract_migrate_version_on_non_migrate() {
200 setup_environment();
201
202 let code = quote! {
203 #[migrate_version(42)]
204 fn anything_else() -> Response {
205 }
207 };
208
209 let actual = entry_point_impl(TokenStream::new(), code);
210 let expected = quote! {
211 ::core::compile_error! { "you only want to add this attribute to your migrate function" }
212 };
213
214 assert_eq!(actual.to_string(), expected.to_string());
215 }
216
217 #[test]
218 fn contract_migrate_version_expansion() {
219 setup_environment();
220
221 let code = quote! {
222 #[migrate_version(2)]
223 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
224 }
226 };
227
228 let actual = entry_point_impl(TokenStream::new(), code);
229 let expected = quote! {
230 const _: () = {
231 #[allow(unused)]
232 #[doc(hidden)]
233 #[cfg(target_arch = "wasm32")]
234 #[link_section = "cw_migrate_version"]
235 static __CW_MIGRATE_VERSION: [u8; version_size(2)] = stringify_version(2);
238
239 #[allow(unused)]
240 #[doc(hidden)]
241 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
242 let mut result: [u8; N] = [0; N];
243 let mut index = N;
244 while index > 0 {
245 let digit: u8 = (version%10) as u8;
246 result[index-1] = digit + b'0';
247 version /= 10;
248 index -= 1;
249 }
250 result
251 }
252
253 #[allow(unused)]
254 #[doc(hidden)]
255 const fn version_size(version: u64) -> usize {
256 if version > 0 {
257 (version.ilog10()+1) as usize
258 } else {
259 panic!("Contract migrate version should be greater than 0.")
260 }
261 }
262 };
263
264 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
265 }
267
268 #[cfg(target_arch = "wasm32")]
269 mod __wasm_export_migrate {
270 #[no_mangle]
271 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
272 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
273 }
274 }
275 };
276
277 assert_eq!(actual.to_string(), expected.to_string());
278
279 let code = quote! {
281 #[entry_point]
282 pub fn migrate_with_info(
283 deps: DepsMut,
284 env: Env,
285 msg: MigrateMsg,
286 migrate_info: MigrateInfo,
287 ) -> Result<Response, ()> {
288 }
290 };
291
292 let actual = entry_point_impl(TokenStream::new(), code);
293 let expected = quote! {
294 ::core::compile_error! { "To use the new migrate function signature, you should provide a \"migrate\" entry point with 4 arguments, not \"migrate_with_info\"" }
295 };
296
297 assert_eq!(actual.to_string(), expected.to_string());
298 }
299
300 #[test]
301 fn contract_migrate_version_with_const_expansion() {
302 setup_environment();
303
304 let code = quote! {
305 #[migrate_version(CONTRACT_VERSION)]
306 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
307 }
309 };
310
311 let actual = entry_point_impl(TokenStream::new(), code);
312 let expected = quote! {
313 const _: () = {
314 #[allow(unused)]
315 #[doc(hidden)]
316 #[cfg(target_arch = "wasm32")]
317 #[link_section = "cw_migrate_version"]
318 static __CW_MIGRATE_VERSION: [u8; version_size(CONTRACT_VERSION)] = stringify_version(CONTRACT_VERSION);
321
322 #[allow(unused)]
323 #[doc(hidden)]
324 const fn stringify_version<const N: usize>(mut version: u64) -> [u8; N] {
325 let mut result: [u8; N] = [0; N];
326 let mut index = N;
327 while index > 0 {
328 let digit: u8 = (version%10) as u8;
329 result[index-1] = digit + b'0';
330 version /= 10;
331 index -= 1;
332 }
333 result
334 }
335
336 #[allow(unused)]
337 #[doc(hidden)]
338 const fn version_size(version: u64) -> usize {
339 if version > 0 {
340 (version.ilog10()+1) as usize
341 } else {
342 panic!("Contract migrate version should be greater than 0.")
343 }
344 }
345 };
346
347 fn migrate(deps: DepsMut, env: Env, msg: MigrateMsg) -> Response {
348 }
350
351 #[cfg(target_arch = "wasm32")]
352 mod __wasm_export_migrate {
353 #[no_mangle]
354 extern "C" fn migrate(ptr_0: u32, ptr_1: u32) -> u32 {
355 ::cosmwasm_std::do_migrate(&super::migrate, ptr_0, ptr_1)
356 }
357 }
358 };
359
360 assert_eq!(actual.to_string(), expected.to_string());
361 }
362
363 #[test]
364 fn default_expansion() {
365 setup_environment();
366
367 let code = quote! {
368 fn instantiate(deps: DepsMut, env: Env) -> Response {
369 }
371 };
372
373 let actual = entry_point_impl(TokenStream::new(), code);
374 let expected = quote! {
375 fn instantiate(deps: DepsMut, env: Env) -> Response { }
376
377 #[cfg(target_arch = "wasm32")]
378 mod __wasm_export_instantiate {
379 #[no_mangle]
380 extern "C" fn instantiate(ptr_0: u32) -> u32 {
381 ::cosmwasm_std::do_instantiate(&super::instantiate, ptr_0)
382 }
383 }
384 };
385
386 assert_eq!(actual.to_string(), expected.to_string());
387 }
388
389 #[test]
390 fn renamed_expansion() {
391 setup_environment();
392
393 let attribute = quote!(crate = "::my_crate::cw_std");
394 let code = quote! {
395 fn instantiate(deps: DepsMut, env: Env) -> Response {
396 }
398 };
399
400 let actual = entry_point_impl(attribute, code);
401 let expected = quote! {
402 fn instantiate(deps: DepsMut, env: Env) -> Response { }
403
404 #[cfg(target_arch = "wasm32")]
405 mod __wasm_export_instantiate {
406 #[no_mangle]
407 extern "C" fn instantiate(ptr_0: u32) -> u32 {
408 ::my_crate::cw_std::do_instantiate(&super::instantiate, ptr_0)
409 }
410 }
411 };
412
413 assert_eq!(actual.to_string(), expected.to_string());
414 }
415}