obce_codegen/
implementation.rs

1// Copyright (c) 2012-2022 Supercolony
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the"Software"),
5// to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
22use crate::{
23    format_err_spanned,
24    utils::{
25        into_u32,
26        AttributeParser,
27        InputBindings,
28        LitOrPath,
29        MetaUtils,
30    },
31};
32use itertools::Itertools;
33use proc_macro2::{
34    Ident,
35    TokenStream,
36};
37use quote::{
38    format_ident,
39    quote,
40    ToTokens,
41};
42use syn::{
43    parse::Parser,
44    parse2,
45    parse_quote,
46    parse_str,
47    punctuated::Punctuated,
48    Error,
49    Expr,
50    GenericArgument,
51    Generics,
52    ImplItem,
53    ItemImpl,
54    Lit,
55    Meta,
56    NestedMeta,
57    Path,
58    PathArguments,
59    Token,
60    Type,
61};
62use tuple::Map;
63
64pub fn generate(_attrs: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
65    let impl_item: ItemImpl = parse2(input).unwrap();
66
67    let mut original_implementation = impl_item.clone();
68
69    let method_items = original_implementation.items.iter_mut().filter_map(|item| {
70        if let ImplItem::Method(method_item) = item {
71            Some(method_item)
72        } else {
73            None
74        }
75    });
76
77    for method_item in method_items {
78        let (_, other_attrs) = method_item.attrs.iter().cloned().split_attrs()?;
79
80        method_item.attrs = other_attrs;
81    }
82
83    let chain_extension = chain_extension_trait_impl(impl_item)?;
84
85    Ok(quote! {
86        // Implementation of the trait for `ExtensionContext` with main logic.
87        #original_implementation
88
89        // Implementation of `ChainExtension` from `contract-pallet`
90        #chain_extension
91    })
92}
93
94#[allow(non_snake_case)]
95fn chain_extension_trait_impl(mut impl_item: ItemImpl) -> Result<TokenStream, Error> {
96    let context = ExtensionContext::try_from(&impl_item)?;
97
98    let namespace = quote! { ::obce::substrate::pallet_contracts::chain_extension:: };
99
100    let T = context.substrate;
101    let E = context.env;
102    let Env = context.obce_env;
103    let extension = context.extension;
104
105    let mut callable_generics = impl_item.generics.clone();
106    callable_generics = filter_generics(callable_generics, &context.lifetime1);
107    let (callable_impls, _, callable_where) = callable_generics.split_for_impl();
108
109    let mut main_generics = impl_item.generics.clone();
110    main_generics = filter_generics(main_generics, &context.lifetime1);
111    main_generics = filter_generics(main_generics, &E);
112    main_generics = filter_generics(main_generics, &Env);
113    let (main_impls, _, main_where) = main_generics.split_for_impl();
114
115    let mut call_generics = impl_item.generics.clone();
116    call_generics = filter_generics(call_generics, &context.lifetime1);
117    call_generics = filter_generics(call_generics, &Env);
118
119    // User is not required to use `Ext` trait for testing, so we automatically
120    // add `Ext` bound when generating "production" code.
121    if let Some(where_clause) = &mut call_generics.where_clause {
122        where_clause.predicates.push(parse_quote! {
123            #E: #namespace Ext<T = #T>
124        });
125    } else {
126        call_generics.where_clause = Some(parse_quote! {
127            where #E: #namespace Ext<T = #T>
128        });
129    }
130
131    let (_, _, call_where) = call_generics.split_for_impl();
132
133    let trait_;
134    let dyn_trait;
135    if let Some((_, path, _)) = impl_item.trait_ {
136        trait_ = path.clone();
137        dyn_trait = quote! { dyn #path };
138    } else {
139        return Err(format_err_spanned!(impl_item, "expected impl trait block",))
140    }
141
142    let methods: Vec<_> = impl_item
143        .items
144        .iter_mut()
145        .filter_map(|item| {
146            if let ImplItem::Method(method) = item {
147                Some(method)
148            } else {
149                None
150            }
151        })
152        .map(|method| {
153            let (obce_attrs, other_attrs) = method.attrs.iter().cloned().split_attrs()?;
154
155            method.attrs = other_attrs;
156
157            let hash = into_u32(&method.sig.ident);
158            let method_name = &method.sig.ident;
159
160            let input_bindings = InputBindings::from_iter(&method.sig.inputs);
161            let lhs_pat = input_bindings.lhs_pat(None);
162            let call_params = input_bindings.iter_call_params();
163
164            let (weight_tokens, pre_charge) = handle_weight_attribute(&input_bindings, obce_attrs.iter())?;
165            let ret_val_tokens = handle_ret_val_attribute(obce_attrs.iter());
166
167            let (read_with_charge, pre_charge_arg) = if pre_charge {
168                (
169                    quote! {
170                        let pre_charged = #weight_tokens;
171                        let #lhs_pat = env.read_as_unbounded(len)?;
172                    },
173                    quote! {
174                        Some(pre_charged)
175                    },
176                )
177            } else {
178                (
179                    quote! {
180                        let #lhs_pat = env.read_as_unbounded(len)?;
181                        #weight_tokens;
182                    },
183                    quote! {
184                        None
185                    },
186                )
187            };
188
189            Result::<_, Error>::Ok(quote! {
190                <#dyn_trait as ::obce::codegen::MethodDescription<#hash>>::ID => {
191                    #read_with_charge
192                    let mut context = ::obce::substrate::ExtensionContext::new(self, env, #pre_charge_arg);
193                    #[allow(clippy::unnecessary_mut_passed)]
194                    let result = <_ as #trait_>::#method_name(
195                        &mut context
196                        #(, #call_params)*
197                    );
198
199                    // If result is `Result` and `Err` is critical, return from the `call`.
200                    // Otherwise, try to convert result to RetVal, and return it or encode the result into the buffer.
201                    let result = ::obce::to_critical_error!(result)?;
202                    #ret_val_tokens
203                    <_ as ::scale::Encode>::using_encoded(&result, |w| context.env.write(w, true, None))?;
204                },
205            })
206        })
207        .try_collect()?;
208
209    Ok(quote! {
210        impl #callable_impls ::obce::substrate::CallableChainExtension<#E, #T, #Env> for #extension
211            #callable_where
212        {
213            fn call(&mut self, mut env: #Env) -> ::core::result::Result<
214                #namespace RetVal,
215                ::obce::substrate::CriticalError
216            > {
217                let len = env.in_len();
218
219                match env.func_id() {
220                    #(#methods)*
221                    _ => ::core::result::Result::Err(::obce::substrate::CriticalError::Other(
222                        "InvalidFunctionId"
223                    ))?,
224                };
225
226                Ok(#namespace RetVal::Converging(0))
227            }
228        }
229
230        impl #main_impls #namespace ChainExtension<#T> for #extension #main_where {
231            fn call<#E>(&mut self, env: #namespace Environment<#E, #namespace InitState>)
232                -> ::core::result::Result<#namespace RetVal, ::obce::substrate::CriticalError>
233                #call_where
234            {
235                <#extension as ::obce::substrate::CallableChainExtension<#E, #T, _>>::call(
236                    self, env.buf_in_buf_out()
237                )
238            }
239        }
240
241        impl #main_impls #namespace RegisteredChainExtension<#T> for #extension #main_where {
242            const ID: ::core::primitive::u16 = <#dyn_trait as ::obce::codegen::ExtensionDescription>::ID;
243        }
244    })
245}
246
247struct ExtensionContext {
248    // Lifetime `'a`
249    lifetime1: GenericArgument,
250    // Generic `E`
251    env: GenericArgument,
252    // Generic `T`
253    substrate: GenericArgument,
254    // Generic `Env`
255    obce_env: GenericArgument,
256    // Generic `Extension`
257    extension: GenericArgument,
258}
259
260impl TryFrom<&ItemImpl> for ExtensionContext {
261    type Error = Error;
262
263    fn try_from(impl_item: &ItemImpl) -> Result<Self, Self::Error> {
264        let Type::Path(path) = impl_item.self_ty.as_ref() else {
265            return Err(format_err_spanned!(
266                impl_item,
267                "the type should be `ExtensionContext`"
268            ));
269        };
270
271        let Some(extension) = path.path.segments.last() else {
272            return Err(format_err_spanned!(
273                path,
274                "the type should be `ExtensionContext`"
275            ));
276        };
277
278        let PathArguments::AngleBracketed(generic_args) = &extension.arguments else {
279            return Err(format_err_spanned!(
280                path,
281                "`ExtensionContext` should have 5 generics as `<'a, E, T, Env, Extension>`"
282            ));
283        };
284
285        let (lifetime1, env, substrate, obce_env, extension) =
286            generic_args.args.iter().cloned().tuples().exactly_one().map_err(|_| {
287                format_err_spanned!(
288                    generic_args,
289                    "`ExtensionContext` should have 5 generics as `<'a, E, T, Env, Extension>`"
290                )
291            })?;
292
293        Ok(ExtensionContext {
294            lifetime1,
295            env,
296            substrate,
297            obce_env,
298            extension,
299        })
300    }
301}
302
303fn filter_generics(mut generics: Generics, filter: &GenericArgument) -> Generics {
304    let filter: Vec<_> = filter
305        .to_token_stream()
306        .into_iter()
307        .map(|token| token.to_string())
308        .collect();
309    generics.params = generics
310        .params
311        .clone()
312        .into_iter()
313        .filter(|param| {
314            let param: Vec<_> = param
315                .to_token_stream()
316                .into_iter()
317                .map(|token| token.to_string())
318                .collect();
319            !is_subsequence(&param, &filter)
320        })
321        .collect();
322
323    if let Some(where_clause) = &mut generics.where_clause {
324        where_clause.predicates = where_clause
325            .predicates
326            .clone()
327            .into_iter()
328            .filter(|predicate| {
329                let predicate: Vec<_> = predicate
330                    .to_token_stream()
331                    .into_iter()
332                    .map(|token| token.to_string())
333                    .collect();
334                !is_subsequence(&predicate, &filter)
335            })
336            .collect();
337    }
338
339    generics
340}
341
342fn is_subsequence<T: PartialEq + core::fmt::Debug>(src: &[T], search: &[T]) -> bool {
343    if search.len() > src.len() {
344        return false
345    }
346
347    for i in 0..(src.len() - search.len() + 1) {
348        if &src[i..(i + search.len())] == search {
349            return true
350        }
351    }
352    false
353}
354
355fn handle_ret_val_attribute<'a, I: IntoIterator<Item = &'a NestedMeta>>(iter: I) -> Option<TokenStream> {
356    let should_handle = iter.into_iter().any(|attr| {
357        if let NestedMeta::Meta(Meta::Path(path)) = attr {
358            if let Some(ident) = path.get_ident() {
359                return ident == "ret_val"
360            }
361        }
362
363        false
364    });
365
366    should_handle.then(|| {
367        quote! {
368            if let Err(error) = result {
369                if let Ok(ret_val) = error.try_into() {
370                    return Ok(ret_val)
371                }
372            }
373        }
374    })
375}
376
377fn handle_weight_attribute<'a, I: IntoIterator<Item = &'a NestedMeta>>(
378    input_bindings: &InputBindings,
379    iter: I,
380) -> Result<(Option<TokenStream>, bool), Error> {
381    let weight_params = iter.into_iter().find_map(|attr| {
382        let NestedMeta::Meta(Meta::List(list)) = attr else {
383            return None;
384        };
385
386        let Some(ident) = list.path.get_ident() else {
387            return None
388        };
389
390        (ident == "weight").then_some((&list.nested, ident))
391    });
392
393    if let Some((weight_params, weight_ident)) = weight_params {
394        match weight_params.iter().find_by_name("dispatch") {
395            Some((LitOrPath::Lit(Lit::Str(dispatch_path)), ident)) => {
396                let args = match weight_params.iter().find_by_name("args") {
397                    Some((LitOrPath::Lit(Lit::Str(args)), _)) => Some(args.value()),
398                    None => None,
399                    Some((_, ident)) => {
400                        return Err(format_err_spanned!(
401                            ident,
402                            "`args` attribute should contain a comma-separated expression list"
403                        ))
404                    }
405                };
406
407                return Ok((
408                    Some(handle_dispatch_weight(
409                        ident,
410                        input_bindings,
411                        &dispatch_path.value(),
412                        args.as_deref(),
413                    )?),
414                    false,
415                ))
416            }
417            Some((_, ident)) => {
418                return Err(format_err_spanned!(
419                    ident,
420                    "`dispatch` attribute should contain a pallet method path"
421                ))
422            }
423            None => {}
424        };
425
426        match weight_params.iter().find_by_name("expr") {
427            Some((LitOrPath::Lit(Lit::Str(expr)), _)) => {
428                let pre_charge = matches!(
429                    weight_params.iter().find_by_name("pre_charge"),
430                    Some((LitOrPath::Path, _))
431                );
432
433                return Ok((
434                    Some(handle_expr_weight(input_bindings, &expr.value(), pre_charge)?),
435                    pre_charge,
436                ))
437            }
438            Some((_, ident)) => {
439                return Err(format_err_spanned!(
440                    ident,
441                    "`expr` attribute should contain an expression that returns `Weight`"
442                ))
443            }
444            None => {}
445        }
446
447        Err(format_err_spanned!(
448            weight_ident,
449            r#"either "dispatch" or "expr" attributes are expected"#
450        ))
451    } else {
452        Ok((None, false))
453    }
454}
455
456fn handle_expr_weight(input_bindings: &InputBindings, expr: &str, pre_charge: bool) -> Result<TokenStream, Error> {
457    let expr = parse_str::<Expr>(expr)?;
458
459    let raw_map = if pre_charge {
460        quote! {}
461    } else {
462        input_bindings.raw_special_mapping()
463    };
464
465    Ok(quote! {{
466        #[allow(unused_variables)]
467        #raw_map
468        env.charge_weight(#expr)?
469    }})
470}
471
472fn handle_dispatch_weight(
473    ident: &Ident,
474    input_bindings: &InputBindings,
475    dispatch_path: &str,
476    args: Option<&str>,
477) -> Result<TokenStream, Error> {
478    let segments = parse_str::<Path>(dispatch_path)?.segments.into_iter();
479    let segments_len = segments.len();
480
481    if segments_len < 3 {
482        return Err(format_err_spanned!(
483            ident,
484            "dispatch path should contain at least three segments"
485        ))
486    }
487
488    let (pallet_ns, _, method_name) = segments
489        .enumerate()
490        .group_by(|(idx, _)| if *idx < segments_len - 2 { 0 } else { *idx })
491        .into_iter()
492        .map(|(_, group)| group.map(|(_, segment)| segment))
493        .next_tuple::<(_, _, _)>()
494        .unwrap()
495        .map(Punctuated::<_, Token![::]>::from_iter);
496
497    let dispatch_args = if let Some(args) = args {
498        let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
499        parser.parse_str(args)?.to_token_stream()
500    } else {
501        let raw_call_params = input_bindings.iter_raw_call_params();
502
503        // If no args were provided try to call the pallet method using default outer args.
504        quote! {
505            #(*#raw_call_params,)*
506        }
507    };
508
509    let call_variant_name = format_ident!("new_call_variant_{}", method_name.last().unwrap().ident);
510
511    let raw_map = input_bindings.raw_special_mapping();
512
513    Ok(quote! {{
514        #[allow(unused_variables)]
515        #raw_map
516        let __call_variant = &#pallet_ns ::Call::<T>::#call_variant_name(#dispatch_args);
517        let __dispatch_info = <#pallet_ns ::Call<T> as ::obce::substrate::frame_support::dispatch::GetDispatchInfo>::get_dispatch_info(__call_variant);
518        env.charge_weight(__dispatch_info.weight)?
519    }})
520}