1use addr::Addresses;
18use essential_types::{contract::Contract, predicate::Program, PredicateAddress};
19use pint_abi_types::{ContractABI, PredicateABI, StorageVarABI, TupleField, TypeABI, UnionVariant};
20use pint_abi_visit::Nesting;
21use proc_macro::TokenStream;
22use proc_macro2::Span;
23use quote::ToTokens;
24use std::collections::BTreeSet;
25use syn::parse_macro_input;
26
27mod addr;
28mod args;
29mod array;
30mod keys;
31mod macro_args;
32mod map;
33mod mutations;
34mod tuple;
35mod unions;
36mod utils;
37
38const ROOT_MOD_NAME: &str = "";
40
41enum SingleKeyTy {
43    Bool,
44    Int,
45    Real,
46    String,
47    B256,
48    Optional(Box<TypeABI>),
49    Union(String ),
50}
51
52impl SingleKeyTy {
53    fn syn_ty(&self, mod_level: usize) -> syn::Type {
55        match self {
56            SingleKeyTy::Bool => syn::parse_quote!(bool),
57            SingleKeyTy::Int => syn::parse_quote!(i64),
58            SingleKeyTy::Real => syn::parse_quote!(f64),
59            SingleKeyTy::String => syn::parse_quote!(String),
60            SingleKeyTy::B256 => syn::parse_quote!([i64; 4]),
61            SingleKeyTy::Optional(ty) => ty_from_optional(ty, mod_level),
62            SingleKeyTy::Union(name) => unions::ty_from_union(name, mod_level),
63        }
64    }
65}
66
67fn fields_from_tuple_fields(fields: &[TupleField], mod_level: usize) -> Vec<syn::Field> {
69    fields
70        .iter()
71        .map(|TupleField { name, ty }| {
72            let _name = name;
74            let ty = ty_from_pint_ty(ty, mod_level);
75            syn::parse_quote!(#ty)
76        })
77        .collect()
78}
79
80fn ty_from_tuple(tuple: &[TupleField], mod_level: usize) -> syn::Type {
82    let fields = fields_from_tuple_fields(tuple, mod_level);
83    syn::parse_quote! {
84        ( #( #fields ),* )
85    }
86}
87
88fn ty_from_array(ty: &TypeABI, size: i64, mod_level: usize) -> syn::Type {
90    let syn_ty = ty_from_pint_ty(ty, mod_level);
91    let len = usize::try_from(size).expect("array size out of range of `usize`");
92    syn::parse_quote! {
93        [#syn_ty; #len]
94    }
95}
96
97fn ty_from_optional(ty: &TypeABI, mod_level: usize) -> syn::Type {
99    let syn_ty = ty_from_pint_ty(ty, mod_level);
100    syn::parse_quote! {
101        Option<#syn_ty>
102    }
103}
104
105fn ty_from_pint_ty(ty: &TypeABI, mod_level: usize) -> syn::Type {
107    match ty {
108        TypeABI::Bool => syn::parse_quote!(bool),
109        TypeABI::Int => syn::parse_quote!(i64),
110        TypeABI::Real => syn::parse_quote!(f64),
111        TypeABI::String => syn::parse_quote!(String),
112        TypeABI::B256 => syn::parse_quote!([i64; 4]),
113        TypeABI::Optional(ty) => ty_from_optional(ty, mod_level),
114        TypeABI::Union { name, .. } => unions::ty_from_union(name, mod_level),
115        TypeABI::Tuple(tuple) => ty_from_tuple(tuple, mod_level),
116        TypeABI::Array { ty, size } => ty_from_array(ty, *size, mod_level),
117        TypeABI::Map { .. } => unreachable!("Maps are not allowed as non-storage types"),
118    }
119}
120
121fn strip_colons_prefix(name: &str) -> &str {
124    name.trim_start_matches("::")
125}
126
127fn field_name_from_var_name(name: &str) -> String {
131    strip_colons_prefix(name)
132        .replace(['.', '@'], "_")
133        .replace("::", "_")
134}
135
136fn items_from_predicate(
138    predicate: &PredicateABI,
139    addr: Option<&PredicateAddress>,
140) -> Vec<syn::Item> {
141    let mut items = vec![];
142    if let Some(addr) = addr {
143        items.push(addr::predicate_const(&addr.contract, &addr.predicate).into());
144    }
145    if !predicate.params.is_empty() {
146        items.extend(args::items(&predicate.params));
147    }
148    items
149}
150
151fn mod_from_predicate(
154    name: &str,
155    predicate: &PredicateABI,
156    addr: Option<PredicateAddress>,
157) -> syn::ItemMod {
158    let doc_str = format!("Items for the `{name}` predicate.");
159    let ident = syn::Ident::new(name, Span::call_site());
160    let items = items_from_predicate(predicate, addr.as_ref());
161    syn::parse_quote! {
162        #[allow(non_snake_case)]
163        #[doc = #doc_str]
164        pub mod #ident {
165            #(
166                #items
167            )*
168        }
169    }
170}
171
172fn is_predicate_empty(pred: &PredicateABI) -> bool {
174    pred.params.is_empty()
175}
176
177fn predicates_with_addrs<'a>(
179    predicates: &'a [PredicateABI],
180    addrs: Option<&'a Addresses>,
181) -> impl 'a + Iterator<Item = (&'a PredicateABI, Option<PredicateAddress>)> {
182    predicates.iter().enumerate().map(move |(ix, predicate)| {
183        let addr = addrs.map(|addrs| PredicateAddress {
184            contract: addrs.contract.clone(),
185            predicate: addrs.predicates[ix].clone(),
186        });
187        (predicate, addr)
188    })
189}
190
191fn mods_from_named_predicates(
193    predicates: &[PredicateABI],
194    addrs: Option<&Addresses>,
195) -> Vec<syn::ItemMod> {
196    predicates_with_addrs(predicates, addrs)
197        .filter(|(predicate, addr)| !is_predicate_empty(predicate) || addr.is_some())
198        .filter(|(predicate, _)| predicate.name != ROOT_MOD_NAME)
199        .map(|(predicate, addr)| {
200            let name = strip_colons_prefix(&predicate.name);
201            mod_from_predicate(name, predicate, addr)
202        })
203        .collect()
204}
205
206fn find_root_predicate<'a>(
208    predicates: &'a [PredicateABI],
209    addrs: Option<&'a Addresses>,
210) -> Option<(&'a PredicateABI, Option<PredicateAddress>)> {
211    predicates_with_addrs(predicates, addrs).find(|(predicate, _)| predicate.name == ROOT_MOD_NAME)
212}
213
214fn items_from_predicates(predicates: &[PredicateABI], addrs: Option<&Addresses>) -> Vec<syn::Item> {
218    let mut items = vec![];
219    if let Some((root_pred, addr)) = find_root_predicate(predicates, addrs) {
221        let name = "root";
223        items.push(mod_from_predicate(name, root_pred, addr).into());
224    }
225    items.extend(
227        mods_from_named_predicates(predicates, addrs)
228            .into_iter()
229            .map(syn::Item::from),
230    );
231    items
232}
233
234fn nesting_expr(nesting: &[Nesting]) -> syn::ExprArray {
236    let elems = nesting
237        .iter()
238        .map(|n| {
239            let expr: syn::Expr = match n {
240                Nesting::Var { ix } => {
241                    syn::parse_quote!(pint_abi::key::Nesting::Var { ix: #ix })
242                }
243                Nesting::TupleField { ix: _, flat_ix } => {
244                    syn::parse_quote!(pint_abi::key::Nesting::TupleField { flat_ix: #flat_ix })
245                }
246                Nesting::MapEntry { key_size: _ } => {
247                    syn::parse_quote!(pint_abi::key::Nesting::MapEntry)
248                }
249                Nesting::ArrayElem { elem_len, .. } => {
250                    syn::parse_quote!(pint_abi::key::Nesting::ArrayElem { elem_len: #elem_len })
251                }
252            };
253            expr
254        })
255        .collect();
256    syn::ExprArray {
257        attrs: vec![],
258        bracket_token: Default::default(),
259        elems,
260    }
261}
262
263fn construct_key_expr() -> syn::Expr {
266    syn::parse_quote! {
267        pint_abi::key::construct(&nesting[..], &self.key_elems[..])
268    }
269}
270
271fn nesting_ty_str<'a>(nesting: impl IntoIterator<Item = &'a Nesting>) -> String {
278    fn elem_str(nesting: &Nesting) -> String {
279        match nesting {
280            Nesting::Var { ix } => ix.to_string(),
281            Nesting::TupleField { ix, .. } => ix.to_string(),
282            Nesting::MapEntry { .. } => "MapEntry".to_string(),
283            Nesting::ArrayElem { .. } => "ArrayElem".to_string(),
284        }
285    }
286    let mut iter = nesting.into_iter();
287    let mut s = elem_str(iter.next().expect("nesting must contain at least one item"));
288    for n in iter {
289        use std::fmt::Write;
290        write!(&mut s, "_{}", elem_str(n)).expect("failed to fmt nesting ty str");
291    }
292    s
293}
294
295fn nesting_key_doc_str(nesting: &[Nesting]) -> String {
299    use core::fmt::Write;
300    let partial_key = pint_abi_visit::partial_key_from_nesting(nesting);
301    let mut s = "[".to_string();
302    let mut opts = partial_key.iter();
303    fn write_opt(s: &mut String, opt: &Option<i64>) {
304        match opt {
305            None => write!(s, "_"),
306            Some(u) => write!(s, "{u}"),
307        }
308        .expect("failed to write key element to string")
309    }
310    if let Some(opt) = opts.next() {
311        write_opt(&mut s, opt);
312        for opt in opts {
313            write!(&mut s, ", ").unwrap();
314            write_opt(&mut s, opt);
315        }
316    }
317    write!(&mut s, "]").unwrap();
318    s
319}
320
321fn items_from_keyed_vars(vars: &[StorageVarABI]) -> Vec<syn::Item> {
325    let mut items = vec![];
326
327    items.push(mutations::module(vars).into());
329    items.push(syn::parse_quote! {
330        #[doc(inline)]
331        pub use mutations::{mutations, Mutations};
332    });
333
334    items.push(keys::module(vars).into());
336    items.push(syn::parse_quote! {
337        #[doc(inline)]
338        pub use keys::{keys, Keys};
339    });
340
341    items
342}
343
344fn mod_from_keyed_vars(mod_name: &str, vars: &[StorageVarABI]) -> syn::ItemMod {
348    let items = items_from_keyed_vars(vars);
349    let mod_ident = syn::Ident::new(mod_name, Span::call_site());
350    syn::parse_quote! {
351        pub mod #mod_ident {
352            #(
371                #items
372            )*
373        }
374    }
375}
376
377fn items_from_abi_and_addrs(abi: &ContractABI, addrs: Option<&Addresses>) -> Vec<syn::Item> {
379    let mut items = vec![];
380
381    let mut unions: BTreeSet<(Vec<String>, Vec<UnionVariant>)> = BTreeSet::new();
383    abi.storage
384        .iter()
385        .map(|var| var.ty.clone())
386        .chain(
387            abi.predicates
388                .iter()
389                .flat_map(|predicate| predicate.params.iter().map(|param| param.ty.clone())),
390        )
391        .for_each(|ty| unions::collect_unions(&ty, &mut unions));
392
393    items.extend(unions::items_from_unions(&unions));
394
395    items.extend(items_from_predicates(&abi.predicates, addrs));
396    if let Some(addrs) = addrs {
397        items.push(addr::contract_const(&addrs.contract).into());
398    }
399    if !abi.storage.is_empty() {
400        items.push(mod_from_keyed_vars("storage", &abi.storage).into());
401    }
402    items
403}
404
405fn tokens_from_abi_and_addrs(abi: &ContractABI, addrs: Option<&Addresses>) -> TokenStream {
407    let items = items_from_abi_and_addrs(abi, addrs);
408    items
409        .into_iter()
410        .map(|item| TokenStream::from(item.into_token_stream()))
411        .collect()
412}
413
414fn resolve_path(path: &std::path::Path) -> std::path::PathBuf {
417    if path.is_relative() {
418        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
419            .expect("`CARGO_MANIFEST_DIR` not set, but required for relative path expansion");
420        let manifest_dir_path = std::path::Path::new(&manifest_dir);
421        manifest_dir_path.join(path)
422    } else {
423        path.to_path_buf()
424    }
425}
426
427#[proc_macro]
429pub fn from_str(input: TokenStream) -> TokenStream {
430    let input_lit_str = parse_macro_input!(input as syn::LitStr);
431    let string = input_lit_str.value();
432    let abi: ContractABI = serde_json::from_str(&string)
433        .expect("failed to deserialize str from JSON to `ContractABI`");
434    tokens_from_abi_and_addrs(&abi, None)
435}
436
437fn read_from_json_file<T>(path: &std::path::Path) -> Result<T, serde_json::Error>
439where
440    T: for<'de> serde::Deserialize<'de>,
441{
442    let file = std::fs::File::open(path).unwrap_or_else(|err| {
443        panic!("failed to open {path:?}: {err}");
444    });
445    let reader = std::io::BufReader::new(file);
446    serde_json::from_reader(reader)
447}
448
449#[proc_macro]
477pub fn from_file(input: TokenStream) -> TokenStream {
478    let args = parse_macro_input!(input as macro_args::FromFile);
479
480    let abi_path = resolve_path(args.abi.value().as_ref());
482    let abi: ContractABI = read_from_json_file(&abi_path).unwrap_or_else(|err| {
483        panic!("failed to deserialize {abi_path:?} from JSON to `ContractABI`: {err}");
484    });
485
486    let contract: Option<Contract> = args
488        .contract
489        .map(|contract| {
490            let contract_path = resolve_path(contract.value().as_ref());
491            read_from_json_file::<(Contract, Vec<Program>)>(&contract_path).unwrap_or_else(|err| {
492                panic!("failed to deserialize {contract_path:?} from JSON to `Contract`: {err}");
493            })
494        })
495        .map(|(contract, _)| contract);
496
497    let addrs = contract.as_ref().map(Addresses::from);
499
500    tokens_from_abi_and_addrs(&abi, addrs.as_ref())
501}