obce_codegen/
mock.rs

1use proc_macro2::TokenStream;
2use quote::{
3    format_ident,
4    quote,
5};
6use syn::{
7    parse2,
8    parse_quote,
9    Error,
10    ImplItem,
11    ItemImpl,
12    ItemTrait,
13    TraitItem,
14    TraitItemMethod,
15};
16
17use crate::{
18    format_err_spanned,
19    utils::{
20        into_u32,
21        InputBindings,
22    },
23};
24
25pub fn generate(_: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
26    let mut impl_item: ItemImpl = parse2(input)?;
27
28    let Some((_, trait_name, _)) = impl_item.trait_ else {
29        return Err(format_err_spanned!(
30            impl_item,
31            "impl marked as mocked should have a trait present"
32        ));
33    };
34    let item = impl_item.self_ty;
35
36    let (impls, types, where_clause) = impl_item.generics.split_for_impl();
37
38    // We assume that every single item is a method.
39    let methods = impl_item
40        .items
41        .iter_mut()
42        .filter_map(|item| {
43            if let ImplItem::Method(method_item) = item {
44                Some(method_item)
45            } else {
46                None
47            }
48        })
49        .collect::<Vec<_>>();
50
51    let mut mock_trait: ItemTrait = parse_quote! {
52        trait MockTrait {}
53    };
54
55    mock_trait.generics = impl_item.generics.clone();
56    mock_trait.items = methods
57        .iter()
58        .map(|method| (**method).clone())
59        .map(|val| {
60            TraitItem::Method(TraitItemMethod {
61                attrs: val.attrs,
62                sig: val.sig,
63                default: None,
64                semi_token: None,
65            })
66        })
67        .collect();
68
69    let mut mock_impl: ItemImpl = parse_quote! {
70        impl MockTrait #types for #item {}
71    };
72
73    mock_impl.generics = impl_item.generics.clone();
74    mock_impl.items = methods
75        .iter()
76        .map(|method| (**method).clone())
77        .map(ImplItem::Method)
78        .collect();
79
80    let proxies = methods.iter()
81        .map(|method| {
82            let hash = into_u32(&method.sig.ident);
83
84            let method_name = &method.sig.ident;
85            let proxy_name = format_ident!("ProxyFor{}", hash);
86            let proxy_where_clause = if let Some(mut where_clause) = where_clause.cloned() {
87                where_clause.predicates.push(parse_quote! {
88                    dyn #trait_name: ::obce::codegen::ExtensionDescription,
89                });
90                where_clause.predicates.push(parse_quote! {
91                    <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output: ::scale::Encode,
92                });
93                where_clause.predicates.push(parse_quote! {
94                    <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input: ::scale::Decode
95                });
96                where_clause
97            } else {
98                parse_quote! {
99                    where
100                        dyn #trait_name: ::obce::codegen::ExtensionDescription,
101                        <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output: ::scale::Encode,
102                        <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input: ::scale::Decode
103                }
104            };
105
106            let input_bindings = InputBindings::from_iter(&method.sig.inputs);
107            let lhs_pat = input_bindings.lhs_pat(Some(parse_quote! {
108                <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input
109            }));
110            let call_params = input_bindings.iter_call_params();
111
112            quote! {
113                struct #proxy_name #types (::std::rc::Rc<::std::cell::RefCell<#item>>);
114
115                impl #impls ::obce::ink_lang::env::test::ChainExtension for #proxy_name #types #proxy_where_clause {
116                    fn func_id(&self) -> u32 {
117                        let trait_id = <dyn #trait_name as ::obce::codegen::ExtensionDescription>::ID;
118                        let func_id = <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::ID;
119                        (trait_id as u32) << 16 | (func_id as u32)
120                    }
121
122                    fn call(&mut self, mut input: &[u8], output: &mut Vec<u8>) -> u32 {
123                        let context = &mut *self.0.borrow_mut();
124
125                        let bytes: Vec<u8> = ::scale::Decode::decode(&mut &input[..])
126                            .unwrap();
127
128                        let #lhs_pat = ::scale::Decode::decode(&mut &bytes[..])
129                            .unwrap();
130
131                        #[allow(clippy::unnecessary_mut_passed)]
132                        let call_output: <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output = <#item as MockTrait #types>::#method_name(
133                            context
134                            #(, #call_params)*
135                        );
136
137                        ::scale::Encode::encode_to(&call_output, output);
138
139                        0
140                    }
141                }
142
143                ::obce::ink_lang::env::test::register_chain_extension(#proxy_name(wrapped_context.clone()));
144            }
145        });
146
147    Ok(quote! {
148        pub fn register_chain_extensions #types (ctx: #item) {
149            #[allow(unused_variables)]
150            let wrapped_context = ::std::rc::Rc::new(::std::cell::RefCell::new(ctx));
151
152            #mock_trait
153
154            #mock_impl
155
156            #(#proxies)*
157        }
158    })
159}