cw_orch_contract_derive/
lib.rs1#![recursion_limit = "128"]
2
3use syn::{Expr, Token};
4use syn::{__private::TokenStream2, parse_macro_input, Fields, GenericArgument, Item, Path};
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8
9use quote::quote;
10
11use syn::{punctuated::Punctuated, token::Comma};
12
13use syn::parse::{Parse, ParseStream};
14
15mod kw {
16 syn::custom_keyword!(id);
17}
18struct InterfaceInput {
21 expressions: Punctuated<Path, Comma>,
22 _kw_id: Option<kw::id>,
23 _eq_token: Option<Token![=]>,
24 default_id: Option<Expr>,
25}
26
27impl Parse for InterfaceInput {
29 fn parse(input: ParseStream) -> syn::Result<Self> {
30 let mut expressions: Punctuated<Path, Comma> = Punctuated::new();
31
32 while let Ok(path) = input.parse() {
33 expressions.push(path);
34 let _: Option<Token![,]> = input.parse().ok();
35
36 if input.peek(kw::id) {
38 break;
39 }
40 }
41 let kw_id: Option<kw::id> = input.parse().map_err(|_| {
43 syn::Error::new(
44 input.span(),
45 "The 5th argument of the macro should be of the format `id=my_contract_id`",
46 )
47 })?;
48 let eq_token: Option<Token![=]> = input.parse().map_err(|_| {
49 syn::Error::new(
50 input.span(),
51 "The 5th argument of the macro should be of the format `id=my_contract_id`",
52 )
53 })?;
54 let default_id: Option<Expr> = input.parse().ok();
55 Ok(Self {
56 expressions,
57 _kw_id: kw_id,
58 _eq_token: eq_token,
59 default_id,
60 })
61 }
62}
63
64fn get_generics_from_path(p: &Path) -> Punctuated<GenericArgument, Comma> {
66 let mut generics = Punctuated::new();
67
68 for segment in p.segments.clone() {
69 if let syn::PathArguments::AngleBracketed(generic_args) = &segment.arguments {
70 for arg in generic_args.args.clone() {
71 generics.push(arg);
72 }
73 }
74 }
75
76 generics
77}
78
79#[proc_macro_attribute]
144pub fn interface(attrs: TokenStream, input: TokenStream) -> TokenStream {
145 let mut item = parse_macro_input!(input as syn::Item);
146
147 let attributes = parse_macro_input!(attrs as InterfaceInput);
149
150 let types_in_order = attributes.expressions;
151 let default_id = attributes.default_id;
152
153 if types_in_order.len() != 4 {
154 panic!("Expected four endpoint types (InstantiateMsg, ExecuteMsg, QueryMsg, MigrateMsg). Use cosmwasm_std::Empty if not implemented.")
155 }
156
157 let Item::Struct(cw_orch_struct) = &mut item else {
158 panic!("Only works on structs");
159 };
160 let Fields::Unit = &mut cw_orch_struct.fields else {
161 panic!("Struct must be unit-struct");
162 };
163
164 let init = types_in_order[0].clone();
165 let exec = types_in_order[1].clone();
166 let query = types_in_order[2].clone();
167 let migrate = types_in_order[3].clone();
168
169 let all_generics: Punctuated<GenericArgument, Comma> = types_in_order
171 .iter()
172 .flat_map(get_generics_from_path)
173 .collect();
174 let all_phantom_markers: Vec<TokenStream2> = all_generics
176 .iter()
177 .map(|t| {
178 quote!(
179 ::std::marker::PhantomData<#t>
180 )
181 })
182 .collect();
183
184 let all_phantom_marker_values: Vec<TokenStream2> = all_generics
185 .iter()
186 .map(|_| quote!(::std::marker::PhantomData::default()))
187 .collect();
188
189 let all_debug_serialize: Vec<TokenStream2> = all_generics
191 .iter()
192 .map(|t| {
193 quote!(
194 #t: ::std::fmt::Debug + ::serde::Serialize
195 )
196 })
197 .collect();
198 let all_debug_serialize = if !all_debug_serialize.is_empty() {
199 quote!(where #(#all_debug_serialize,)*)
200 } else {
201 quote!()
202 };
203
204 let name = cw_orch_struct.ident.clone();
205 let default_num = if let Some(id_expr) = default_id {
206 quote!(
207 impl <Chain, #all_generics> #name<Chain, #all_generics> {
208 pub fn new(chain: Chain) -> Self {
209 Self(
210 ::cw_orch::core::contract::Contract::new(#id_expr, chain)
211 , #(#all_phantom_marker_values,)*)
212 }
213 }
214 )
215 } else {
216 quote!(
217 impl <Chain, #all_generics> #name<Chain, #all_generics> {
218 pub fn new(contract_id: impl ToString, chain: Chain) -> Self {
219 Self(
220 ::cw_orch::core::contract::Contract::new(contract_id, chain)
221 , #(#all_phantom_marker_values,)*)
222 }
223 }
224 )
225 };
226 let struct_def = quote!(
227 #[cfg(not(target_arch = "wasm32"))]
228 #[derive(
229 ::std::clone::Clone,
230 )]
231 pub struct #name<Chain, #all_generics>(::cw_orch::core::contract::Contract<Chain>, #(#all_phantom_markers,)*);
232
233 #[cfg(target_arch = "wasm32")]
234 #[derive(
235 ::std::clone::Clone,
236 )]
237 pub struct #name;
238
239 #[cfg(not(target_arch = "wasm32"))]
240 #default_num
241
242 #[cfg(not(target_arch = "wasm32"))]
243 impl<Chain: ::cw_orch::core::environment::ChainState, #all_generics> ::cw_orch::core::contract::interface_traits::ContractInstance<Chain> for #name<Chain, #all_generics> {
244 fn as_instance(&self) -> &::cw_orch::core::contract::Contract<Chain> {
245 &self.0
246 }
247 fn as_instance_mut(&mut self) -> &mut ::cw_orch::core::contract::Contract<Chain> {
248 &mut self.0
249 }
250 }
251
252 #[cfg(not(target_arch = "wasm32"))]
253 impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::InstantiableContract for #name<Chain, #all_generics> #all_debug_serialize {
254 type InstantiateMsg = #init;
255 }
256
257 #[cfg(not(target_arch = "wasm32"))]
258 impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::ExecutableContract for #name<Chain, #all_generics> #all_debug_serialize {
259 type ExecuteMsg = #exec;
260 }
261
262 #[cfg(not(target_arch = "wasm32"))]
263 impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::QueryableContract for #name<Chain, #all_generics> #all_debug_serialize {
264 type QueryMsg = #query;
265 }
266
267 #[cfg(not(target_arch = "wasm32"))]
268 impl<Chain, #all_generics> ::cw_orch::core::contract::interface_traits::MigratableContract for #name<Chain, #all_generics> #all_debug_serialize {
269 type MigrateMsg = #migrate;
270 }
271 );
272 struct_def.into()
273}