fuels_code_gen/program_bindings/abigen/bindings/
contract.rs

1use fuel_abi_types::abi::full_program::{FullABIFunction, FullProgramABI};
2use itertools::Itertools;
3use proc_macro2::{Ident, TokenStream};
4use quote::{TokenStreamExt, quote};
5
6use crate::{
7    error::Result,
8    program_bindings::{
9        abigen::{
10            bindings::function_generator::FunctionGenerator,
11            configurables::generate_code_for_configurable_constants,
12            logs::{generate_id_error_codes_pairs, log_formatters_instantiation_code},
13        },
14        generated_code::GeneratedCode,
15    },
16    utils::{TypePath, ident},
17};
18
19pub(crate) fn contract_bindings(
20    name: &Ident,
21    abi: FullProgramABI,
22    no_std: bool,
23) -> Result<GeneratedCode> {
24    if no_std {
25        return Ok(GeneratedCode::default());
26    }
27
28    let log_formatters =
29        log_formatters_instantiation_code(quote! {contract_id.clone().into()}, &abi.logged_types);
30
31    let error_codes = generate_id_error_codes_pairs(abi.error_codes);
32    let error_codes = quote! {::std::collections::HashMap::from([#(#error_codes),*])};
33
34    let methods_name = ident(&format!("{name}Methods"));
35    let contract_methods_name = ident(&format!("{name}MethodVariants"));
36
37    let contract_functions = expand_functions(&abi.functions)?;
38    let constant_methods_code =
39        generate_constant_methods_pattern(&abi.functions, &contract_methods_name)?;
40
41    let configuration_struct_name = ident(&format!("{name}Configurables"));
42    let constant_configuration_code =
43        generate_code_for_configurable_constants(&configuration_struct_name, &abi.configurables)?;
44
45    let code = quote! {
46        #[derive(Debug, Clone)]
47        pub struct #name<A = ()> {
48            contract_id: ::fuels::types::ContractId,
49            account: A,
50            log_decoder: ::fuels::core::codec::LogDecoder,
51            encoder_config: ::fuels::core::codec::EncoderConfig,
52        }
53
54        impl #name {
55            pub const METHODS: #contract_methods_name = #contract_methods_name;
56        }
57
58        impl<A> #name<A>
59        {
60            pub fn new(
61                contract_id: ::fuels::types::ContractId,
62                account: A,
63            ) -> Self {
64                let log_decoder = ::fuels::core::codec::LogDecoder::new(#log_formatters, #error_codes);
65                let encoder_config = ::fuels::core::codec::EncoderConfig::default();
66                Self { contract_id, account, log_decoder, encoder_config }
67            }
68
69            pub fn contract_id(&self) -> ::fuels::types::ContractId {
70                self.contract_id
71            }
72
73            pub fn account(&self) -> &A {
74                &self.account
75            }
76
77            pub fn with_account<U: ::fuels::accounts::Account>(self, account: U)
78            -> #name<U> {
79                #name {
80                        contract_id: self.contract_id,
81                        account,
82                        log_decoder: self.log_decoder,
83                        encoder_config: self.encoder_config
84                }
85            }
86
87            pub fn with_encoder_config(mut self, encoder_config: ::fuels::core::codec::EncoderConfig)
88            -> #name::<A> {
89                self.encoder_config = encoder_config;
90
91                self
92            }
93
94            pub async fn get_balances(&self) -> ::fuels::types::errors::Result<::std::collections::HashMap<::fuels::types::AssetId, u64>> where A: ::fuels::accounts::ViewOnlyAccount {
95                ::fuels::accounts::ViewOnlyAccount::try_provider(&self.account)?
96                                  .get_contract_balances(&self.contract_id)
97                                  .await
98                                  .map_err(::std::convert::Into::into)
99            }
100
101            pub fn methods(&self) -> #methods_name<A> where A: Clone {
102                #methods_name {
103                    contract_id: self.contract_id.clone(),
104                    account: self.account.clone(),
105                    log_decoder: self.log_decoder.clone(),
106                    encoder_config: self.encoder_config.clone(),
107                }
108            }
109        }
110
111        // Implement struct that holds the contract methods
112        pub struct #methods_name<A> {
113            contract_id: ::fuels::types::ContractId,
114            account: A,
115            log_decoder: ::fuels::core::codec::LogDecoder,
116            encoder_config: ::fuels::core::codec::EncoderConfig,
117        }
118
119        impl<A: ::fuels::accounts::Account + Clone> #methods_name<A> {
120            #contract_functions
121        }
122
123        impl<A>
124            ::fuels::programs::calls::ContractDependency for #name<A>
125        {
126            fn id(&self) -> ::fuels::types::ContractId {
127                self.contract_id
128            }
129
130            fn log_decoder(&self) -> ::fuels::core::codec::LogDecoder {
131                self.log_decoder.clone()
132            }
133        }
134
135        #constant_configuration_code
136
137        #constant_methods_code
138    };
139
140    // All publicly available types generated above should be listed here.
141    let type_paths = [
142        name,
143        &methods_name,
144        &configuration_struct_name,
145        &contract_methods_name,
146    ]
147    .map(|type_name| TypePath::new(type_name).expect("We know the given types are not empty"))
148    .into_iter()
149    .collect();
150
151    Ok(GeneratedCode::new(code, type_paths, no_std))
152}
153
154fn expand_functions(functions: &[FullABIFunction]) -> Result<TokenStream> {
155    functions
156        .iter()
157        .map(expand_fn)
158        .fold_ok(TokenStream::default(), |mut all_code, code| {
159            all_code.append_all(code);
160            all_code
161        })
162}
163
164/// Transforms a function defined in [`FullABIFunction`] into a [`TokenStream`]
165/// that represents that same function signature as a Rust-native function
166/// declaration.
167pub(crate) fn expand_fn(abi_fun: &FullABIFunction) -> Result<TokenStream> {
168    let mut generator = FunctionGenerator::new(abi_fun)?;
169
170    generator.set_docs(abi_fun.doc_strings()?);
171
172    let original_output = generator.output_type();
173    generator.set_output_type(
174        quote! {::fuels::programs::calls::CallHandler<A, ::fuels::programs::calls::ContractCall, #original_output> },
175    );
176
177    let fn_selector = generator.fn_selector();
178    let arg_tokens = generator.tokenized_args();
179    let is_payable = abi_fun.is_payable();
180    let body = quote! {
181            ::fuels::programs::calls::CallHandler::new_contract_call(
182                self.contract_id.clone(),
183                self.account.clone(),
184                #fn_selector,
185                &#arg_tokens,
186                self.log_decoder.clone(),
187                #is_payable,
188                self.encoder_config.clone(),
189            )
190    };
191    generator.set_body(body);
192
193    Ok(generator.generate())
194}
195
196fn generate_constant_methods_pattern(
197    functions: &[FullABIFunction],
198    contract_methods_name: &Ident,
199) -> Result<TokenStream> {
200    let method_descriptors = functions.iter().map(|func| {
201        let method_name = ident(func.name());
202        let fn_name = func.name();
203        let fn_selector =
204            proc_macro2::Literal::byte_string(&crate::utils::encode_fn_selector(fn_name));
205
206        quote! {
207            pub const fn #method_name(&self) -> ::fuels::types::MethodDescriptor {
208                ::fuels::types::MethodDescriptor {
209                    name: #fn_name,
210                    fn_selector: #fn_selector,
211                }
212            }
213        }
214    });
215
216    let all_methods = functions.iter().map(|func| {
217        let method_name = ident(func.name());
218        quote! { Self.#method_name() }
219    });
220
221    let method_count = functions.len();
222
223    let code = quote! {
224        #[derive(Debug, Clone, Copy)]
225        pub struct #contract_methods_name;
226
227        impl #contract_methods_name {
228            #(#method_descriptors)*
229
230            pub const fn iter(&self) -> [::fuels::types::MethodDescriptor; #method_count] {
231                [#(#all_methods),*]
232            }
233        }
234    };
235
236    Ok(code)
237}
238
239#[cfg(test)]
240mod tests {
241    use std::collections::HashMap;
242
243    use fuel_abi_types::abi::{
244        full_program::FullABIFunction,
245        program::Attribute,
246        unified_program::{UnifiedABIFunction, UnifiedTypeApplication, UnifiedTypeDeclaration},
247    };
248    use pretty_assertions::assert_eq;
249    use quote::quote;
250
251    use crate::{error::Result, program_bindings::abigen::bindings::contract::expand_fn};
252
253    #[test]
254    fn expand_contract_method_simple() -> Result<()> {
255        let the_function = UnifiedABIFunction {
256            inputs: vec![UnifiedTypeApplication {
257                name: String::from("bimbam"),
258                type_id: 1,
259                ..Default::default()
260            }],
261            name: "hello_world".to_string(),
262            attributes: Some(vec![Attribute {
263                name: "doc-comment".to_string(),
264                arguments: vec!["This is a doc string".to_string()],
265            }]),
266            ..Default::default()
267        };
268        let types = [
269            (
270                0,
271                UnifiedTypeDeclaration {
272                    type_id: 0,
273                    type_field: String::from("()"),
274                    ..Default::default()
275                },
276            ),
277            (
278                1,
279                UnifiedTypeDeclaration {
280                    type_id: 1,
281                    type_field: String::from("bool"),
282                    ..Default::default()
283                },
284            ),
285        ]
286        .into_iter()
287        .collect::<HashMap<_, _>>();
288        let result = expand_fn(&FullABIFunction::from_counterpart(&the_function, &types)?);
289
290        let expected = quote! {
291            #[doc = "This is a doc string"]
292            pub fn hello_world(&self, bimbam: ::core::primitive::bool) -> ::fuels::programs::calls::CallHandler<A, ::fuels::programs::calls::ContractCall, ()> {
293                ::fuels::programs::calls::CallHandler::new_contract_call(
294                    self.contract_id.clone(),
295                    self.account.clone(),
296                    ::fuels::core::codec::encode_fn_selector("hello_world"),
297                    &[::fuels::core::traits::Tokenizable::into_token(bimbam)],
298                    self.log_decoder.clone(),
299                    false,
300                    self.encoder_config.clone(),
301                )
302            }
303        };
304
305        assert_eq!(result?.to_string(), expected.to_string());
306
307        Ok(())
308    }
309
310    #[test]
311    fn expand_contract_method_complex() -> Result<()> {
312        // given
313        let the_function = UnifiedABIFunction {
314            inputs: vec![UnifiedTypeApplication {
315                name: String::from("the_only_allowed_input"),
316                type_id: 4,
317                ..Default::default()
318            }],
319            name: "hello_world".to_string(),
320            output: UnifiedTypeApplication {
321                name: String::from("stillnotused"),
322                type_id: 1,
323                ..Default::default()
324            },
325            attributes: Some(vec![
326                Attribute {
327                    name: "doc-comment".to_string(),
328                    arguments: vec!["This is a doc string".to_string()],
329                },
330                Attribute {
331                    name: "doc-comment".to_string(),
332                    arguments: vec!["This is another doc string".to_string()],
333                },
334            ]),
335        };
336        let types = [
337            (
338                1,
339                UnifiedTypeDeclaration {
340                    type_id: 1,
341                    type_field: String::from("enum EntropyCirclesEnum"),
342                    components: Some(vec![
343                        UnifiedTypeApplication {
344                            name: String::from("Postcard"),
345                            type_id: 2,
346                            ..Default::default()
347                        },
348                        UnifiedTypeApplication {
349                            name: String::from("Teacup"),
350                            type_id: 3,
351                            ..Default::default()
352                        },
353                    ]),
354                    ..Default::default()
355                },
356            ),
357            (
358                2,
359                UnifiedTypeDeclaration {
360                    type_id: 2,
361                    type_field: String::from("bool"),
362                    ..Default::default()
363                },
364            ),
365            (
366                3,
367                UnifiedTypeDeclaration {
368                    type_id: 3,
369                    type_field: String::from("u64"),
370                    ..Default::default()
371                },
372            ),
373            (
374                4,
375                UnifiedTypeDeclaration {
376                    type_id: 4,
377                    type_field: String::from("struct SomeWeirdFrenchCuisine"),
378                    components: Some(vec![
379                        UnifiedTypeApplication {
380                            name: String::from("Beef"),
381                            type_id: 2,
382                            ..Default::default()
383                        },
384                        UnifiedTypeApplication {
385                            name: String::from("BurgundyWine"),
386                            type_id: 3,
387                            ..Default::default()
388                        },
389                    ]),
390                    ..Default::default()
391                },
392            ),
393        ]
394        .into_iter()
395        .collect::<HashMap<_, _>>();
396
397        // when
398        let result = expand_fn(&FullABIFunction::from_counterpart(&the_function, &types)?);
399
400        // then
401
402        // Some more editing was required because it is not rustfmt-compatible (adding/removing parentheses or commas)
403        let expected = quote! {
404            #[doc = "This is a doc string"]
405            #[doc = "This is another doc string"]
406            pub fn hello_world(
407                &self,
408                the_only_allowed_input: self::SomeWeirdFrenchCuisine
409            ) -> ::fuels::programs::calls::CallHandler<A, ::fuels::programs::calls::ContractCall, self::EntropyCirclesEnum> {
410                ::fuels::programs::calls::CallHandler::new_contract_call(
411                    self.contract_id.clone(),
412                    self.account.clone(),
413                    ::fuels::core::codec::encode_fn_selector( "hello_world"),
414                    &[::fuels::core::traits::Tokenizable::into_token(
415                        the_only_allowed_input
416                    )],
417                    self.log_decoder.clone(),
418                    false,
419                    self.encoder_config.clone(),
420                )
421            }
422        };
423
424        assert_eq!(result?.to_string(), expected.to_string());
425
426        Ok(())
427    }
428}