cw_orch_contract_derive/
lib.rs

1#![recursion_limit = "128"]
2
3use syn::{Expr, Token};
4use syn::{__private::TokenStream2, parse_macro_input, Fields, GenericArgument, Item, Path};
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8
9use quote::quote;
10
11use syn::{punctuated::Punctuated, token::Comma};
12
13use syn::parse::{Parse, ParseStream};
14
15mod kw {
16    syn::custom_keyword!(id);
17}
18// This is used to parse the types into a list of types separated by Commas
19// and default contract id if provided by "id = $expr"
20struct InterfaceInput {
21    expressions: Punctuated<Path, Comma>,
22    _kw_id: Option<kw::id>,
23    _eq_token: Option<Token![=]>,
24    default_id: Option<Expr>,
25}
26
27// Implement the `Parse` trait for your input struct
28impl Parse for InterfaceInput {
29    fn parse(input: ParseStream) -> syn::Result<Self> {
30        let mut expressions: Punctuated<Path, Comma> = Punctuated::new();
31
32        while let Ok(path) = input.parse() {
33            expressions.push(path);
34            let _: Option<Token![,]> = input.parse().ok();
35
36            // If we found id = break
37            if input.peek(kw::id) {
38                break;
39            }
40        }
41        // Parse if there is any
42        let kw_id: Option<kw::id> = input.parse().map_err(|_| {
43            syn::Error::new(
44                input.span(),
45                "The 5th argument of the macro should be of the format `id=my_contract_id`",
46            )
47        })?;
48        let eq_token: Option<Token![=]> = input.parse().map_err(|_| {
49            syn::Error::new(
50                input.span(),
51                "The 5th argument of the macro should be of the format `id=my_contract_id`",
52            )
53        })?;
54        let default_id: Option<Expr> = input.parse().ok();
55        Ok(Self {
56            expressions,
57            _kw_id: kw_id,
58            _eq_token: eq_token,
59            default_id,
60        })
61    }
62}
63
64// Gets the generics associated with a type
65fn get_generics_from_path(p: &Path) -> Punctuated<GenericArgument, Comma> {
66    let mut generics = Punctuated::new();
67
68    for segment in p.segments.clone() {
69        if let syn::PathArguments::AngleBracketed(generic_args) = &segment.arguments {
70            for arg in generic_args.args.clone() {
71                generics.push(arg);
72            }
73        }
74    }
75
76    generics
77}
78
79/**
80Procedural macro to generate a cw-orchestrator interface
81
82## Example
83
84```ignore
85#[interface(
86    cw20_base::msg::InstantiateMsg,
87    cw20_base::msg::ExecuteMsg,
88    cw20_base::msg::QueryMsg,
89    cw20_base::msg::MigrateMsg
90)]
91pub struct Cw20;
92```
93This generated the following code:
94
95```ignore
96
97// This struct represents the interface to the contract.
98pub struct Cw20<Chain>(::cw_orch::core::contract::Contract<Chain>);
99
100impl <Chain> Cw20<Chain> {
101    /// Constructor for the contract interface
102     pub fn new(contract_id: impl ToString, chain: Chain) -> Self {
103        Self(
104            ::cw_orch::core::contract::Contract::new(contract_id, chain)
105        )
106    }
107}
108
109// Traits for signaling cw-orchestrator with what messages to call the contract's entry points.
110impl <Chain> ::cw_orch::core::contract::interface_traits::InstantiableContract for Cw20<Chain> {
111    type InstantiateMsg = InstantiateMsg;
112}
113impl <Chain> ::cw_orch::core::contract::interface_traits::ExecutableContract for Cw20<Chain> {
114    type ExecuteMsg = ExecuteMsg;
115}
116// ... other entry point & upload traits
117```
118
119## Linking the interface to its source code
120
121The interface can be linked to its source code by implementing the `Uploadable` trait for the interface.
122
123```ignore
124use cw_orch::prelude::*;
125
126impl <Chain> Uploadable for Cw20<Chain> {
127    fn wrapper() -> <Mock as cw_orch::TxHandler>::ContractSource {
128        Box::new(
129            ContractWrapper::new_with_empty(
130                cw20_base::contract::execute,
131                cw20_base::contract::instantiate,
132                cw20_base::contract::query,
133            )
134            .with_migrate(cw20_base::contract::migrate),
135        )
136    }
137
138    fn wasm(_chain: &ChainInfoOwned) -> <Daemon as cw_orch::TxHandler>::ContractSource {
139        WasmPath::new("path/to/cw20.wasm").unwrap()
140    }
141}
142*/
143#[proc_macro_attribute]
144pub fn interface(attrs: TokenStream, input: TokenStream) -> TokenStream {
145    let mut item = parse_macro_input!(input as syn::Item);
146
147    // Try to parse the attributes to a
148    let attributes = parse_macro_input!(attrs as InterfaceInput);
149
150    let types_in_order = attributes.expressions;
151    let default_id = attributes.default_id;
152
153    if types_in_order.len() != 4 {
154        panic!("Expected four endpoint types (InstantiateMsg, ExecuteMsg, QueryMsg, MigrateMsg). Use cosmwasm_std::Empty if not implemented.")
155    }
156
157    let Item::Struct(cw_orch_struct) = &mut item else {
158        panic!("Only works on structs");
159    };
160    let Fields::Unit = &mut cw_orch_struct.fields else {
161        panic!("Struct must be unit-struct");
162    };
163
164    let init = types_in_order[0].clone();
165    let exec = types_in_order[1].clone();
166    let query = types_in_order[2].clone();
167    let migrate = types_in_order[3].clone();
168
169    // We create all generics for all types
170    let all_generics: Punctuated<GenericArgument, Comma> = types_in_order
171        .iter()
172        .flat_map(get_generics_from_path)
173        .collect();
174    // We create all phantom markers because else types are unused
175    let all_phantom_markers: Vec<TokenStream2> = all_generics
176        .iter()
177        .map(|t| {
178            quote!(
179                ::std::marker::PhantomData<#t>
180            )
181        })
182        .collect();
183
184    let all_phantom_marker_values: Vec<TokenStream2> = all_generics
185        .iter()
186        .map(|_| quote!(::std::marker::PhantomData::default()))
187        .collect();
188
189    // We create necessary Debug + Serialize traits
190    let all_debug_serialize: Vec<TokenStream2> = all_generics
191        .iter()
192        .map(|t| {
193            quote!(
194                #t: ::std::fmt::Debug + ::serde::Serialize
195            )
196        })
197        .collect();
198    let all_debug_serialize = if !all_debug_serialize.is_empty() {
199        quote!(where #(#all_debug_serialize,)*)
200    } else {
201        quote!()
202    };
203
204    let name = cw_orch_struct.ident.clone();
205    let default_num = if let Some(id_expr) = default_id {
206        quote!(
207            impl <Chain, #all_generics> #name<Chain, #all_generics> {
208                pub fn new(chain: Chain) -> Self {
209                    Self(
210                        ::cw_orch::core::contract::Contract::new(#id_expr, chain)
211                    , #(#all_phantom_marker_values,)*)
212                }
213            }
214        )
215    } else {
216        quote!(
217            impl <Chain, #all_generics> #name<Chain, #all_generics> {
218                pub fn new(contract_id: impl ToString, chain: Chain) -> Self {
219                    Self(
220                        ::cw_orch::core::contract::Contract::new(contract_id, chain)
221                    , #(#all_phantom_marker_values,)*)
222                }
223            }
224        )
225    };
226    let struct_def = quote!(
227        #[cfg(not(target_arch = "wasm32"))]
228        #[derive(
229            ::std::clone::Clone,
230        )]
231        pub struct #name<Chain, #all_generics>(::cw_orch::core::contract::Contract<Chain>, #(#all_phantom_markers,)*);
232
233        #[cfg(target_arch = "wasm32")]
234        #[derive(
235            ::std::clone::Clone,
236        )]
237        pub struct #name;
238
239        #[cfg(not(target_arch = "wasm32"))]
240        #default_num
241
242        #[cfg(not(target_arch = "wasm32"))]
243        impl<Chain: ::cw_orch::core::environment::ChainState, #all_generics> ::cw_orch::core::contract::interface_traits::ContractInstance<Chain> for #name<Chain, #all_generics> {
244            fn as_instance(&self) -> &::cw_orch::core::contract::Contract<Chain> {
245                &self.0
246            }
247            fn as_instance_mut(&mut self) -> &mut ::cw_orch::core::contract::Contract<Chain> {
248                &mut self.0
249            }
250        }
251
252        #[cfg(not(target_arch = "wasm32"))]
253        impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::InstantiableContract for #name<Chain, #all_generics> #all_debug_serialize {
254            type InstantiateMsg = #init;
255        }
256
257        #[cfg(not(target_arch = "wasm32"))]
258        impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::ExecutableContract for #name<Chain, #all_generics> #all_debug_serialize {
259            type ExecuteMsg = #exec;
260        }
261
262        #[cfg(not(target_arch = "wasm32"))]
263        impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::QueryableContract for #name<Chain, #all_generics> #all_debug_serialize {
264            type QueryMsg = #query;
265        }
266
267        #[cfg(not(target_arch = "wasm32"))]
268        impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::MigratableContract for #name<Chain, #all_generics> #all_debug_serialize {
269            type MigrateMsg = #migrate;
270        }
271    );
272    struct_def.into()
273}