obce_codegen/
definition.rs

1// Copyright (c) 2012-2022 Supercolony
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the"Software"),
5// to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
22use itertools::Itertools;
23use proc_macro2::TokenStream;
24use quote::quote;
25use syn::{
26    parse2,
27    parse_quote,
28    Error,
29    FnArg,
30    ItemTrait,
31    Lit,
32    Meta,
33    NestedMeta,
34    ReturnType,
35    TraitItem,
36    TraitItemMethod,
37};
38
39use crate::{
40    format_err_spanned,
41    types::AttributeArgs,
42    utils::{
43        into_u16,
44        into_u32,
45        AttributeParser,
46    },
47};
48
49struct TraitAttrs {
50    id: u16,
51}
52
53impl TraitAttrs {
54    fn new<'a, I: IntoIterator<Item = &'a NestedMeta>>(trait_item: &ItemTrait, iter: I) -> Result<Self, Error> {
55        let id = find_id(iter)?.unwrap_or_else(|| into_u16(&trait_item.ident));
56
57        Ok(Self { id })
58    }
59}
60
61struct Method {
62    id: u16,
63    hash: u32,
64    input_tokens: TokenStream,
65    output_tokens: TokenStream,
66}
67
68impl Method {
69    fn new(method_item: &mut TraitItemMethod) -> Result<Self, Error> {
70        if let Some(default) = &method_item.default {
71            return Err(format_err_spanned!(
72                default,
73                "default implementation is not supported in chain extensions"
74            ))
75        }
76
77        let (obce_attrs, other_attrs) = method_item.attrs.iter().cloned().split_attrs()?;
78
79        method_item.attrs = other_attrs;
80
81        let id = find_id(obce_attrs.iter())
82            .transpose()
83            .unwrap_or_else(|| Ok(into_u16(&method_item.sig.ident)))?;
84
85        let hash = into_u32(&method_item.sig.ident);
86
87        let input_tys = method_item.sig.inputs.iter().filter_map(|input| {
88            if let FnArg::Typed(pat) = input {
89                Some(&*pat.ty)
90            } else {
91                None
92            }
93        });
94
95        let output_tokens = if let ReturnType::Type(_, ty) = &method_item.sig.output {
96            quote!(#ty)
97        } else {
98            quote!(())
99        };
100
101        Ok(Self {
102            id,
103            hash,
104            input_tokens: quote! {
105                (#(#input_tys),*)
106            },
107            output_tokens,
108        })
109    }
110
111    fn fill_with_ink_data(&self, trait_attrs: &TraitAttrs, method_item: &mut TraitItemMethod) {
112        let Method {
113            id,
114            input_tokens,
115            output_tokens,
116            ..
117        } = self;
118
119        let input_bound = parse_quote! {
120            #input_tokens: ::scale::Encode
121        };
122
123        let output_bound = parse_quote! {
124            #output_tokens: ::scale::Decode
125        };
126
127        if let Some(where_clause) = &mut method_item.sig.generics.where_clause {
128            where_clause.predicates.push(input_bound);
129            where_clause.predicates.push(output_bound);
130        } else {
131            method_item.sig.generics.where_clause = Some(parse_quote! {
132                where #input_bound, #output_bound
133            });
134        }
135
136        let input_bindings = method_item.sig.inputs.iter().filter_map(|input| {
137            if let FnArg::Typed(pat) = input {
138                Some(&*pat.pat)
139            } else {
140                None
141            }
142        });
143
144        // https://paritytech.github.io/substrate/master/pallet_contracts/chain_extension/trait.RegisteredChainExtension.html
145        let trait_id = (trait_attrs.id as u32) << 16;
146        let id_for_call = trait_id | (*id as u32);
147
148        method_item.default = Some(parse_quote! {{
149            ::obce::ink_lang::env::chain_extension::ChainExtensionMethod::build(#id_for_call)
150                .input::<#input_tokens>()
151                .output::<#output_tokens, false>()
152                .ignore_error_code()
153                .call(&(#(#input_bindings),*))
154        }});
155    }
156}
157
158pub fn generate(attrs: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
159    let mut trait_item: ItemTrait = parse2(input)?;
160
161    let trait_attrs = TraitAttrs::new(&trait_item, parse2::<AttributeArgs>(attrs)?.iter())?;
162
163    let trait_id = trait_attrs.id;
164    let trait_name = &trait_item.ident;
165
166    let (impls, types, where_clause) = trait_item.generics.split_for_impl();
167
168    let methods: Vec<_> = trait_item
169        .items
170        .iter_mut()
171        .map(|item| {
172            if let TraitItem::Method(method) = item {
173                Method::new(method)
174            } else {
175                Err(format_err_spanned!(
176                    item,
177                    "only methods are supported in trait definitions"
178                ))
179            }
180        })
181        .try_collect()?;
182
183    if let Some(id) = methods.iter().map(|Method { id, .. }| id).duplicates().next() {
184        return Err(format_err_spanned!(
185            trait_item,
186            "found duplicated method identifier: {id}",
187        ))
188    }
189
190    let method_descriptions = methods.iter().map(
191        |Method {
192             id,
193             hash,
194             input_tokens,
195             output_tokens,
196             ..
197         }| {
198            quote! {
199                impl #impls ::obce::codegen::MethodDescription<#hash> for dyn #trait_name #types #where_clause {
200                    const ID: ::core::primitive::u16 = #id;
201                    type Input = #input_tokens;
202                    type Output = #output_tokens;
203                }
204            }
205        },
206    );
207
208    let mut ink_trait_item = trait_item.clone();
209
210    ink_trait_item
211        .items
212        .iter_mut()
213        .zip(methods.iter())
214        .for_each(|(item, method)| {
215            if let TraitItem::Method(method_item) = item {
216                method.fill_with_ink_data(&trait_attrs, method_item);
217            } else {
218                // This branch is unreachable, because `ink_trait_item`
219                // is cloned from the `trait_item`, items of which are verified
220                // to be method above.
221                unreachable!("only methods are present here")
222            }
223        });
224
225    Ok(quote! {
226        impl #impls ::obce::codegen::ExtensionDescription for dyn #trait_name #types #where_clause {
227            const ID: ::core::primitive::u16 = #trait_id;
228        }
229
230        #(#method_descriptions)*
231
232        #[cfg(feature = "substrate")]
233        #trait_item
234
235        #[cfg(feature = "ink")]
236        #ink_trait_item
237    })
238}
239
240fn find_id<'a, I: IntoIterator<Item = &'a NestedMeta>>(iter: I) -> Result<Option<u16>, Error> {
241    iter.into_iter()
242        .find_map(|arg| {
243            match arg {
244                NestedMeta::Meta(Meta::NameValue(value)) if value.path.is_ident("id") => {
245                    Some(match &value.lit {
246                        Lit::Int(lit_int) => lit_int.base10_parse::<u16>(),
247                        Lit::Str(lit_str) => Ok(into_u16(lit_str.value())),
248                        _ => Err(format_err_spanned!(value, "id should be integer or string")),
249                    })
250                }
251                _ => None,
252            }
253        })
254        .transpose()
255}