obce_codegen/
definition.rs1use itertools::Itertools;
23use proc_macro2::TokenStream;
24use quote::quote;
25use syn::{
26 parse2,
27 parse_quote,
28 Error,
29 FnArg,
30 ItemTrait,
31 Lit,
32 Meta,
33 NestedMeta,
34 ReturnType,
35 TraitItem,
36 TraitItemMethod,
37};
38
39use crate::{
40 format_err_spanned,
41 types::AttributeArgs,
42 utils::{
43 into_u16,
44 into_u32,
45 AttributeParser,
46 },
47};
48
49struct TraitAttrs {
50 id: u16,
51}
52
53impl TraitAttrs {
54 fn new<'a, I: IntoIterator<Item = &'a NestedMeta>>(trait_item: &ItemTrait, iter: I) -> Result<Self, Error> {
55 let id = find_id(iter)?.unwrap_or_else(|| into_u16(&trait_item.ident));
56
57 Ok(Self { id })
58 }
59}
60
61struct Method {
62 id: u16,
63 hash: u32,
64 input_tokens: TokenStream,
65 output_tokens: TokenStream,
66}
67
68impl Method {
69 fn new(method_item: &mut TraitItemMethod) -> Result<Self, Error> {
70 if let Some(default) = &method_item.default {
71 return Err(format_err_spanned!(
72 default,
73 "default implementation is not supported in chain extensions"
74 ))
75 }
76
77 let (obce_attrs, other_attrs) = method_item.attrs.iter().cloned().split_attrs()?;
78
79 method_item.attrs = other_attrs;
80
81 let id = find_id(obce_attrs.iter())
82 .transpose()
83 .unwrap_or_else(|| Ok(into_u16(&method_item.sig.ident)))?;
84
85 let hash = into_u32(&method_item.sig.ident);
86
87 let input_tys = method_item.sig.inputs.iter().filter_map(|input| {
88 if let FnArg::Typed(pat) = input {
89 Some(&*pat.ty)
90 } else {
91 None
92 }
93 });
94
95 let output_tokens = if let ReturnType::Type(_, ty) = &method_item.sig.output {
96 quote!(#ty)
97 } else {
98 quote!(())
99 };
100
101 Ok(Self {
102 id,
103 hash,
104 input_tokens: quote! {
105 (#(#input_tys),*)
106 },
107 output_tokens,
108 })
109 }
110
111 fn fill_with_ink_data(&self, trait_attrs: &TraitAttrs, method_item: &mut TraitItemMethod) {
112 let Method {
113 id,
114 input_tokens,
115 output_tokens,
116 ..
117 } = self;
118
119 let input_bound = parse_quote! {
120 #input_tokens: ::scale::Encode
121 };
122
123 let output_bound = parse_quote! {
124 #output_tokens: ::scale::Decode
125 };
126
127 if let Some(where_clause) = &mut method_item.sig.generics.where_clause {
128 where_clause.predicates.push(input_bound);
129 where_clause.predicates.push(output_bound);
130 } else {
131 method_item.sig.generics.where_clause = Some(parse_quote! {
132 where #input_bound, #output_bound
133 });
134 }
135
136 let input_bindings = method_item.sig.inputs.iter().filter_map(|input| {
137 if let FnArg::Typed(pat) = input {
138 Some(&*pat.pat)
139 } else {
140 None
141 }
142 });
143
144 let trait_id = (trait_attrs.id as u32) << 16;
146 let id_for_call = trait_id | (*id as u32);
147
148 method_item.default = Some(parse_quote! {{
149 ::obce::ink_lang::env::chain_extension::ChainExtensionMethod::build(#id_for_call)
150 .input::<#input_tokens>()
151 .output::<#output_tokens, false>()
152 .ignore_error_code()
153 .call(&(#(#input_bindings),*))
154 }});
155 }
156}
157
158pub fn generate(attrs: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
159 let mut trait_item: ItemTrait = parse2(input)?;
160
161 let trait_attrs = TraitAttrs::new(&trait_item, parse2::<AttributeArgs>(attrs)?.iter())?;
162
163 let trait_id = trait_attrs.id;
164 let trait_name = &trait_item.ident;
165
166 let (impls, types, where_clause) = trait_item.generics.split_for_impl();
167
168 let methods: Vec<_> = trait_item
169 .items
170 .iter_mut()
171 .map(|item| {
172 if let TraitItem::Method(method) = item {
173 Method::new(method)
174 } else {
175 Err(format_err_spanned!(
176 item,
177 "only methods are supported in trait definitions"
178 ))
179 }
180 })
181 .try_collect()?;
182
183 if let Some(id) = methods.iter().map(|Method { id, .. }| id).duplicates().next() {
184 return Err(format_err_spanned!(
185 trait_item,
186 "found duplicated method identifier: {id}",
187 ))
188 }
189
190 let method_descriptions = methods.iter().map(
191 |Method {
192 id,
193 hash,
194 input_tokens,
195 output_tokens,
196 ..
197 }| {
198 quote! {
199 impl #impls ::obce::codegen::MethodDescription<#hash> for dyn #trait_name #types #where_clause {
200 const ID: ::core::primitive::u16 = #id;
201 type Input = #input_tokens;
202 type Output = #output_tokens;
203 }
204 }
205 },
206 );
207
208 let mut ink_trait_item = trait_item.clone();
209
210 ink_trait_item
211 .items
212 .iter_mut()
213 .zip(methods.iter())
214 .for_each(|(item, method)| {
215 if let TraitItem::Method(method_item) = item {
216 method.fill_with_ink_data(&trait_attrs, method_item);
217 } else {
218 unreachable!("only methods are present here")
222 }
223 });
224
225 Ok(quote! {
226 impl #impls ::obce::codegen::ExtensionDescription for dyn #trait_name #types #where_clause {
227 const ID: ::core::primitive::u16 = #trait_id;
228 }
229
230 #(#method_descriptions)*
231
232 #[cfg(feature = "substrate")]
233 #trait_item
234
235 #[cfg(feature = "ink")]
236 #ink_trait_item
237 })
238}
239
240fn find_id<'a, I: IntoIterator<Item = &'a NestedMeta>>(iter: I) -> Result<Option<u16>, Error> {
241 iter.into_iter()
242 .find_map(|arg| {
243 match arg {
244 NestedMeta::Meta(Meta::NameValue(value)) if value.path.is_ident("id") => {
245 Some(match &value.lit {
246 Lit::Int(lit_int) => lit_int.base10_parse::<u16>(),
247 Lit::Str(lit_str) => Ok(into_u16(lit_str.value())),
248 _ => Err(format_err_spanned!(value, "id should be integer or string")),
249 })
250 }
251 _ => None,
252 }
253 })
254 .transpose()
255}