pryzm_std_derive/
lib.rs

1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput};
6
7macro_rules! match_kv_attr {
8    ($key:expr, $value_type:tt) => {
9        |tt| {
10            if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11                &tt[..]
12            {
13                if (key == $key) && (eq.as_char() == '=') {
14                    Some(quote!(#value))
15                } else {
16                    None
17                }
18            } else {
19                None
20            }
21        }
22    };
23}
24
25
26#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
27pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    let ident = input.ident;
30
31    let type_url = get_type_url(&input.attrs);
32
33    // `EncodeError` always indicates that a message failed to encode because the
34    // provided buffer had insufficient capacity. Message encoding is otherwise
35    // infallible.
36
37    let (query_request_conversion, cosmwasm_query) = if get_attr("proto_query", &input.attrs)
38        .is_some()
39    {
40        let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
41        let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
42
43        let query_request_conversion = quote! {
44            impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
45                fn from(msg: #ident) -> Self {
46                    cosmwasm_std::QueryRequest::<Q>::Stargate {
47                        path: #path.to_string(),
48                        data: msg.into(),
49                    }
50                }
51            }
52        };
53
54        let cosmwasm_query = quote! {
55            pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
56                querier.query::<#res>(&self.into())
57            }
58        };
59
60        (query_request_conversion, cosmwasm_query)
61    } else {
62        (quote!(), quote!())
63    };
64
65    (quote! {
66        impl #ident {
67            pub const TYPE_URL: &'static str = #type_url;
68            #cosmwasm_query
69
70            pub fn to_proto_bytes(&self) -> Vec<u8> {
71                let mut bytes = Vec::new();
72                prost::Message::encode(self, &mut bytes)
73                    .expect("Message encoding must be infallible");
74                bytes
75            }
76            pub fn to_any(&self) -> crate::shim::Any {
77                crate::shim::Any {
78                    type_url: Self::TYPE_URL.to_string(),
79                    value: self.to_proto_bytes(),
80                }
81            }
82        }
83
84        #query_request_conversion
85
86        impl From<#ident> for cosmwasm_std::Binary {
87            fn from(msg: #ident) -> Self {
88                cosmwasm_std::Binary(msg.to_proto_bytes())
89            }
90        }
91
92        impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
93            fn from(msg: #ident) -> Self {
94                cosmwasm_std::CosmosMsg::<T>::Stargate {
95                    type_url: #type_url.to_string(),
96                    value: msg.into(),
97                }
98            }
99        }
100
101        impl TryFrom<cosmwasm_std::Binary> for #ident {
102            type Error = cosmwasm_std::StdError;
103
104            fn try_from(binary: cosmwasm_std::Binary) -> ::std::result::Result<Self, Self::Error> {
105                use ::prost::Message;
106                Self::decode(&binary[..]).map_err(|e| {
107                    cosmwasm_std::StdError::parse_err(
108                        stringify!(#ident),
109                        format!(
110                            "Unable to decode binary: \n  - base64: {}\n  - bytes array: {:?}\n\n{:?}",
111                            binary,
112                            binary.to_vec(),
113                            e
114                        )
115                    )
116                })
117            }
118        }
119
120        impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
121            type Error = cosmwasm_std::StdError;
122
123            fn try_from(result: cosmwasm_std::SubMsgResult) -> ::std::result::Result<Self, Self::Error> {
124                result
125                    .into_result()
126                    .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
127                    .data
128                    .ok_or_else(|| cosmwasm_std::StdError::not_found("cosmwasm_std::SubMsgResult::<T>"))?
129                    .try_into()
130            }
131        }
132    }).into()
133}
134
135
136fn get_type_url(attrs: &Vec<syn::Attribute>) -> proc_macro2::TokenStream {
137    let proto_message = get_attr("proto_message", attrs).and_then(|a| a.parse_meta().ok());
138
139    if let Some(syn::Meta::List(meta)) = proto_message.clone() {
140        match meta.nested[0].clone() {
141            syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => {
142                if meta.path.is_ident("type_url") {
143                    match meta.lit {
144                        syn::Lit::Str(s) => quote!(#s),
145                        _ => proto_message_attr_error(meta.lit),
146                    }
147                } else {
148                    proto_message_attr_error(meta.path)
149                }
150            }
151            t => proto_message_attr_error(t),
152        }
153    } else {
154        proto_message_attr_error(proto_message)
155    }
156}
157
158fn get_query_attrs<F>(attrs: &Vec<syn::Attribute>, f: F) -> proc_macro2::TokenStream
159where
160    F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>,
161{
162    let proto_query = get_attr("proto_query", attrs);
163
164    if let Some(attr) = proto_query {
165        if attr.tokens.clone().into_iter().count() != 1 {
166            return proto_query_attr_error(proto_query);
167        }
168
169        if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
170            let kv_groups = group.stream().into_iter().group_by(|t| {
171                if let TokenTree::Punct(punct) = t {
172                    punct.as_char() != ','
173                } else {
174                    true
175                }
176            });
177            let mut key_values: Vec<Vec<TokenTree>> = vec![];
178
179            for (non_sep, g) in &kv_groups {
180                if non_sep {
181                    key_values.push(g.collect());
182                }
183            }
184
185            return key_values
186                .iter()
187                .find_map(f)
188                .unwrap_or_else(|| proto_query_attr_error(proto_query));
189        }
190
191        proto_query_attr_error(proto_query)
192    } else {
193        proto_query_attr_error(proto_query)
194    }
195}
196
197fn get_attr<'a>(attr_ident: &str, attrs: &'a Vec<syn::Attribute>) -> Option<&'a syn::Attribute> {
198    attrs.iter().find(|&attr| attr.path.segments.len() == 1 && attr.path.segments[0].ident == attr_ident)
199}
200
201fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
202    syn::Error::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`")
203        .to_compile_error()
204}
205
206fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
207    syn::Error::new_spanned(
208        tokens,
209        "expected `proto_query(path = \"...\", response_type = ...)`",
210    )
211    .to_compile_error()
212}