1use crate::{
23 format_err_spanned,
24 utils::{
25 into_u32,
26 AttributeParser,
27 InputBindings,
28 LitOrPath,
29 MetaUtils,
30 },
31};
32use itertools::Itertools;
33use proc_macro2::{
34 Ident,
35 TokenStream,
36};
37use quote::{
38 format_ident,
39 quote,
40 ToTokens,
41};
42use syn::{
43 parse::Parser,
44 parse2,
45 parse_quote,
46 parse_str,
47 punctuated::Punctuated,
48 Error,
49 Expr,
50 GenericArgument,
51 Generics,
52 ImplItem,
53 ItemImpl,
54 Lit,
55 Meta,
56 NestedMeta,
57 Path,
58 PathArguments,
59 Token,
60 Type,
61};
62use tuple::Map;
63
64pub fn generate(_attrs: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
65 let impl_item: ItemImpl = parse2(input).unwrap();
66
67 let mut original_implementation = impl_item.clone();
68
69 let method_items = original_implementation.items.iter_mut().filter_map(|item| {
70 if let ImplItem::Method(method_item) = item {
71 Some(method_item)
72 } else {
73 None
74 }
75 });
76
77 for method_item in method_items {
78 let (_, other_attrs) = method_item.attrs.iter().cloned().split_attrs()?;
79
80 method_item.attrs = other_attrs;
81 }
82
83 let chain_extension = chain_extension_trait_impl(impl_item)?;
84
85 Ok(quote! {
86 #original_implementation
88
89 #chain_extension
91 })
92}
93
94#[allow(non_snake_case)]
95fn chain_extension_trait_impl(mut impl_item: ItemImpl) -> Result<TokenStream, Error> {
96 let context = ExtensionContext::try_from(&impl_item)?;
97
98 let namespace = quote! { ::obce::substrate::pallet_contracts::chain_extension:: };
99
100 let T = context.substrate;
101 let E = context.env;
102 let Env = context.obce_env;
103 let extension = context.extension;
104
105 let mut callable_generics = impl_item.generics.clone();
106 callable_generics = filter_generics(callable_generics, &context.lifetime1);
107 let (callable_impls, _, callable_where) = callable_generics.split_for_impl();
108
109 let mut main_generics = impl_item.generics.clone();
110 main_generics = filter_generics(main_generics, &context.lifetime1);
111 main_generics = filter_generics(main_generics, &E);
112 main_generics = filter_generics(main_generics, &Env);
113 let (main_impls, _, main_where) = main_generics.split_for_impl();
114
115 let mut call_generics = impl_item.generics.clone();
116 call_generics = filter_generics(call_generics, &context.lifetime1);
117 call_generics = filter_generics(call_generics, &Env);
118
119 if let Some(where_clause) = &mut call_generics.where_clause {
122 where_clause.predicates.push(parse_quote! {
123 #E: #namespace Ext<T = #T>
124 });
125 } else {
126 call_generics.where_clause = Some(parse_quote! {
127 where #E: #namespace Ext<T = #T>
128 });
129 }
130
131 let (_, _, call_where) = call_generics.split_for_impl();
132
133 let trait_;
134 let dyn_trait;
135 if let Some((_, path, _)) = impl_item.trait_ {
136 trait_ = path.clone();
137 dyn_trait = quote! { dyn #path };
138 } else {
139 return Err(format_err_spanned!(impl_item, "expected impl trait block",))
140 }
141
142 let methods: Vec<_> = impl_item
143 .items
144 .iter_mut()
145 .filter_map(|item| {
146 if let ImplItem::Method(method) = item {
147 Some(method)
148 } else {
149 None
150 }
151 })
152 .map(|method| {
153 let (obce_attrs, other_attrs) = method.attrs.iter().cloned().split_attrs()?;
154
155 method.attrs = other_attrs;
156
157 let hash = into_u32(&method.sig.ident);
158 let method_name = &method.sig.ident;
159
160 let input_bindings = InputBindings::from_iter(&method.sig.inputs);
161 let lhs_pat = input_bindings.lhs_pat(None);
162 let call_params = input_bindings.iter_call_params();
163
164 let (weight_tokens, pre_charge) = handle_weight_attribute(&input_bindings, obce_attrs.iter())?;
165 let ret_val_tokens = handle_ret_val_attribute(obce_attrs.iter());
166
167 let (read_with_charge, pre_charge_arg) = if pre_charge {
168 (
169 quote! {
170 let pre_charged = #weight_tokens;
171 let #lhs_pat = env.read_as_unbounded(len)?;
172 },
173 quote! {
174 Some(pre_charged)
175 },
176 )
177 } else {
178 (
179 quote! {
180 let #lhs_pat = env.read_as_unbounded(len)?;
181 #weight_tokens;
182 },
183 quote! {
184 None
185 },
186 )
187 };
188
189 Result::<_, Error>::Ok(quote! {
190 <#dyn_trait as ::obce::codegen::MethodDescription<#hash>>::ID => {
191 #read_with_charge
192 let mut context = ::obce::substrate::ExtensionContext::new(self, env, #pre_charge_arg);
193 #[allow(clippy::unnecessary_mut_passed)]
194 let result = <_ as #trait_>::#method_name(
195 &mut context
196 #(, #call_params)*
197 );
198
199 let result = ::obce::to_critical_error!(result)?;
202 #ret_val_tokens
203 <_ as ::scale::Encode>::using_encoded(&result, |w| context.env.write(w, true, None))?;
204 },
205 })
206 })
207 .try_collect()?;
208
209 Ok(quote! {
210 impl #callable_impls ::obce::substrate::CallableChainExtension<#E, #T, #Env> for #extension
211 #callable_where
212 {
213 fn call(&mut self, mut env: #Env) -> ::core::result::Result<
214 #namespace RetVal,
215 ::obce::substrate::CriticalError
216 > {
217 let len = env.in_len();
218
219 match env.func_id() {
220 #(#methods)*
221 _ => ::core::result::Result::Err(::obce::substrate::CriticalError::Other(
222 "InvalidFunctionId"
223 ))?,
224 };
225
226 Ok(#namespace RetVal::Converging(0))
227 }
228 }
229
230 impl #main_impls #namespace ChainExtension<#T> for #extension #main_where {
231 fn call<#E>(&mut self, env: #namespace Environment<#E, #namespace InitState>)
232 -> ::core::result::Result<#namespace RetVal, ::obce::substrate::CriticalError>
233 #call_where
234 {
235 <#extension as ::obce::substrate::CallableChainExtension<#E, #T, _>>::call(
236 self, env.buf_in_buf_out()
237 )
238 }
239 }
240
241 impl #main_impls #namespace RegisteredChainExtension<#T> for #extension #main_where {
242 const ID: ::core::primitive::u16 = <#dyn_trait as ::obce::codegen::ExtensionDescription>::ID;
243 }
244 })
245}
246
247struct ExtensionContext {
248 lifetime1: GenericArgument,
250 env: GenericArgument,
252 substrate: GenericArgument,
254 obce_env: GenericArgument,
256 extension: GenericArgument,
258}
259
260impl TryFrom<&ItemImpl> for ExtensionContext {
261 type Error = Error;
262
263 fn try_from(impl_item: &ItemImpl) -> Result<Self, Self::Error> {
264 let Type::Path(path) = impl_item.self_ty.as_ref() else {
265 return Err(format_err_spanned!(
266 impl_item,
267 "the type should be `ExtensionContext`"
268 ));
269 };
270
271 let Some(extension) = path.path.segments.last() else {
272 return Err(format_err_spanned!(
273 path,
274 "the type should be `ExtensionContext`"
275 ));
276 };
277
278 let PathArguments::AngleBracketed(generic_args) = &extension.arguments else {
279 return Err(format_err_spanned!(
280 path,
281 "`ExtensionContext` should have 5 generics as `<'a, E, T, Env, Extension>`"
282 ));
283 };
284
285 let (lifetime1, env, substrate, obce_env, extension) =
286 generic_args.args.iter().cloned().tuples().exactly_one().map_err(|_| {
287 format_err_spanned!(
288 generic_args,
289 "`ExtensionContext` should have 5 generics as `<'a, E, T, Env, Extension>`"
290 )
291 })?;
292
293 Ok(ExtensionContext {
294 lifetime1,
295 env,
296 substrate,
297 obce_env,
298 extension,
299 })
300 }
301}
302
303fn filter_generics(mut generics: Generics, filter: &GenericArgument) -> Generics {
304 let filter: Vec<_> = filter
305 .to_token_stream()
306 .into_iter()
307 .map(|token| token.to_string())
308 .collect();
309 generics.params = generics
310 .params
311 .clone()
312 .into_iter()
313 .filter(|param| {
314 let param: Vec<_> = param
315 .to_token_stream()
316 .into_iter()
317 .map(|token| token.to_string())
318 .collect();
319 !is_subsequence(¶m, &filter)
320 })
321 .collect();
322
323 if let Some(where_clause) = &mut generics.where_clause {
324 where_clause.predicates = where_clause
325 .predicates
326 .clone()
327 .into_iter()
328 .filter(|predicate| {
329 let predicate: Vec<_> = predicate
330 .to_token_stream()
331 .into_iter()
332 .map(|token| token.to_string())
333 .collect();
334 !is_subsequence(&predicate, &filter)
335 })
336 .collect();
337 }
338
339 generics
340}
341
342fn is_subsequence<T: PartialEq + core::fmt::Debug>(src: &[T], search: &[T]) -> bool {
343 if search.len() > src.len() {
344 return false
345 }
346
347 for i in 0..(src.len() - search.len() + 1) {
348 if &src[i..(i + search.len())] == search {
349 return true
350 }
351 }
352 false
353}
354
355fn handle_ret_val_attribute<'a, I: IntoIterator<Item = &'a NestedMeta>>(iter: I) -> Option<TokenStream> {
356 let should_handle = iter.into_iter().any(|attr| {
357 if let NestedMeta::Meta(Meta::Path(path)) = attr {
358 if let Some(ident) = path.get_ident() {
359 return ident == "ret_val"
360 }
361 }
362
363 false
364 });
365
366 should_handle.then(|| {
367 quote! {
368 if let Err(error) = result {
369 if let Ok(ret_val) = error.try_into() {
370 return Ok(ret_val)
371 }
372 }
373 }
374 })
375}
376
377fn handle_weight_attribute<'a, I: IntoIterator<Item = &'a NestedMeta>>(
378 input_bindings: &InputBindings,
379 iter: I,
380) -> Result<(Option<TokenStream>, bool), Error> {
381 let weight_params = iter.into_iter().find_map(|attr| {
382 let NestedMeta::Meta(Meta::List(list)) = attr else {
383 return None;
384 };
385
386 let Some(ident) = list.path.get_ident() else {
387 return None
388 };
389
390 (ident == "weight").then_some((&list.nested, ident))
391 });
392
393 if let Some((weight_params, weight_ident)) = weight_params {
394 match weight_params.iter().find_by_name("dispatch") {
395 Some((LitOrPath::Lit(Lit::Str(dispatch_path)), ident)) => {
396 let args = match weight_params.iter().find_by_name("args") {
397 Some((LitOrPath::Lit(Lit::Str(args)), _)) => Some(args.value()),
398 None => None,
399 Some((_, ident)) => {
400 return Err(format_err_spanned!(
401 ident,
402 "`args` attribute should contain a comma-separated expression list"
403 ))
404 }
405 };
406
407 return Ok((
408 Some(handle_dispatch_weight(
409 ident,
410 input_bindings,
411 &dispatch_path.value(),
412 args.as_deref(),
413 )?),
414 false,
415 ))
416 }
417 Some((_, ident)) => {
418 return Err(format_err_spanned!(
419 ident,
420 "`dispatch` attribute should contain a pallet method path"
421 ))
422 }
423 None => {}
424 };
425
426 match weight_params.iter().find_by_name("expr") {
427 Some((LitOrPath::Lit(Lit::Str(expr)), _)) => {
428 let pre_charge = matches!(
429 weight_params.iter().find_by_name("pre_charge"),
430 Some((LitOrPath::Path, _))
431 );
432
433 return Ok((
434 Some(handle_expr_weight(input_bindings, &expr.value(), pre_charge)?),
435 pre_charge,
436 ))
437 }
438 Some((_, ident)) => {
439 return Err(format_err_spanned!(
440 ident,
441 "`expr` attribute should contain an expression that returns `Weight`"
442 ))
443 }
444 None => {}
445 }
446
447 Err(format_err_spanned!(
448 weight_ident,
449 r#"either "dispatch" or "expr" attributes are expected"#
450 ))
451 } else {
452 Ok((None, false))
453 }
454}
455
456fn handle_expr_weight(input_bindings: &InputBindings, expr: &str, pre_charge: bool) -> Result<TokenStream, Error> {
457 let expr = parse_str::<Expr>(expr)?;
458
459 let raw_map = if pre_charge {
460 quote! {}
461 } else {
462 input_bindings.raw_special_mapping()
463 };
464
465 Ok(quote! {{
466 #[allow(unused_variables)]
467 #raw_map
468 env.charge_weight(#expr)?
469 }})
470}
471
472fn handle_dispatch_weight(
473 ident: &Ident,
474 input_bindings: &InputBindings,
475 dispatch_path: &str,
476 args: Option<&str>,
477) -> Result<TokenStream, Error> {
478 let segments = parse_str::<Path>(dispatch_path)?.segments.into_iter();
479 let segments_len = segments.len();
480
481 if segments_len < 3 {
482 return Err(format_err_spanned!(
483 ident,
484 "dispatch path should contain at least three segments"
485 ))
486 }
487
488 let (pallet_ns, _, method_name) = segments
489 .enumerate()
490 .group_by(|(idx, _)| if *idx < segments_len - 2 { 0 } else { *idx })
491 .into_iter()
492 .map(|(_, group)| group.map(|(_, segment)| segment))
493 .next_tuple::<(_, _, _)>()
494 .unwrap()
495 .map(Punctuated::<_, Token![::]>::from_iter);
496
497 let dispatch_args = if let Some(args) = args {
498 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
499 parser.parse_str(args)?.to_token_stream()
500 } else {
501 let raw_call_params = input_bindings.iter_raw_call_params();
502
503 quote! {
505 #(*#raw_call_params,)*
506 }
507 };
508
509 let call_variant_name = format_ident!("new_call_variant_{}", method_name.last().unwrap().ident);
510
511 let raw_map = input_bindings.raw_special_mapping();
512
513 Ok(quote! {{
514 #[allow(unused_variables)]
515 #raw_map
516 let __call_variant = &#pallet_ns ::Call::<T>::#call_variant_name(#dispatch_args);
517 let __dispatch_info = <#pallet_ns ::Call<T> as ::obce::substrate::frame_support::dispatch::GetDispatchInfo>::get_dispatch_info(__call_variant);
518 env.charge_weight(__dispatch_info.weight)?
519 }})
520}