1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Expr, ExprLit, Lit, Meta};
6
7macro_rules! match_kv_attr {
8 ($key:expr, $value_type:tt) => {
9 |tt| {
10 if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11 &tt[..]
12 {
13 if (key == $key) && (eq.as_char() == '=') {
14 Some(quote!(#value))
15 } else {
16 None
17 }
18 } else {
19 None
20 }
21 }
22 };
23}
24
25#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
26pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 let ident = input.ident;
29
30 let type_url = get_type_url(&input.attrs);
31
32 let (query_request_conversion, cosmwasm_query) = if get_attr("proto_query", &input.attrs)
37 .is_some()
38 {
39 let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
40 let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
41
42 let query_request_conversion = quote! {
43 impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
44 fn from(msg: #ident) -> Self {
45 cosmwasm_std::QueryRequest::<Q>::Grpc(cosmwasm_std::GrpcQuery{
46 path: #path.to_string(),
47 data: msg.into(),
48 })
49 }
50 }
51 };
52
53 let cosmwasm_query = quote! {
54 pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
55 let binary_result = querier.query_grpc(#path.to_string(), self.into())?;
56 let response_query = crate::types::tendermint::abci::ResponseQuery::try_from(binary_result)?;
57 #res::try_from(response_query.value)
58 }
59
60 pub fn mock_response<T: provwasm_common::MockableQuerier>(querier: &mut T, response: #res) {
61 querier.register_custom_query(#path.to_string(), Box::new(move |data| {
62 cosmwasm_std::SystemResult::Ok(cosmwasm_std::ContractResult::Ok(
63 cosmwasm_std::Binary::new(crate::types::tendermint::abci::ResponseQuery{
64 code: 0,
65 log: "".to_string(),
66 info: "".to_string(),
67 index: 0,
68 key: vec![],
69 value: response.to_proto_bytes(),
70 proof_ops: None,
71 height: 0,
72 codespace: "".to_string(),
73 }.to_proto_bytes())))
74 }))
75 }
76
77 pub fn mock_failed_response<T: provwasm_common::MockableQuerier>(querier: &mut T, error: String) {
78 querier.register_custom_query(#path.to_string(), Box::new(move |data| {
79 cosmwasm_std::SystemResult::Err(cosmwasm_std::SystemError::InvalidResponse {
80 error: error.clone(),
81 response: cosmwasm_std::Binary::default(),
82 })
83 }))
84 }
85 };
86
87 (query_request_conversion, cosmwasm_query)
88 } else {
89 (quote!(), quote!())
90 };
91
92 (quote! {
93 impl #ident {
94 pub const TYPE_URL: &'static str = #type_url;
95 #cosmwasm_query
96
97 pub fn to_proto_bytes(&self) -> Vec<u8> {
98 let mut bytes = Vec::new();
99 prost::Message::encode(self, &mut bytes)
100 .expect("Message encoding must be infallible");
101 bytes
102 }
103 pub fn to_any(&self) -> crate::shim::Any {
104 crate::shim::Any {
105 type_url: Self::TYPE_URL.to_string(),
106 value: self.to_proto_bytes(),
107 }
108 }
109 }
110
111 #query_request_conversion
112
113 impl From<#ident> for cosmwasm_std::Binary {
114 fn from(msg: #ident) -> Self {
115 cosmwasm_std::Binary::new(msg.to_proto_bytes())
116 }
117 }
118
119 impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
120 fn from(msg: #ident) -> Self {
121 cosmwasm_std::CosmosMsg::<T>::Any(cosmwasm_std::AnyMsg {
122 type_url: #type_url.to_string(),
123 value: msg.into(),
124 })
125 }
126 }
127
128 impl TryFrom<cosmwasm_std::Binary> for #ident {
129 type Error = cosmwasm_std::StdError;
130
131 fn try_from(binary: cosmwasm_std::Binary) -> ::std::result::Result<Self, Self::Error> {
132 use ::prost::Message;
133 Self::decode(&binary[..]).map_err(|e| {
134 cosmwasm_std::StdError::parse_err(
135 stringify!(#ident),
136 format!(
137 "Unable to decode binary: \n - base64: {}\n - bytes array: {:?}\n\n{:?}",
138 binary,
139 binary.to_vec(),
140 e
141 )
142 )
143 })
144 }
145 }
146
147 impl TryFrom<Vec<u8>> for #ident {
148 type Error = cosmwasm_std::StdError;
149
150 fn try_from(binary: Vec<u8>) -> ::std::result::Result<Self, Self::Error> {
151 use ::prost::Message;
152 Self::decode(&binary[..]).map_err(|e| {
153 cosmwasm_std::StdError::parse_err(
154 stringify!(#ident),
155 format!(
156 "Unable to decode binary:\n - bytes array: {:?}\n\n{:?}",
157 binary,
158 e
159 )
160 )
161 })
162 }
163 }
164
165 impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
166 type Error = cosmwasm_std::StdError;
167
168 fn try_from(result: cosmwasm_std::SubMsgResult) -> ::std::result::Result<Self, Self::Error> {
169 result
170 .into_result()
171 .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
172 .data
173 .ok_or_else(|| cosmwasm_std::StdError::not_found("cosmwasm_std::SubMsgResult::<T>"))?
174 .try_into()
175 }
176 }
177
178 impl TryFrom<crate::shim::Any> for #ident {
179 type Error = prost::DecodeError;
180
181 fn try_from(value: crate::shim::Any) -> ::std::result::Result<Self, Self::Error> {
182 prost::Message::decode(value.value.as_slice())
183 }
184 }
185
186 impl TryInto<crate::shim::Any> for #ident {
187 type Error = prost::EncodeError;
188
189 fn try_into(self) -> ::std::result::Result<crate::shim::Any, Self::Error> {
190 let value = prost::Message::encode_to_vec(&self);
191 Ok(crate::shim::Any {
192 type_url: <#ident>::TYPE_URL.to_string(),
193 value,
194 })
195 }
196 }
197 })
198 .into()
199}
200
201#[proc_macro_derive(SerdeEnumAsInt)]
202pub fn derive_serde_enum_as_int(input: TokenStream) -> TokenStream {
203 let input = parse_macro_input!(input as DeriveInput);
204 let ident = input.ident;
205 (quote! {
206 impl #ident {
207 pub fn serialize<S>(v: &i32, serializer: S) -> std::result::Result<S::Ok, S::Error>
208 where
209 S: serde::Serializer,
210 {
211 let enum_value = Self::try_from(*v);
212 match enum_value {
213 Ok(v) => serializer.serialize_str(v.as_str_name()),
214 Err(e) => Err(serde::ser::Error::custom(e)),
215 }
216 }
217
218 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<i32, D::Error>
219 where
220 D: serde::Deserializer<'de>,
221 {
222 use serde::de::Deserialize;
223 let s = String::deserialize(deserializer)?;
224 match Self::from_str_name(&s) {
225 Some(v) => Ok(v.into()),
226 None => Err(serde::de::Error::custom("unknown value")),
227 }
228 }
229
230 pub fn serialize_vec<S>(v: &Vec<i32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
231 where
232 S: serde::Serializer,
233 {
234 use serde::ser::SerializeTuple;
235
236 let mut enum_strs: Vec<&str> = Vec::new();
237 for ord in v {
238 let enum_value = Self::try_from(*ord);
240 match enum_value {
241 Ok(v) => {
242 enum_strs.push(v.as_str_name());
243 }
244 Err(e) => return Err(serde::ser::Error::custom(e)),
245 }
246 }
247 let mut seq = serializer.serialize_tuple(enum_strs.len())?;
248 for item in enum_strs {
249 seq.serialize_element(item)?;
250 }
251 seq.end()
252 }
253
254 fn deserialize_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<i32>, D::Error>
255 where
256 D: serde::Deserializer<'de>,
257 {
258 use serde::de::{Deserialize, Error};
259
260 let strs: Vec<String> = Vec::deserialize(deserializer)?;
261 let mut ords: Vec<i32> = Vec::new();
262 for str_name in strs {
263 let enum_value = Self::from_str_name(&str_name)
264 .ok_or_else(|| Error::custom(format!("unknown enum string: {}", str_name)))?;
265 ords.push(enum_value as i32);
266 }
267 Ok(ords)
268 }
269 }
270 })
271 .into()
272}
273
274fn get_type_url(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
275 let proto_message_attr = get_attr("proto_message", attrs);
276
277 if let Some(attr) = proto_message_attr {
278 let meta = &attr.meta;
279
280 if let Meta::List(meta_list) = meta {
281 let nested = meta_list.parse_args_with(
283 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
284 );
285
286 if let Ok(nested_metas) = nested {
287 if let Some(first) = nested_metas.first() {
288 if let Meta::NameValue(name_value) = first {
289 if name_value.path.is_ident("type_url") {
290 if let Expr::Lit(ExprLit {
291 lit: Lit::Str(s), ..
292 }) = &name_value.value
293 {
294 return quote!(#s);
295 } else {
296 return proto_message_attr_error(&name_value.value);
297 }
298 } else {
299 return proto_message_attr_error(&name_value.path);
300 }
301 } else {
302 return proto_message_attr_error(first);
303 }
304 }
305 }
306 }
307
308 proto_message_attr_error(attr)
309 } else {
310 proto_message_attr_error("proto_message attribute not found")
311 }
312}
313
314fn get_query_attrs<F>(attrs: &[syn::Attribute], f: F) -> proc_macro2::TokenStream
315where
316 F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>,
317{
318 let proto_query_attr = get_attr("proto_query", attrs);
319
320 if let Some(attr) = proto_query_attr {
321 let meta = &attr.meta;
322
323 if let Meta::List(meta_list) = meta {
324 let tokens = meta_list.tokens.clone();
325
326 if tokens.clone().into_iter().count() < 1 {
327 return proto_query_attr_error(attr);
328 }
329
330 let kv_groups = tokens.into_iter().chunk_by(|t| {
333 if let TokenTree::Punct(punct) = t {
334 punct.as_char() != ','
335 } else {
336 true
337 }
338 });
339 let mut key_values: Vec<Vec<TokenTree>> = vec![];
340
341 for (non_sep, g) in &kv_groups {
342 if non_sep {
343 key_values.push(g.collect());
344 }
345 }
346
347 return key_values
348 .iter()
349 .find_map(f)
350 .unwrap_or_else(|| proto_query_attr_error(attr));
351 }
352
353 proto_query_attr_error(attr)
354 } else {
355 proto_query_attr_error("proto_query attribute not found")
356 }
357}
358
359fn get_attr<'a>(attr_ident: &str, attrs: &'a [syn::Attribute]) -> Option<&'a syn::Attribute> {
360 attrs.iter().find(|&attr| {
361 attr.path().segments.len() == 1 && attr.path().segments[0].ident == attr_ident
362 })
363}
364
365fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
366 syn::Error::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`")
367 .to_compile_error()
368}
369
370fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
371 syn::Error::new_spanned(
372 tokens,
373 "expected `proto_query(path = \"...\", response_type = ...)`",
374 )
375 .to_compile_error()
376}