nitka_proc/
lib.rs

1extern crate proc_macro;
2
3use std::str::FromStr;
4
5use proc_macro::TokenStream;
6use quote::{quote, ToTokens};
7use syn::{
8    __private::TokenStream2, parse_macro_input, parse_quote, parse_str, FnArg, Ident, ItemFn, ItemTrait, Pat,
9    ReturnType, TraitItem, TraitItemFn,
10};
11
12/// Create interface trait suitable for usage in integration tests
13#[proc_macro_attribute]
14pub fn make_integration_version(_args: TokenStream, stream: TokenStream) -> TokenStream {
15    let mut input = parse_macro_input!(stream as ItemTrait);
16
17    let crate_name = std::env::var("CARGO_PKG_NAME")
18        .unwrap_or_else(|_| panic!("let crate_name = std::env::var(\"CARGO_PKG_NAME\")."));
19
20    let Some(index) = crate_name.find("-model") else {
21        panic!("Some(index) = crate_name.find(");
22    };
23
24    let contract_name = to_camel_case(&crate_name[..index]);
25
26    let trait_name = &input.ident;
27
28    let contract = Ident::new(&format!("{contract_name}Contract"), trait_name.span());
29
30    let integration_trait_name = Ident::new(&format!("{trait_name}Integration"), trait_name.span());
31
32    let mut integration_trait_methods: Vec<TraitItemFn> = input
33        .items
34        .iter_mut()
35        .filter_map(|item| {
36            if let TraitItem::Fn(method) = item {
37                let async_method = convert_method_to_integration_trait(method);
38                Some(async_method)
39            } else {
40                None
41            }
42        })
43        .collect();
44
45    let implementation_methods: Vec<ItemFn> = integration_trait_methods
46        .iter_mut()
47        .map(convert_method_to_implementation)
48        .collect();
49
50    quote! {
51
52        #[cfg(not(feature = "integration-api"))]
53        #input
54
55        #[cfg(feature = "integration-api")]
56        pub trait #integration_trait_name {
57            #(#integration_trait_methods)*
58        }
59
60        #[cfg(feature = "integration-api")]
61        impl #integration_trait_name for #contract<'_> {
62            #(#implementation_methods)*
63        }
64    }
65    .into()
66}
67
68fn convert_method_to_implementation(trait_method: &mut TraitItemFn) -> ItemFn {
69    let fn_name = trait_method.sig.ident.clone();
70    let fn_args = trait_method.sig.inputs.clone();
71    let fn_ret = trait_method.sig.output.clone();
72
73    let fn_name_str = TokenStream2::from_str(&format!("\"{fn_name}\"")).expect("Failed to extract method name");
74
75    let call_args = if fn_args.len() > 1 {
76        let mut args_quote = quote!();
77
78        for arg in fn_args.iter().skip(1) {
79            let FnArg::Typed(arg) = arg else {
80                panic!("FnArg::Typed(arg) = arg");
81            };
82
83            let Pat::Ident(pat_ident) = &*arg.pat else {
84                panic!("Pat::Ident(ident) = &arg.pat");
85            };
86
87            let ident = &pat_ident.ident;
88
89            let string_ident = TokenStream2::from_str(&format!("\"{ident}\"")).expect("Failed to extract method name");
90
91            args_quote = quote! {
92                #args_quote
93                #string_ident : #ident,
94            }
95        }
96
97        quote! {
98            .args_json(near_sdk::serde_json::json!({
99                #args_quote
100            })).unwrap()
101        }
102    } else {
103        quote!()
104    };
105
106    let deposit = if let Some(attr) = trait_method.attrs.first() {
107        let attr = attr.path().to_token_stream().to_string();
108
109        match attr.as_str() {
110            "deposit_one_yocto" => quote! {
111                .deposit(near_workspaces::types::NearToken::from_yoctonear(1))
112            },
113            "deposit_yocto" => {
114                let mut attr = trait_method.attrs.first().unwrap().to_token_stream().to_string();
115
116                attr.pop().unwrap();
117
118                let index = attr.find("= ").unwrap() + 2;
119
120                let attr = &attr[index..];
121
122                let deposit_value = TokenStream2::from_str(attr).unwrap();
123
124                quote! {
125                    .deposit(near_workspaces::types::NearToken::from_yoctonear(#deposit_value))
126                }
127            }
128            _ => quote!(),
129        }
130    } else {
131        quote!()
132    };
133
134    trait_method.attrs = vec![];
135
136    let result: ItemFn = parse_quote!(
137        fn #fn_name(#fn_args) #fn_ret {
138            nitka::integration_contract::make_call(self.contract, #fn_name_str) #deposit #call_args
139        }
140    );
141
142    result.clone()
143}
144
145fn convert_method_to_integration_trait(trait_method: &mut TraitItemFn) -> TraitItemFn {
146    let mut method = trait_method.clone();
147
148    let mut ret = if matches!(method.sig.output, ReturnType::Default) {
149        "()".to_string()
150    } else {
151        let ret = method.sig.output.to_token_stream().to_string();
152        let ret = ret.strip_prefix("-> ").unwrap();
153        ret.to_string()
154    };
155
156    let self_arg: FnArg = parse_str("&self").unwrap();
157
158    if ret == "Self" {
159        method.sig.inputs.insert(0, self_arg);
160        ret = "()".to_string();
161    } else {
162        method.sig.inputs[0] = self_arg;
163    }
164    
165    if ret.starts_with(":: near_sdk :: PromiseOrValue <") {
166        let start = ret.find('<').unwrap();
167        let end = ret.find('>').unwrap();
168
169        ret = ret[start + 1..end].to_string();
170    }
171
172    let ret: Result<ReturnType, _> = parse_str(&format!("-> nitka::ContractCall<{ret}>"));
173
174    method.sig.output = ret.unwrap();
175
176    if let Some(attr) = method.attrs.first() {
177        let attr = attr.path().to_token_stream().to_string();
178        trait_method.attrs = vec![];
179        if attr.as_str() == "update" {
180            method.sig.inputs.push(parse_str("code: Vec<u8>").unwrap());
181        }
182    }
183
184    method
185}
186
187fn to_camel_case(input: &str) -> String {
188    input
189        .split('-')
190        .map(|word| {
191            let mut chars = word.chars();
192            match chars.next() {
193                None => String::new(),
194                Some(first_char) => first_char.to_uppercase().collect::<String>() + chars.as_str(),
195            }
196        })
197        .collect()
198}