provwasm_proc_macro/
lib.rs

1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput};
6
7macro_rules! match_kv_attr {
8    ($key:expr, $value_type:tt) => {
9        |tt| {
10            if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11                &tt[..]
12            {
13                if (key == $key) && (eq.as_char() == '=') {
14                    Some(quote!(#value))
15                } else {
16                    None
17                }
18            } else {
19                None
20            }
21        }
22    };
23}
24
25#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
26pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(input as DeriveInput);
28    let ident = input.ident;
29
30    let type_url = get_type_url(&input.attrs);
31
32    // `EncodeError` always indicates that a message failed to encode because the
33    // provided buffer had insufficient capacity. Message encoding is otherwise
34    // infallible.
35
36    let (query_request_conversion, cosmwasm_query) = if get_attr("proto_query", &input.attrs)
37        .is_some()
38    {
39        let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
40        let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
41
42        let query_request_conversion = quote! {
43            impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
44                fn from(msg: #ident) -> Self {
45                    cosmwasm_std::QueryRequest::<Q>::Grpc(cosmwasm_std::GrpcQuery{
46                        path: #path.to_string(),
47                        data: msg.into(),
48                    })
49                }
50            }
51        };
52
53        let cosmwasm_query = quote! {
54            pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
55                let binary_result = querier.query_grpc(#path.to_string(), self.into())?;
56                let response_query = crate::types::tendermint::abci::ResponseQuery::try_from(binary_result)?;
57                #res::try_from(response_query.value)
58            }
59
60            pub fn mock_response<T: provwasm_common::MockableQuerier>(querier: &mut T, response: #res) {
61                querier.register_custom_query(#path.to_string(), Box::new(move |data| {
62                    cosmwasm_std::SystemResult::Ok(cosmwasm_std::ContractResult::Ok(
63                        cosmwasm_std::Binary::new(crate::types::tendermint::abci::ResponseQuery{
64                            code: 0,
65                            log: "".to_string(),
66                            info: "".to_string(),
67                            index: 0,
68                            key: vec![],
69                            value: response.to_proto_bytes(),
70                            proof_ops: None,
71                            height: 0,
72                            codespace: "".to_string(),
73                        }.to_proto_bytes())))
74                }))
75            }
76
77            pub fn mock_failed_response<T: provwasm_common::MockableQuerier>(querier: &mut T, error: String) {
78                querier.register_custom_query(#path.to_string(), Box::new(move |data| {
79                    cosmwasm_std::SystemResult::Err(cosmwasm_std::SystemError::InvalidResponse {
80                        error: error.clone(),
81                        response: cosmwasm_std::Binary::default(),
82                    })
83                }))
84            }
85        };
86
87        (query_request_conversion, cosmwasm_query)
88    } else {
89        (quote!(), quote!())
90    };
91
92    (quote! {
93        impl #ident {
94            pub const TYPE_URL: &'static str = #type_url;
95            #cosmwasm_query
96
97            pub fn to_proto_bytes(&self) -> Vec<u8> {
98                let mut bytes = Vec::new();
99                prost::Message::encode(self, &mut bytes)
100                    .expect("Message encoding must be infallible");
101                bytes
102            }
103            pub fn to_any(&self) -> crate::shim::Any {
104                crate::shim::Any {
105                    type_url: Self::TYPE_URL.to_string(),
106                    value: self.to_proto_bytes(),
107                }
108            }
109        }
110
111        #query_request_conversion
112
113        impl From<#ident> for cosmwasm_std::Binary {
114            fn from(msg: #ident) -> Self {
115                cosmwasm_std::Binary::new(msg.to_proto_bytes())
116            }
117        }
118
119        impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
120            fn from(msg: #ident) -> Self {
121                cosmwasm_std::CosmosMsg::<T>::Any(cosmwasm_std::AnyMsg {
122                    type_url: #type_url.to_string(),
123                    value: msg.into(),
124                })
125            }
126        }
127
128        impl TryFrom<cosmwasm_std::Binary> for #ident {
129            type Error = cosmwasm_std::StdError;
130
131            fn try_from(binary: cosmwasm_std::Binary) -> ::std::result::Result<Self, Self::Error> {
132                use ::prost::Message;
133                Self::decode(&binary[..]).map_err(|e| {
134                    cosmwasm_std::StdError::parse_err(
135                        stringify!(#ident),
136                        format!(
137                            "Unable to decode binary: \n  - base64: {}\n  - bytes array: {:?}\n\n{:?}",
138                            binary,
139                            binary.to_vec(),
140                            e
141                        )
142                    )
143                })
144            }
145        }
146
147        impl TryFrom<Vec<u8>> for #ident {
148            type Error = cosmwasm_std::StdError;
149
150            fn try_from(binary: Vec<u8>) -> ::std::result::Result<Self, Self::Error> {
151                use ::prost::Message;
152                Self::decode(&binary[..]).map_err(|e| {
153                    cosmwasm_std::StdError::parse_err(
154                        stringify!(#ident),
155                        format!(
156                            "Unable to decode binary:\n  - bytes array: {:?}\n\n{:?}",
157                            binary,
158                            e
159                        )
160                    )
161                })
162            }
163        }
164
165        impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
166            type Error = cosmwasm_std::StdError;
167
168            fn try_from(result: cosmwasm_std::SubMsgResult) -> ::std::result::Result<Self, Self::Error> {
169                result
170                    .into_result()
171                    .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
172                    .data
173                    .ok_or_else(|| cosmwasm_std::StdError::not_found("cosmwasm_std::SubMsgResult::<T>"))?
174                    .try_into()
175            }
176        }
177
178        impl TryFrom<crate::shim::Any> for #ident {
179            type Error = prost::DecodeError;
180
181            fn try_from(value: crate::shim::Any) -> ::std::result::Result<Self, Self::Error> {
182                prost::Message::decode(value.value.as_slice())
183            }
184        }
185
186        impl TryInto<crate::shim::Any> for #ident {
187            type Error = prost::EncodeError;
188
189            fn try_into(self) -> ::std::result::Result<crate::shim::Any, Self::Error> {
190                let value = prost::Message::encode_to_vec(&self);
191                Ok(crate::shim::Any {
192                    type_url: <#ident>::TYPE_URL.to_string(),
193                    value,
194                })
195            }
196        }
197    })
198        .into()
199}
200
201#[proc_macro_derive(SerdeEnumAsInt)]
202pub fn derive_serde_enum_as_int(input: TokenStream) -> TokenStream {
203    let input = parse_macro_input!(input as DeriveInput);
204    let ident = input.ident;
205    (quote! {
206        impl #ident {
207            pub fn serialize<S>(v: &i32, serializer: S) -> std::result::Result<S::Ok, S::Error>
208            where
209                S: serde::Serializer,
210            {
211                let enum_value = Self::try_from(*v);
212                match enum_value {
213                    Ok(v) => serializer.serialize_str(v.as_str_name()),
214                    Err(e) => Err(serde::ser::Error::custom(e)),
215                }
216            }
217
218            pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<i32, D::Error>
219            where
220                D: serde::Deserializer<'de>,
221            {
222                use serde::de::Deserialize;
223                let s = String::deserialize(deserializer)?;
224                match Self::from_str_name(&s) {
225                    Some(v) => Ok(v.into()),
226                    None => Err(serde::de::Error::custom("unknown value")),
227                }
228            }
229
230            pub fn serialize_vec<S>(v: &Vec<i32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
231            where
232                S: serde::Serializer,
233            {
234                use serde::ser::SerializeTuple;
235
236                let mut enum_strs: Vec<&str> = Vec::new();
237                for ord in v {
238                    // let enum_value = Self::try_from(*ord);
239                    let enum_value = Self::try_from(*ord);
240                    match enum_value {
241                        Ok(v) => {
242                            enum_strs.push(v.as_str_name());
243                        }
244                        Err(e) => return Err(serde::ser::Error::custom(e)),
245                    }
246                }
247                let mut seq = serializer.serialize_tuple(enum_strs.len())?;
248                for item in enum_strs {
249                    seq.serialize_element(item)?;
250                }
251                seq.end()
252            }
253
254            fn deserialize_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<i32>, D::Error>
255            where
256                D: serde::Deserializer<'de>,
257            {
258                use serde::de::{Deserialize, Error};
259
260                let strs: Vec<String> = Vec::deserialize(deserializer)?;
261                let mut ords: Vec<i32> = Vec::new();
262                for str_name in strs {
263                    let enum_value = Self::from_str_name(&str_name)
264                        .ok_or_else(|| Error::custom(format!("unknown enum string: {}", str_name)))?;
265                    ords.push(enum_value as i32);
266                }
267                Ok(ords)
268            }
269        }
270    })
271        .into()
272}
273
274fn get_type_url(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
275    let proto_message = get_attr("proto_message", attrs).and_then(|a| a.parse_meta().ok());
276
277    if let Some(syn::Meta::List(meta)) = proto_message.clone() {
278        match meta.nested[0].clone() {
279            syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => {
280                if meta.path.is_ident("type_url") {
281                    match meta.lit {
282                        syn::Lit::Str(s) => quote!(#s),
283                        _ => proto_message_attr_error(meta.lit),
284                    }
285                } else {
286                    proto_message_attr_error(meta.path)
287                }
288            }
289            t => proto_message_attr_error(t),
290        }
291    } else {
292        proto_message_attr_error(proto_message)
293    }
294}
295
296fn get_query_attrs<F>(attrs: &[syn::Attribute], f: F) -> proc_macro2::TokenStream
297where
298    F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>,
299{
300    let proto_query = get_attr("proto_query", attrs);
301
302    if let Some(attr) = proto_query {
303        if attr.tokens.clone().into_iter().count() != 1 {
304            return proto_query_attr_error(proto_query);
305        }
306
307        if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
308            let kv_groups = group.stream().into_iter().chunk_by(|t| {
309                if let TokenTree::Punct(punct) = t {
310                    punct.as_char() != ','
311                } else {
312                    true
313                }
314            });
315            let mut key_values: Vec<Vec<TokenTree>> = vec![];
316
317            for (non_sep, g) in &kv_groups {
318                if non_sep {
319                    key_values.push(g.collect());
320                }
321            }
322
323            return key_values
324                .iter()
325                .find_map(f)
326                .unwrap_or_else(|| proto_query_attr_error(proto_query));
327        }
328
329        proto_query_attr_error(proto_query)
330    } else {
331        proto_query_attr_error(proto_query)
332    }
333}
334
335fn get_attr<'a>(attr_ident: &str, attrs: &'a [syn::Attribute]) -> Option<&'a syn::Attribute> {
336    attrs
337        .iter()
338        .find(|&attr| attr.path.segments.len() == 1 && attr.path.segments[0].ident == attr_ident)
339}
340
341fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
342    syn::Error::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`")
343        .to_compile_error()
344}
345
346fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
347    syn::Error::new_spanned(
348        tokens,
349        "expected `proto_query(path = \"...\", response_type = ...)`",
350    )
351    .to_compile_error()
352}