1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{ parse_macro_input, DeriveInput };
6
7macro_rules! match_kv_attr {
8 ($key:expr, $value_type:tt) => {
9 |tt| {
10 if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11 &tt[..]
12 {
13 if (key == $key) && (eq.as_char() == '=') {
14 Some(quote!(#value))
15 } else {
16 None
17 }
18 } else {
19 None
20 }
21 }
22 };
23}
24
25#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
26pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 let ident = input.ident;
29
30 let type_url = get_type_url(&input.attrs);
31
32 let (query_request_conversion, cosmwasm_query, path_token) = if
37 get_attr("proto_query", &input.attrs).is_some()
38 {
39 let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
40 let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
41 let query_request_conversion =
42 quote! {
43 impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
44 fn from(msg: #ident) -> Self {
45 cosmwasm_std::QueryRequest::<Q>::Grpc(cosmwasm_std::GrpcQuery {
46 path: #path.to_string(),
47 data: msg.into(),
48 })
49 }
50 }
51 };
52
53 let cosmwasm_query =
54 quote! {
55 pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
56 use prost::Message;
57 let resp = #res::decode(
58 querier.query_grpc(
59 #path.to_string(),
60 self.to_proto_bytes().into(),
61 )?
62 .as_slice(),
63 );
64 match resp {
65 Err(e) => Err(cosmwasm_std::StdError::generic_err(format!(
66 "Can't decode item: {}",
67 e
68 ))),
69 Ok(data) => Ok(data),
70 }
71 }
72 };
73
74 let path_token = quote! {
75 pub const PATH: &'static str = #path;
76 };
77
78 (query_request_conversion, cosmwasm_query, path_token)
79 } else {
80 (quote!(), quote!(), quote!())
81 };
82
83 (
84 quote! {
85 impl #ident {
86 pub const TYPE_URL: &'static str = #type_url;
87
88 #path_token
89
90 #cosmwasm_query
91
92 pub fn to_proto_bytes(&self) -> Vec<u8> {
93 let mut bytes = Vec::new();
94 prost::Message::encode(self, &mut bytes)
95 .expect("Message encoding must be infallible");
96 bytes
97 }
98 pub fn to_any(&self) -> crate::shim::Any {
99 crate::shim::Any {
100 type_url: Self::TYPE_URL.to_string(),
101 value: self.to_proto_bytes(),
102 }
103 }
104 }
105
106 #query_request_conversion
107
108 impl From<#ident> for cosmwasm_std::Binary {
109 fn from(msg: #ident) -> Self {
110 cosmwasm_std::Binary::new(msg.to_proto_bytes())
111 }
112 }
113
114 impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
115 fn from(msg: #ident) -> Self {
116 cosmwasm_std::CosmosMsg::<T>::Any(cosmwasm_std::AnyMsg {
117 type_url: #type_url.to_string(),
118 value: msg.into(),
119 })
120 }
121 }
122
123 impl TryFrom<cosmwasm_std::Binary> for #ident {
124 type Error = cosmwasm_std::StdError;
125
126 fn try_from(binary: cosmwasm_std::Binary) -> ::std::result::Result<Self, Self::Error> {
127 use ::prost::Message;
128 Self::decode(&binary[..]).map_err(|e| {
129 cosmwasm_std::StdError::parse_err(
130 stringify!(#ident),
131 format!(
132 "Unable to decode binary: \n - base64: {}\n - bytes array: {:?}\n\n{:?}",
133 binary,
134 binary.to_vec(),
135 e
136 )
137 )
138 })
139 }
140 }
141
142 impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
143 type Error = cosmwasm_std::StdError;
144
145 fn try_from(result: cosmwasm_std::SubMsgResult) -> ::std::result::Result<Self, Self::Error> {
146 result
147 .into_result()
148 .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
149 .data
150 .ok_or_else(|| cosmwasm_std::StdError::not_found("cosmwasm_std::SubMsgResult::<T>"))?
151 .try_into()
152 }
153 }
154 }
155 ).into()
156}
157
158fn get_type_url(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
159 let proto_message = get_attr("proto_message", attrs).and_then(|a| a.parse_meta().ok());
160
161 if let Some(syn::Meta::List(meta)) = proto_message.clone() {
162 match meta.nested[0].clone() {
163 syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => {
164 if meta.path.is_ident("type_url") {
165 match meta.lit {
166 syn::Lit::Str(s) => quote!(#s),
167 _ => proto_message_attr_error(meta.lit),
168 }
169 } else {
170 proto_message_attr_error(meta.path)
171 }
172 }
173 t => proto_message_attr_error(t),
174 }
175 } else {
176 proto_message_attr_error(proto_message)
177 }
178}
179
180fn get_query_attrs<F>(attrs: &[syn::Attribute], f: F) -> proc_macro2::TokenStream
181 where F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>
182{
183 let proto_query = get_attr("proto_query", attrs);
184
185 if let Some(attr) = proto_query {
186 if attr.tokens.clone().into_iter().count() != 1 {
187 return proto_query_attr_error(proto_query);
188 }
189
190 if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
191 let kv_groups = group
192 .stream()
193 .into_iter()
194 .chunk_by(|t| {
195 if let TokenTree::Punct(punct) = t { punct.as_char() != ',' } else { true }
196 });
197 let mut key_values: Vec<Vec<TokenTree>> = vec![];
198
199 for (non_sep, g) in &kv_groups {
200 if non_sep {
201 key_values.push(g.collect());
202 }
203 }
204
205 return key_values
206 .iter()
207 .find_map(f)
208 .unwrap_or_else(|| proto_query_attr_error(proto_query));
209 }
210
211 proto_query_attr_error(proto_query)
212 } else {
213 proto_query_attr_error(proto_query)
214 }
215}
216
217fn get_attr<'a>(attr_ident: &str, attrs: &'a [syn::Attribute]) -> Option<&'a syn::Attribute> {
218 attrs
219 .iter()
220 .find(|&attr| attr.path.segments.len() == 1 && attr.path.segments[0].ident == attr_ident)
221}
222
223fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
224 syn::Error
225 ::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`")
226 .to_compile_error()
227}
228
229fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
230 syn::Error
231 ::new_spanned(tokens, "expected `proto_query(path = \"...\", response_type = ...)`")
232 .to_compile_error()
233}