persistence_std_derive/
lib.rs1use itertools::Itertools;
2use proc_macro::TokenStream;
3use proc_macro2::TokenTree;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput};
6
7macro_rules! match_kv_attr {
8 ($key:expr, $value_type:tt) => {
9 |tt| {
10 if let [TokenTree::Ident(key), TokenTree::Punct(eq), TokenTree::$value_type(value)] =
11 &tt[..]
12 {
13 if (key == $key) && (eq.as_char() == '=') {
14 Some(quote!(#value))
15 } else {
16 None
17 }
18 } else {
19 None
20 }
21 }
22 };
23}
24
25#[proc_macro_derive(CosmwasmExt, attributes(proto_message, proto_query))]
26pub fn derive_cosmwasm_ext(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 let ident = input.ident;
29
30 let type_url = get_type_url(&input.attrs);
31
32 let (query_request_conversion, cosmwasm_query) = if get_attr("proto_query", &input.attrs)
37 .is_some()
38 {
39 let path = get_query_attrs(&input.attrs, match_kv_attr!("path", Literal));
40 let res = get_query_attrs(&input.attrs, match_kv_attr!("response_type", Ident));
41
42 let query_request_conversion = quote! {
43 impl <Q: cosmwasm_std::CustomQuery> From<#ident> for cosmwasm_std::QueryRequest<Q> {
44 fn from(msg: #ident) -> Self {
45 cosmwasm_std::QueryRequest::<Q>::Stargate {
46 path: #path.to_string(),
47 data: msg.into(),
48 }
49 }
50 }
51 };
52
53 let cosmwasm_query = quote! {
54 pub fn query(self, querier: &cosmwasm_std::QuerierWrapper<impl cosmwasm_std::CustomQuery>) -> cosmwasm_std::StdResult<#res> {
55 querier.query::<#res>(&self.into())
56 }
57 };
58
59 (query_request_conversion, cosmwasm_query)
60 } else {
61 (quote!(), quote!())
62 };
63
64 (quote! {
65 impl #ident {
66 pub const TYPE_URL: &'static str = #type_url;
67 #cosmwasm_query
68
69 pub fn to_proto_bytes(&self) -> Vec<u8> {
70 let mut bytes = Vec::new();
71 prost::Message::encode(self, &mut bytes)
72 .expect("Message encoding must be infallible");
73 bytes
74 }
75 pub fn to_any(&self) -> crate::shim::Any {
76 crate::shim::Any {
77 type_url: Self::TYPE_URL.to_string(),
78 value: self.to_proto_bytes(),
79 }
80 }
81 }
82
83 #query_request_conversion
84
85 impl From<#ident> for cosmwasm_std::Binary {
86 fn from(msg: #ident) -> Self {
87 cosmwasm_std::Binary(msg.to_proto_bytes())
88 }
89 }
90
91 impl<T> From<#ident> for cosmwasm_std::CosmosMsg<T> {
92 fn from(msg: #ident) -> Self {
93 cosmwasm_std::CosmosMsg::<T>::Stargate {
94 type_url: #type_url.to_string(),
95 value: msg.into(),
96 }
97 }
98 }
99
100 impl TryFrom<cosmwasm_std::Binary> for #ident {
101 type Error = cosmwasm_std::StdError;
102
103 fn try_from(binary: cosmwasm_std::Binary) -> ::std::result::Result<Self, Self::Error> {
104 use ::prost::Message;
105 Self::decode(&binary[..]).map_err(|e| {
106 cosmwasm_std::StdError::parse_err(
107 stringify!(#ident),
108 format!(
109 "Unable to decode binary: \n - base64: {}\n - bytes array: {:?}\n\n{:?}",
110 binary,
111 binary.to_vec(),
112 e
113 )
114 )
115 })
116 }
117 }
118
119 impl TryFrom<cosmwasm_std::SubMsgResult> for #ident {
120 type Error = cosmwasm_std::StdError;
121
122 fn try_from(result: cosmwasm_std::SubMsgResult) -> ::std::result::Result<Self, Self::Error> {
123 result
124 .into_result()
125 .map_err(|e| cosmwasm_std::StdError::generic_err(e))?
126 .data
127 .ok_or_else(|| cosmwasm_std::StdError::not_found("cosmwasm_std::SubMsgResult::<T>"))?
128 .try_into()
129 }
130 }
131 })
132 .into()
133}
134
135fn get_type_url(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
136 let proto_message = get_attr("proto_message", attrs).and_then(|a| a.parse_meta().ok());
137
138 if let Some(syn::Meta::List(meta)) = proto_message.clone() {
139 match meta.nested[0].clone() {
140 syn::NestedMeta::Meta(syn::Meta::NameValue(meta)) => {
141 if meta.path.is_ident("type_url") {
142 match meta.lit {
143 syn::Lit::Str(s) => quote!(#s),
144 _ => proto_message_attr_error(meta.lit),
145 }
146 } else {
147 proto_message_attr_error(meta.path)
148 }
149 }
150 t => proto_message_attr_error(t),
151 }
152 } else {
153 proto_message_attr_error(proto_message)
154 }
155}
156
157fn get_query_attrs<F>(attrs: &[syn::Attribute], f: F) -> proc_macro2::TokenStream
158where
159 F: FnMut(&Vec<TokenTree>) -> Option<proc_macro2::TokenStream>,
160{
161 let proto_query = get_attr("proto_query", attrs);
162
163 if let Some(attr) = proto_query {
164 if attr.tokens.clone().into_iter().count() != 1 {
165 return proto_query_attr_error(proto_query);
166 }
167
168 if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
169 let kv_groups = group.stream().into_iter().group_by(|t| {
170 if let TokenTree::Punct(punct) = t {
171 punct.as_char() != ','
172 } else {
173 true
174 }
175 });
176 let mut key_values: Vec<Vec<TokenTree>> = vec![];
177
178 for (non_sep, g) in &kv_groups {
179 if non_sep {
180 key_values.push(g.collect());
181 }
182 }
183
184 return key_values
185 .iter()
186 .find_map(f)
187 .unwrap_or_else(|| proto_query_attr_error(proto_query));
188 }
189
190 proto_query_attr_error(proto_query)
191 } else {
192 proto_query_attr_error(proto_query)
193 }
194}
195
196fn get_attr<'a>(attr_ident: &str, attrs: &'a [syn::Attribute]) -> Option<&'a syn::Attribute> {
197 attrs
198 .iter()
199 .find(|&attr| attr.path.segments.len() == 1 && attr.path.segments[0].ident == attr_ident)
200}
201
202fn proto_message_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
203 syn::Error::new_spanned(tokens, "expected `proto_message(type_url = \"...\")`")
204 .to_compile_error()
205}
206
207fn proto_query_attr_error<T: quote::ToTokens>(tokens: T) -> proc_macro2::TokenStream {
208 syn::Error::new_spanned(
209 tokens,
210 "expected `proto_query(path = \"...\", response_type = ...)`",
211 )
212 .to_compile_error()
213}