pink_macro/
chain_extension.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::quote;
3use syn::{Result, Type};
4use unzip3::Unzip3 as _;
5
6use ink_ir::ChainExtension;
7
8pub(crate) fn patch(input: TokenStream2) -> TokenStream2 {
9    match patch_chain_extension_or_err(input) {
10        Ok(tokens) => tokens,
11        Err(err) => err.to_compile_error(),
12    }
13}
14
15fn patch_chain_extension_or_err(input: TokenStream2) -> Result<TokenStream2> {
16    use proc_macro2::{Ident, Literal, Span};
17
18    let backend_trait = {
19        let mut item_trait: syn::ItemTrait = syn::parse2(input.clone())?;
20
21        item_trait.ident = syn::Ident::new(
22            &format!("{}Backend", item_trait.ident),
23            item_trait.ident.span(),
24        );
25
26        item_trait
27            .items
28            .retain(|i| !matches!(i, &syn::TraitItem::Type(_)));
29
30        item_trait.items.push(syn::parse_quote! {
31            type Error;
32        });
33
34        for item in item_trait.items.iter_mut() {
35            if let syn::TraitItem::Fn(item_method) = item {
36                item_method
37                    .attrs
38                    .retain(|attr| !attr.path().is_ident("ink"));
39
40                // Turn &[u8] into Cow<[u8]>
41                for input in item_method.sig.inputs.iter_mut() {
42                    match input {
43                        syn::FnArg::Receiver(_) => (),
44                        syn::FnArg::Typed(arg) => {
45                            if let Type::Reference(tp) = *arg.ty.clone() {
46                                let inner_type = tp.elem.clone();
47                                arg.ty = syn::parse_quote! {
48                                    Cow<#inner_type>
49                                };
50                            }
51                        }
52                    }
53                }
54
55                item_method.sig.inputs.insert(
56                    0,
57                    syn::parse_quote! {
58                        &self
59                    },
60                );
61                item_method.sig.output = match item_method.sig.output.clone() {
62                    syn::ReturnType::Type(_, tp) => {
63                        syn::parse_quote! {
64                            -> Result<#tp, Self::Error>
65                        }
66                    }
67                    syn::ReturnType::Default => {
68                        syn::parse_quote! {
69                            -> Result<(), Self::Error>
70                        }
71                    }
72                };
73            }
74        }
75
76        item_trait
77    };
78
79    let extension = ChainExtension::new(Default::default(), input.clone())?;
80    let id_pairs: Vec<_> = {
81        extension
82            .iter_methods()
83            .map(|m| {
84                let name = m.ident().to_string();
85                let id = m.id().into_u32();
86                let args: Vec<_> = m
87                    .inputs()
88                    .enumerate()
89                    .map(|(i, _)| Ident::new(&format!("arg_{i}"), Span::call_site()))
90                    .collect();
91                (name, id, args)
92            })
93            .collect()
94    };
95
96    // Extract all function ids to a sub module
97    let func_ids = {
98        let mut mod_item: syn::ItemMod = syn::parse_quote! {
99            pub mod func_ids {}
100        };
101        for (name, id, _) in id_pairs.iter() {
102            let name = name.to_uppercase();
103            let name = Ident::new(&name, Span::call_site());
104            let id = Literal::u32_unsuffixed(*id);
105            mod_item
106                .content
107                .as_mut()
108                .unwrap()
109                .1
110                .push(syn::parse_quote! {
111                    pub const #name: u32 = #id;
112                });
113        }
114        mod_item
115    };
116
117    // Generate the dispatcher
118    let dispatcher: syn::ItemMacro = {
119        let (names, ids, args): (Vec<_>, Vec<_>, Vec<_>) = id_pairs
120            .into_iter()
121            .map(|(name, id, args)| {
122                let name = Ident::new(&name, Span::call_site());
123                let id = Literal::u32_unsuffixed(id);
124                (name, id, args)
125            })
126            .unzip3();
127        syn::parse_quote! {
128            #[macro_export]
129            macro_rules! dispatch_ext_call {
130                ($func_id: expr, $handler: expr, $env: expr) => {
131                    match $func_id {
132                        #(
133                            #ids => {
134                                use $crate::chain_extension::EncodeOutputFallback as _;
135                                let (#(#args),*) = $env.read_as_unbounded($env.in_len())?;
136                                let output = $handler.#names(#(#args),*)?;
137                                Some($crate::chain_extension::EncodeOutput(output).encode())
138                            }
139                        )*
140                        _ => None,
141                    }
142                };
143            }
144        }
145    };
146
147    // Mock helper functions
148    let mock_helpers = {
149        let mut mod_item: syn::ItemMod = syn::parse_quote! {
150            pub mod mock {
151                use super::*;
152                use super::test::MockExtensionFn;
153            }
154        };
155        let mut reg_expressions: Vec<TokenStream2> = Default::default();
156        for m in extension.iter_methods() {
157            let name = m.ident().to_string();
158            let fname = "mock_".to_owned() + &name;
159            let fname = Ident::new(&fname, Span::call_site());
160            let origin_fname = Ident::new(&name, Span::call_site());
161            let id = Literal::u32_unsuffixed(m.id().into_u32());
162            let input_types: Vec<Type> = m.inputs().map(|arg| (*arg.ty).clone()).collect();
163            let input_types_cow: Vec<Type> = input_types
164                .iter()
165                .map(|arg| match arg.clone() {
166                    Type::Reference(tp) => {
167                        let inner = &tp.elem;
168                        syn::parse_quote! { Cow<#inner> }
169                    }
170                    tp => tp,
171                })
172                .collect();
173            let input_args: Vec<_> = input_types
174                .iter()
175                .enumerate()
176                .map(|(i, _)| Ident::new(&format!("arg_{i}"), Span::call_site()))
177                .collect();
178            let input_args_asref: Vec<TokenStream2> = input_types
179                .iter()
180                .enumerate()
181                .map(|(i, tp)| {
182                    let name = Ident::new(&format!("arg_{i}"), Span::call_site());
183                    match tp {
184                        Type::Reference(_) => {
185                            syn::parse_quote! {
186                                #name.as_ref()
187                            }
188                        }
189                        _ => syn::parse_quote! {
190                            #name
191                        },
192                    }
193                })
194                .collect();
195            let output = m.sig().output.clone();
196            mod_item
197                .content
198                .as_mut()
199                .unwrap()
200                .1
201                .push(syn::parse_quote! {
202                    pub fn #fname(mut call: impl FnMut(#(#input_types),*) #output + 'static) {
203                        ink::env::test::register_chain_extension(
204                            MockExtensionFn::<_, _, #id>::new(
205                                move |(#(#input_args),*): (#(#input_types_cow),*)| {
206                                    use crate::chain_extension::EncodeOutputFallback as _;
207                                    let output = call(#(#input_args_asref),*);
208                                    crate::chain_extension::EncodeOutput(output).encode()
209                                }
210                            ),
211                        );
212                    }
213                });
214            reg_expressions.push(syn::parse_quote! {
215                ink::env::test::register_chain_extension(
216                    MockExtensionFn::<_, _, #id>::new(
217                        move |(#(#input_args),*): (#(#input_types_cow),*)| {
218                            use crate::chain_extension::EncodeOutputFallback as _;
219                            let output = ext_impl.#origin_fname(#(#input_args),*).unwrap();
220                            crate::chain_extension::EncodeOutput(output).encode()
221                        }
222                    ),
223                );
224            });
225        }
226
227        let backend_trait_ident = &backend_trait.ident;
228        mod_item
229        .content
230        .as_mut()
231        .unwrap()
232        .1
233        .push(syn::parse_quote! {
234            pub fn mock_all_with<E: core::fmt::Debug, I: #backend_trait_ident<Error=E>>(ext_impl: &'static I) {
235                #(#reg_expressions)*
236            }
237        });
238        mod_item
239    };
240
241    let crate_ink_lang = crate::find_crate_name("ink")?;
242    Ok(quote! {
243        #[#crate_ink_lang::chain_extension]
244        #input
245
246        #backend_trait
247
248        #func_ids
249
250        #dispatcher
251
252        #[cfg(feature = "std")]
253        #mock_helpers
254    })
255}