bonfida_autobindings/
lib.rs

1use anchor_syn::idl::types::Idl;
2use cargo_toml::Manifest;
3use clap::{crate_name, crate_version, Arg, ArgMatches, Command};
4use convert_case::{Boundary, Case, Casing};
5use idl_generate::{idl_process_file, idl_process_state_file};
6use proc_macro2::TokenTree;
7use std::{
8    collections::HashMap,
9    fs::{self, File},
10    io::{Read, Write},
11    str::FromStr,
12    time::Instant,
13};
14
15use syn::{
16    punctuated::Punctuated, token::Comma, Attribute, Expr, ExprLit, Field, Fields, FieldsNamed,
17    Item, ItemEnum, ItemStruct, Lit, Path, Type, TypeArray, TypePath, TypeReference, Variant,
18};
19
20use crate::js_generate::js_process_file;
21use crate::py_generate::py_process_file;
22
23pub mod idl_generate;
24pub mod js_generate;
25pub mod py_generate;
26pub mod test;
27
28#[derive(Debug, Clone, Copy)]
29pub enum TargetLang {
30    Javascript,
31    Python,
32    AnchorIdl,
33}
34
35pub fn command() -> Command<'static> {
36    Command::new(crate_name!())
37        .version(crate_version!())
38        .about("Autogenerate Rust and JS instruction bindings")
39        .author("Bonfida")
40        .arg(
41            Arg::with_name("instr-path")
42                .long("instructions-path")
43                .takes_value(true)
44                .default_value("src/processor"),
45        )
46        .arg(
47            Arg::with_name("toml-path")
48                .long("cargo-toml-path")
49                .takes_value(true)
50                .default_value("Cargo.toml"),
51        )
52        .arg(
53            Arg::with_name("instr-enum-path")
54                .long("instructions-enum-path")
55                .takes_value(true)
56                .default_value("src/instruction_auto.rs"),
57        )
58        .arg(
59            Arg::with_name("account-tag-enum-path")
60                .long("account-tag-enum-path")
61                .takes_value(true)
62                .default_value("src/state.rs"),
63        )
64        .arg(
65            Arg::with_name("state-folder")
66                .long("state-folder")
67                .takes_value(true)
68                .default_value("src/state"),
69        )
70        .arg(
71            Arg::with_name("target-lang")
72                .long("target-language")
73                .takes_value(true)
74                .default_value("js")
75                .help("Enter \"py\", \"js\" or \"idl\""),
76        )
77        .arg(
78            Arg::with_name("test")
79                .long("test")
80                .takes_value(true)
81                .default_value("false")
82                .help("Enter true or false"),
83        )
84        .arg(
85            Arg::with_name("skip-account-tag")
86                .long("skip-account-tag")
87                .takes_value(false),
88        )
89        .arg(
90            Arg::with_name("no-state")
91                .long("no-state")
92                .action(clap::ArgAction::SetTrue),
93        )
94}
95
96pub fn process(matches: &ArgMatches) {
97    let instructions_path = matches.value_of("instr-path").unwrap();
98    let instructions_enum_path = matches.value_of("instr-enum-path").unwrap();
99    let cargo_toml_path = matches.value_of("toml-path").unwrap();
100    let target_lang_str = matches.value_of("target-lang").unwrap();
101    let state_folder = matches.value_of("state-folder").unwrap();
102    let skip_account_tag = matches.contains_id("skip-account-tag");
103    let target_lang = match target_lang_str {
104        "js" | "javascript" => TargetLang::Javascript,
105        "py" | "python" => TargetLang::Python,
106        "idl" | "anchor-idl" => TargetLang::AnchorIdl,
107        _ => {
108            println!("Target language must be javascript or python");
109            panic!()
110        }
111    };
112    let test_mode = bool::from_str(matches.value_of("test").unwrap()).unwrap();
113    let no_state = matches.get_flag("no-state");
114    fs::create_dir_all("../js/src/").unwrap();
115    fs::create_dir_all("../python/src/").unwrap();
116
117    let now = Instant::now();
118
119    match test_mode {
120        true => {
121            test::test(instructions_path);
122        }
123        false => {
124            generate(
125                cargo_toml_path,
126                instructions_path,
127                instructions_enum_path,
128                state_folder,
129                target_lang,
130                match target_lang {
131                    TargetLang::Javascript => "../js/src/raw_instructions.ts",
132                    TargetLang::Python => "../python/src/raw_instructions.py",
133                    TargetLang::AnchorIdl => "idl.json",
134                },
135                skip_account_tag,
136                no_state,
137            );
138        }
139    }
140
141    let elapsed = now.elapsed();
142    println!("✨  Done in {:.2?}", elapsed);
143}
144
145#[allow(clippy::too_many_arguments)]
146pub fn generate(
147    cargo_toml_path: &str,
148    instructions_path: &str,
149    instructions_enum_path: &str,
150    state_folder_path: &str,
151    target_lang: TargetLang,
152    output_path: &str,
153    skip_account_tag: bool,
154    no_state: bool,
155) {
156    let path = std::path::Path::new(instructions_path);
157    let (instruction_tags, use_casting) = parse_instructions_enum(instructions_enum_path);
158    let directory = std::fs::read_dir(path).unwrap();
159    let cargo_toml_path = std::path::Path::new(&cargo_toml_path)
160        .canonicalize()
161        .unwrap();
162    let manifest = Manifest::from_path(cargo_toml_path).unwrap();
163    let mut output = get_header(target_lang);
164    let mut idl = Idl {
165        version: manifest.package.as_ref().unwrap().version.clone().unwrap(),
166        name: manifest.package.as_ref().unwrap().name.clone(),
167        constants: vec![],
168        instructions: vec![],
169        accounts: vec![],
170        types: vec![],
171        events: None,
172        errors: None,
173        metadata: None,
174        docs: None,
175    };
176    for d in directory {
177        let file = d.unwrap();
178        let module_name = std::path::Path::new(&file.file_name())
179            .file_stem()
180            .unwrap()
181            .to_str()
182            .unwrap()
183            .to_owned();
184        let instruction_tag = instruction_tags.get(&module_name).unwrap_or_else(|| {
185            panic!(
186                "Instruction not found for {} in {:#?}",
187                module_name, instruction_tags
188            )
189        });
190        match target_lang {
191            TargetLang::Javascript => {
192                let s = js_process_file(
193                    &module_name,
194                    *instruction_tag,
195                    file.path().to_str().unwrap(),
196                    use_casting,
197                );
198                output.push_str(&s)
199            }
200            TargetLang::Python => {
201                let s = py_process_file(
202                    &module_name,
203                    *instruction_tag,
204                    file.path().to_str().unwrap(),
205                    use_casting,
206                );
207                output.push_str(&s)
208            }
209            TargetLang::AnchorIdl => {
210                let i = idl_process_file(&module_name, file.path().to_str().unwrap());
211                if let Some(i) = i {
212                    idl.instructions.push(i)
213                }
214            }
215        };
216    }
217
218    if matches!(target_lang, TargetLang::AnchorIdl) {
219        if !no_state {
220            let state_directory =
221                std::fs::read_dir(std::path::Path::new(state_folder_path)).unwrap();
222            for d in state_directory {
223                let file = d.unwrap();
224                let account = idl_process_state_file(&file.path(), skip_account_tag);
225                idl.accounts.push(account);
226            }
227        }
228        output.push_str(&serde_json::to_string_pretty(&idl).unwrap())
229    }
230
231    let mut out_file = File::create(output_path).unwrap();
232    out_file.write_all(output.as_bytes()).unwrap();
233}
234
235pub fn parse_instructions_enum(instructions_enum_path: &str) -> (HashMap<String, usize>, bool) {
236    let mut f = File::open(instructions_enum_path)
237        .unwrap_or_else(|e| panic!("{e} {}", instructions_enum_path));
238    let mut result_map = HashMap::new();
239    let mut raw_string = String::new();
240    f.read_to_string(&mut raw_string).unwrap();
241    let use_casting = raw_string.contains("get_instruction_cast");
242    let ast: syn::File = syn::parse_str(&raw_string).unwrap();
243    let instructions_enum = find_enum(&ast, None);
244    let enum_variants = get_enum_variants(instructions_enum);
245    let mut instruction_tag = 0;
246    for Variant {
247        ident,
248        discriminant,
249        ..
250    } in enum_variants.into_iter()
251    {
252        let module_name = pascal_to_snake(&ident.to_string());
253        if let Some((_, discriminant)) = discriminant {
254            if let Expr::Lit(ExprLit {
255                lit: Lit::Int(i), ..
256            }) = discriminant
257            {
258                let parsed = i.base10_parse().unwrap();
259                instruction_tag = parsed;
260            } else {
261                panic!("Unsupported enum discriminant type!");
262            }
263        }
264        result_map.insert(module_name, instruction_tag);
265        instruction_tag += 1;
266    }
267    (result_map, use_casting)
268}
269
270pub fn parse_account_tag_enum(account_tag_enum_path: &str) -> HashMap<String, usize> {
271    let mut f = File::open(account_tag_enum_path).unwrap();
272    let mut result_map = HashMap::new();
273    let mut raw_string = String::new();
274    f.read_to_string(&mut raw_string).unwrap();
275    let ast: syn::File = syn::parse_str(&raw_string).unwrap();
276    let account_tag_enum = find_enum(&ast, Some("AccountTag"));
277    let enum_variants = get_enum_variants(account_tag_enum);
278    for (i, Variant { ident, .. }) in enum_variants.into_iter().enumerate() {
279        let module_name = pascal_to_snake(&ident.to_string());
280        result_map.insert(module_name, i);
281    }
282    result_map
283}
284
285pub fn get_header(target_lang: TargetLang) -> String {
286    match target_lang {
287        TargetLang::Javascript => include_str!("templates/template.ts").to_string(),
288        TargetLang::Python => include_str!("templates/template.py").to_string(),
289        TargetLang::AnchorIdl => String::new(),
290    }
291}
292
293#[allow(dead_code)]
294fn get_simple_type(ty: &Type) -> String {
295    match ty {
296        Type::Path(TypePath {
297            qself: _,
298            path: Path {
299                leading_colon: _,
300                segments,
301            },
302        }) => segments.iter().next().unwrap().ident.to_string(),
303        _ => unimplemented!(),
304    }
305}
306
307fn padding_len(ty: &Type) -> u8 {
308    match ty {
309        Type::Path(TypePath {
310            path: Path { segments, .. },
311            ..
312        }) => {
313            let simple_type = segments.iter().next().unwrap().ident.to_string();
314            match simple_type.as_ref() {
315                "u8" => 1,
316                "u16" => 2,
317                "u32" => 4,
318                "u64" => 8,
319                "u128" => 16,
320                _ => unimplemented!(), // padding should be of types given above
321            }
322        }
323        Type::Array(TypeArray {
324            elem,
325            len: Expr::Lit(ExprLit {
326                lit: Lit::Int(l), ..
327            }),
328            ..
329        }) => padding_len(elem) * l.base10_parse::<u8>().unwrap(),
330        _ => unimplemented!(),
331    }
332}
333
334fn snake_to_camel(s: &str) -> String {
335    s.from_case(Case::Snake).to_case(Case::Camel)
336}
337fn snake_to_pascal(s: &str) -> String {
338    s.from_case(Case::Snake).to_case(Case::Pascal)
339}
340fn pascal_to_snake(s: &str) -> String {
341    s.from_case(Case::Pascal)
342        .without_boundaries(&[Boundary::UpperDigit, Boundary::DigitLower])
343        .to_case(Case::Snake)
344}
345fn lower_to_upper(s: &str) -> String {
346    s.from_case(Case::Lower).to_case(Case::Upper)
347}
348
349fn find_struct(file_ast: &syn::File, ident_str: Option<&str>) -> Item {
350    file_ast
351        .items
352        .iter()
353        .find(|a| {
354            if let Item::Struct(ItemStruct { ident, .. }) = a {
355                ident_str.map(|s| *ident == s).unwrap_or(true)
356            } else {
357                false
358            }
359        })
360        .unwrap()
361        .clone()
362}
363
364fn find_enum(file_ast: &syn::File, ident_name: Option<&str>) -> Item {
365    file_ast
366        .items
367        .iter()
368        .find(|a| {
369            if let Item::Enum(i) = a {
370                ident_name.map(|s| i.ident == s).unwrap_or(true)
371            } else {
372                false
373            }
374        })
375        .unwrap()
376        .clone()
377}
378
379fn get_enum_variants(s: Item) -> Punctuated<Variant, Comma> {
380    if let Item::Enum(ItemEnum { variants, .. }) = s {
381        variants
382    } else {
383        unreachable!()
384    }
385}
386
387fn get_struct_fields(s: Item) -> Punctuated<Field, Comma> {
388    if let Item::Struct(ItemStruct {
389        fields: Fields::Named(FieldsNamed { named, .. }),
390        ..
391    }) = s
392    {
393        named
394    } else {
395        unreachable!()
396    }
397}
398
399fn get_constraints(attrs: &[Attribute]) -> (bool, bool) {
400    let mut writable = false;
401    let mut signer = false;
402    for a in attrs {
403        if a.path.is_ident("cons") {
404            let t = if let TokenTree::Group(g) = a.tokens.clone().into_iter().next().unwrap() {
405                g.stream()
406            } else {
407                panic!()
408            };
409
410            for constraint in t.into_iter() {
411                match constraint {
412                    TokenTree::Ident(i) => {
413                        if &i.to_string() == "writable" {
414                            writable = true;
415                        }
416                        if &i.to_string() == "signer" {
417                            signer = true;
418                        }
419                    }
420                    TokenTree::Punct(p) if p.as_char() == ',' => {}
421                    _ => {}
422                }
423            }
424            break;
425        }
426    }
427    (writable, signer)
428}
429
430fn is_slice(ty: &Type) -> bool {
431    if let Type::Reference(TypeReference { elem, .. }) = ty {
432        let ty = *elem.clone();
433        if let Type::Slice(_) = ty {
434            return true;
435        }
436    }
437    false
438}
439
440// fn is_vec(ty: &Type) -> bool {
441//     if let Type::Path(TypePath { path, .. }) = ty {
442//         let seg = path.segments.iter().next().unwrap();
443//         return seg.ident == "Vec";
444//     }
445//     false
446// }
447
448fn is_option(ty: &Type) -> bool {
449    if let Type::Path(TypePath { path, .. }) = ty {
450        let seg = path.segments.iter().next().unwrap();
451        if seg.ident != "Option" {
452            unimplemented!()
453        }
454        return true;
455    }
456    false
457}