injective_std_derive/
lib.rs

1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput};
6
7macro_rules! match_kv_attr {
8    ($key:expr, $value_type:tt) => {
9        |tt| {
10            if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11                &tt[..]
12            {
13                if (key == $key) && (eq.as_char() == '=') {
14                    Some(quote!(#value))
15                } else {
16                    None
17                }
18            } else {
19                None
20            }
21        }
22    };
23}
24
25#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
26pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(input as DeriveInput);
28    let ident = input.ident;
29
30    let type_url = get_type_url(&input.attrs);
31
32    // `EncodeError` always indicates that a message failed to encode because the
33    // provided buffer had insufficient capacity. Message encoding is otherwise
34    // infallible.
35
36    let (query_request_conversion, cosmwasm_query) = if get_attr("proto_query", &input.attrs).is_some() {
37        let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
38        let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
39
40        let query_request_conversion = quote! {
41            impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
42                fn from(msg: #ident) -> Self {
43                    cosmwasm_std::QueryRequest::<Q>::Stargate {
44                        path: #path.to_string(),
45                        data: msg.into(),
46                    }
47                }
48            }
49        };
50
51        let cosmwasm_query = quote! {
52            pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
53                querier.query::<#res>(&self.into())
54            }
55        };
56
57        (query_request_conversion, cosmwasm_query)
58    } else {
59        (quote!(), quote!())
60    };
61
62    (quote! {
63        impl #ident {
64            pub const TYPE_URL: &'static str = #type_url;
65            #cosmwasm_query
66        }
67
68        #query_request_conversion
69
70        impl From<#ident> for cosmwasm_std::Binary {
71            fn from(msg: #ident) -> Self {
72                let mut bytes = Vec::new();
73                prost::Message::encode(&msg, &mut bytes)
74                    .expect("Message encoding must be infallible");
75
76                cosmwasm_std::Binary::new(bytes)
77            }
78        }
79
80        impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
81            fn from(msg: #ident) -> Self {
82                cosmwasm_std::CosmosMsg::<T>::Stargate {
83                    type_url: #type_url.to_string(),
84                    value: msg.into(),
85                }
86            }
87        }
88
89        impl TryFrom<cosmwasm_std::Binary> for #ident {
90            type Error = cosmwasm_std::StdError;
91
92            fn try_from(binary: cosmwasm_std::Binary) -> Result<Self, Self::Error> {
93                use ::prost::Message;
94                Self::decode(&binary[..]).map_err(|e| {
95                    cosmwasm_std::StdError::parse_err(
96                        stringify!(#ident).to_string(),
97                        format!(
98                            "Unable to decode binary: \n  - base64: {}\n  - bytes array: {:?}\n\n{:?}",
99                            binary,
100                            binary.to_vec(),
101                            e
102                        ),
103                    )
104                })
105            }
106        }
107
108        impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
109            type Error = cosmwasm_std::StdError;
110
111            fn try_from(result: cosmwasm_std::SubMsgResult) -> Result<Self, Self::Error> {
112                result
113                    .into_result()
114                    .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
115                    .data
116                    .ok_or_else(|| cosmwasm_std::StdError::not_found(
117                        "cosmwasm_std::SubMsgResult::<T>".to_string()
118                    ))?
119                    .try_into()
120            }
121        }
122    })
123    .into()
124}
125
126fn get_type_url(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
127    let proto_message = get_attr("proto_message", attrs).and_then(|a| a.parse_meta().ok());
128
129    if let Some(syn::Meta::List(meta)) = proto_message.clone() {
130        match meta.nested[0].clone() {
131            syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => {
132                if meta.path.is_ident("type_url") {
133                    match meta.lit {
134                        syn::Lit::Str(s) => quote!(#s),
135                        _ => proto_message_attr_error(meta.lit),
136                    }
137                } else {
138                    proto_message_attr_error(meta.path)
139                }
140            }
141            t => proto_message_attr_error(t),
142        }
143    } else {
144        proto_message_attr_error(proto_message)
145    }
146}
147
148fn get_query_attrs<F>(attrs: &[syn::Attribute], f: F) -> proc_macro2::TokenStream
149where
150    F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>,
151{
152    let proto_query = get_attr("proto_query", attrs);
153
154    if let Some(attr) = proto_query {
155        if attr.tokens.clone().into_iter().count() != 1 {
156            return proto_query_attr_error(proto_query);
157        }
158
159        if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
160            let kv_groups = group
161                .stream()
162                .into_iter()
163                .chunk_by(|t| if let TokenTree::Punct(punct) = t { punct.as_char() != ',' } else { true });
164            let mut key_values: Vec<Vec<TokenTree>> = vec![];
165
166            for (non_sep, g) in &kv_groups {
167                if non_sep {
168                    key_values.push(g.collect());
169                }
170            }
171
172            return key_values.iter().find_map(f).unwrap_or_else(|| proto_query_attr_error(proto_query));
173        }
174
175        proto_query_attr_error(proto_query)
176    } else {
177        proto_query_attr_error(proto_query)
178    }
179}
180
181fn get_attr<'a>(attr_ident: &str, attrs: &'a [syn::Attribute]) -> Option<&'a syn::Attribute> {
182    attrs
183        .iter()
184        .find(|&attr| attr.path.segments.len() == 1 && attr.path.segments[0].ident == attr_ident)
185}
186
187fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
188    syn::Error::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`").to_compile_error()
189}
190
191fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
192    syn::Error::new_spanned(tokens, "expected `proto_query(path = \"...\", response_type = ...)`").to_compile_error()
193}