cainome_rs/
lib.rs

1use anyhow::Result;
2use cainome_parser::tokens::StateMutability;
3use cainome_parser::{AbiParser, TokenizedAbi};
4use camino::Utf8PathBuf;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use std::collections::HashMap;
8use std::fmt;
9use std::fs;
10use std::io;
11
12mod execution_version;
13mod expand;
14pub use execution_version::{ExecutionVersion, ParseExecutionVersionError};
15
16use crate::expand::utils;
17use crate::expand::{CairoContract, CairoEnum, CairoEnumEvent, CairoFunction, CairoStruct};
18
19///Type-safe contract bindings generated by Abigen.
20#[derive(Clone)]
21pub struct ContractBindings {
22    /// Name of the contract.
23    pub name: String,
24    /// Tokenized ABI written to a `[TokenStream2]`.
25    pub tokens: TokenStream2,
26}
27
28impl ContractBindings {
29    /// Writes the bindings to the specified file.
30    ///
31    /// # Arguments
32    ///
33    /// * `file` - The path to the file to write the bindings to.
34    pub fn write_to_file(&self, file: &str) -> io::Result<()> {
35        let content = format!(
36            "// ****\n// Auto-generated by cainome do not edit.\n// ****\n\n#![allow(clippy::all)]\n#![allow(warnings)]\n\n{}",
37            self
38        );
39        fs::write(file, content)
40    }
41}
42
43impl fmt::Display for ContractBindings {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        let syntax_tree = syn::parse2::<syn::File>(self.tokens.clone()).unwrap();
46        let s = prettyplease::unparse(&syntax_tree);
47        f.write_str(&s)
48    }
49}
50
51impl fmt::Debug for ContractBindings {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("ContractBindings")
54            .field("name", &self.name)
55            .finish()
56    }
57}
58
59/// Programmatically generate type-safe Rust bindings for an Starknet smart contract from its ABI.
60///
61/// Currently only one contract at a time is supported.
62#[derive(Debug, Clone)]
63pub struct Abigen {
64    /// Name of the contract, used as the variable name in the generated code
65    /// to identify the contract.
66    pub contract_name: String,
67    /// The path to a sierra artifact or a JSON with ABI entries only.
68    pub abi_source: Utf8PathBuf,
69    /// Types aliases to avoid name conflicts, as for now the types are limited to the
70    /// latest segment of the fully qualified path.
71    pub types_aliases: HashMap<String, String>,
72    /// The version of transaction to be executed.
73    pub execution_version: ExecutionVersion,
74    /// Derives to be added to the generated types.
75    pub derives: Vec<String>,
76    /// Derives to be added to the generated contract.
77    pub contract_derives: Vec<String>,
78    /// Types to be skipped from the generated types, usually combined with the `types_aliases` to
79    /// let the user specify the implementation of the types. If a type is generic, the generic arguments
80    /// are not part of the compared name.
81    pub type_skips: Vec<String>,
82}
83
84impl Abigen {
85    /// Creates a new instance of `Abigen`.
86    ///
87    /// # Arguments
88    ///
89    /// * `contract_name` - Name of the contract, used as the variable name in the generated code
90    ///   to identify the contract.
91    /// * `abi_source` - The path to a sierra artifact or a JSON with ABI entries only.
92    pub fn new(contract_name: &str, abi_source: &str) -> Self {
93        Self {
94            contract_name: contract_name.to_string(),
95            abi_source: Utf8PathBuf::from(abi_source),
96            types_aliases: HashMap::new(),
97            execution_version: ExecutionVersion::V1,
98            derives: vec![],
99            contract_derives: vec![],
100            type_skips: vec![],
101        }
102    }
103
104    /// Sets the types aliases to avoid name conflicts.
105    ///
106    /// # Arguments
107    ///
108    /// * `types_aliases` - Types aliases to avoid name conflicts.
109    pub fn with_types_aliases(mut self, types_aliases: HashMap<String, String>) -> Self {
110        self.types_aliases = types_aliases;
111        self
112    }
113
114    /// Sets the execution version to be used.
115    ///
116    /// # Arguments
117    ///
118    /// * `execution_version` - The version of transaction to be executed.
119    pub fn with_execution_version(mut self, execution_version: ExecutionVersion) -> Self {
120        self.execution_version = execution_version;
121        self
122    }
123
124    /// Sets the derives to be added to the generated types.
125    ///
126    /// # Arguments
127    ///
128    /// * `derives` - Derives to be added to the generated types.
129    pub fn with_derives(mut self, derives: Vec<String>) -> Self {
130        self.derives = derives;
131        self
132    }
133
134    /// Sets the derives to be added to the generated contract.
135    ///
136    /// # Arguments
137    ///
138    /// * `derives` - Derives to be added to the generated contract.
139    pub fn with_contract_derives(mut self, derives: Vec<String>) -> Self {
140        self.contract_derives = derives;
141        self
142    }
143
144    /// Sets the types to be skipped from the generated types.
145    ///
146    /// # Arguments
147    ///
148    /// * `type_skips` - Types to be skipped from the generated types.
149    pub fn with_type_skips(mut self, type_skips: Vec<String>) -> Self {
150        self.type_skips = type_skips;
151        self
152    }
153    /// Generates the contract bindings.
154    pub fn generate(&self) -> Result<ContractBindings> {
155        let file_content = std::fs::read_to_string(&self.abi_source)?;
156
157        match AbiParser::tokens_from_abi_string(&file_content, &self.types_aliases) {
158            Ok(tokens) => {
159                let expanded = abi_to_tokenstream(
160                    &self.contract_name,
161                    &tokens,
162                    self.execution_version,
163                    &self.derives,
164                    &self.contract_derives,
165                    &self.type_skips,
166                );
167
168                Ok(ContractBindings {
169                    name: self.contract_name.clone(),
170                    tokens: expanded,
171                })
172            }
173            Err(e) => {
174                anyhow::bail!(
175                    "Abi source {} could not be parsed {:?}. ABI file should be a JSON with an array of abi entries or a Sierra artifact.",
176                    self.abi_source, e
177                )
178            }
179        }
180    }
181}
182
183/// Converts the given ABI (in it's tokenize form) into rust bindings.
184///
185/// # Arguments
186///
187/// * `contract_name` - Name of the contract.
188/// * `abi_tokens` - Tokenized ABI.
189/// * `execution_version` - The version of transaction to be executed.
190/// * `derives` - Derives to be added to the generated types.
191/// * `contract_derives` - Derives to be added to the generated contract.
192/// * `type_skips` - Types to be skipped from the generated types.
193pub fn abi_to_tokenstream(
194    contract_name: &str,
195    abi_tokens: &TokenizedAbi,
196    execution_version: ExecutionVersion,
197    derives: &[String],
198    contract_derives: &[String],
199    type_skips: &[String],
200) -> TokenStream2 {
201    let type_skips = type_skips
202        .iter()
203        .map(|s| s.replace(" ", ""))
204        .collect::<Vec<String>>();
205    let contract_name = utils::str_to_ident(contract_name);
206
207    let mut tokens: Vec<TokenStream2> = vec![];
208
209    tokens.push(CairoContract::expand(
210        contract_name.clone(),
211        contract_derives,
212    ));
213
214    let mut sorted_structs = abi_tokens.structs.clone();
215    sorted_structs.sort_by(|a, b| {
216        let a_name = a
217            .to_composite()
218            .expect("composite expected")
219            .type_name_or_alias();
220        let b_name = b
221            .to_composite()
222            .expect("composite expected")
223            .type_name_or_alias();
224        a_name.cmp(&b_name)
225    });
226
227    let mut sorted_enums = abi_tokens.enums.clone();
228    sorted_enums.sort_by(|a, b| {
229        let a_name = a
230            .to_composite()
231            .expect("composite expected")
232            .type_name_or_alias();
233        let b_name = b
234            .to_composite()
235            .expect("composite expected")
236            .type_name_or_alias();
237        a_name.cmp(&b_name)
238    });
239
240    for s in &sorted_structs {
241        let s_composite = s.to_composite().expect("composite expected");
242
243        if type_skips.contains(&s_composite.type_path_no_generic()) {
244            continue;
245        }
246
247        tokens.push(CairoStruct::expand_decl(s_composite, derives));
248        tokens.push(CairoStruct::expand_impl(s_composite));
249    }
250
251    for e in &sorted_enums {
252        let e_composite = e.to_composite().expect("composite expected");
253        tokens.push(CairoEnum::expand_decl(e_composite, derives));
254        tokens.push(CairoEnum::expand_impl(e_composite));
255
256        if type_skips.contains(&e_composite.type_path_no_generic()) {
257            continue;
258        }
259
260        tokens.push(CairoEnumEvent::expand(
261            e.to_composite().expect("composite expected"),
262            &abi_tokens.enums,
263            &abi_tokens.structs,
264        ));
265    }
266
267    let mut reader_views = vec![];
268    let mut views = vec![];
269    let mut externals = vec![];
270
271    // Interfaces are not yet reflected in the generated contract.
272    // Then, the standalone functions and functions from interfaces are put together.
273    let mut functions = abi_tokens.functions.clone();
274    for funcs in abi_tokens.interfaces.values() {
275        functions.extend(funcs.clone());
276    }
277
278    functions.sort_by(|a, b| {
279        let a_name = a.to_function().expect("function expected").name.to_string();
280        let b_name = b.to_function().expect("function expected").name.to_string();
281        a_name.cmp(&b_name)
282    });
283
284    for f in functions {
285        let f = f.to_function().expect("function expected");
286        match f.state_mutability {
287            StateMutability::View => {
288                reader_views.push(CairoFunction::expand(f, true, execution_version));
289                views.push(CairoFunction::expand(f, false, execution_version));
290            }
291            StateMutability::External => {
292                externals.push(CairoFunction::expand(f, false, execution_version))
293            }
294        }
295    }
296
297    let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str());
298
299    tokens.push(quote! {
300        impl<A: starknet::accounts::ConnectedAccount + Sync> #contract_name<A> {
301            #(#views)*
302            #(#externals)*
303        }
304
305        impl<P: starknet::providers::Provider + Sync> #reader<P> {
306            #(#reader_views)*
307        }
308    });
309
310    let expanded = quote! {
311        #(#tokens)*
312    };
313
314    expanded
315}