1use proc_macro2::TokenStream;
2use quote::{
3 format_ident,
4 quote,
5};
6use syn::{
7 parse2,
8 parse_quote,
9 Error,
10 ImplItem,
11 ItemImpl,
12 ItemTrait,
13 TraitItem,
14 TraitItemMethod,
15};
16
17use crate::{
18 format_err_spanned,
19 utils::{
20 into_u32,
21 InputBindings,
22 },
23};
24
25pub fn generate(_: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
26 let mut impl_item: ItemImpl = parse2(input)?;
27
28 let Some((_, trait_name, _)) = impl_item.trait_ else {
29 return Err(format_err_spanned!(
30 impl_item,
31 "impl marked as mocked should have a trait present"
32 ));
33 };
34 let item = impl_item.self_ty;
35
36 let (impls, types, where_clause) = impl_item.generics.split_for_impl();
37
38 let methods = impl_item
40 .items
41 .iter_mut()
42 .filter_map(|item| {
43 if let ImplItem::Method(method_item) = item {
44 Some(method_item)
45 } else {
46 None
47 }
48 })
49 .collect::<Vec<_>>();
50
51 let mut mock_trait: ItemTrait = parse_quote! {
52 trait MockTrait {}
53 };
54
55 mock_trait.generics = impl_item.generics.clone();
56 mock_trait.items = methods
57 .iter()
58 .map(|method| (**method).clone())
59 .map(|val| {
60 TraitItem::Method(TraitItemMethod {
61 attrs: val.attrs,
62 sig: val.sig,
63 default: None,
64 semi_token: None,
65 })
66 })
67 .collect();
68
69 let mut mock_impl: ItemImpl = parse_quote! {
70 impl MockTrait #types for #item {}
71 };
72
73 mock_impl.generics = impl_item.generics.clone();
74 mock_impl.items = methods
75 .iter()
76 .map(|method| (**method).clone())
77 .map(ImplItem::Method)
78 .collect();
79
80 let proxies = methods.iter()
81 .map(|method| {
82 let hash = into_u32(&method.sig.ident);
83
84 let method_name = &method.sig.ident;
85 let proxy_name = format_ident!("ProxyFor{}", hash);
86 let proxy_where_clause = if let Some(mut where_clause) = where_clause.cloned() {
87 where_clause.predicates.push(parse_quote! {
88 dyn #trait_name: ::obce::codegen::ExtensionDescription,
89 });
90 where_clause.predicates.push(parse_quote! {
91 <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output: ::scale::Encode,
92 });
93 where_clause.predicates.push(parse_quote! {
94 <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input: ::scale::Decode
95 });
96 where_clause
97 } else {
98 parse_quote! {
99 where
100 dyn #trait_name: ::obce::codegen::ExtensionDescription,
101 <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output: ::scale::Encode,
102 <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input: ::scale::Decode
103 }
104 };
105
106 let input_bindings = InputBindings::from_iter(&method.sig.inputs);
107 let lhs_pat = input_bindings.lhs_pat(Some(parse_quote! {
108 <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Input
109 }));
110 let call_params = input_bindings.iter_call_params();
111
112 quote! {
113 struct #proxy_name #types (::std::rc::Rc<::std::cell::RefCell<#item>>);
114
115 impl #impls ::obce::ink_lang::env::test::ChainExtension for #proxy_name #types #proxy_where_clause {
116 fn func_id(&self) -> u32 {
117 let trait_id = <dyn #trait_name as ::obce::codegen::ExtensionDescription>::ID;
118 let func_id = <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::ID;
119 (trait_id as u32) << 16 | (func_id as u32)
120 }
121
122 fn call(&mut self, mut input: &[u8], output: &mut Vec<u8>) -> u32 {
123 let context = &mut *self.0.borrow_mut();
124
125 let bytes: Vec<u8> = ::scale::Decode::decode(&mut &input[..])
126 .unwrap();
127
128 let #lhs_pat = ::scale::Decode::decode(&mut &bytes[..])
129 .unwrap();
130
131 #[allow(clippy::unnecessary_mut_passed)]
132 let call_output: <dyn #trait_name as ::obce::codegen::MethodDescription<#hash>>::Output = <#item as MockTrait #types>::#method_name(
133 context
134 #(, #call_params)*
135 );
136
137 ::scale::Encode::encode_to(&call_output, output);
138
139 0
140 }
141 }
142
143 ::obce::ink_lang::env::test::register_chain_extension(#proxy_name(wrapped_context.clone()));
144 }
145 });
146
147 Ok(quote! {
148 pub fn register_chain_extensions #types (ctx: #item) {
149 #[allow(unused_variables)]
150 let wrapped_context = ::std::rc::Rc::new(::std::cell::RefCell::new(ctx));
151
152 #mock_trait
153
154 #mock_impl
155
156 #(#proxies)*
157 }
158 })
159}