1use proc_macro::TokenStream;
48use proc_macro2::TokenStream as TokenStream2;
49use quote::{format_ident, quote};
50use syn::{
51 FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, TraitItemFn, Type,
52 parse::{Parse, ParseStream},
53 parse_macro_input,
54 spanned::Spanned,
55};
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59enum SerializationFormat {
60 #[default]
61 Json,
62 Borsh,
63}
64
65#[derive(Debug, Default)]
67struct ContractArgs {
68 format: SerializationFormat,
69}
70
71impl Parse for ContractArgs {
72 fn parse(input: ParseStream) -> syn::Result<Self> {
73 if input.is_empty() {
74 return Ok(Self::default());
75 }
76
77 let ident: Ident = input.parse()?;
78 let format = match ident.to_string().as_str() {
79 "json" => SerializationFormat::Json,
80 "borsh" => SerializationFormat::Borsh,
81 other => {
82 return Err(syn::Error::new(
83 ident.span(),
84 format!("unknown format '{}', expected 'json' or 'borsh'", other),
85 ));
86 }
87 };
88
89 Ok(Self { format })
90 }
91}
92
93#[derive(Debug, Default)]
95struct CallArgs {
96 payable: bool,
97}
98
99impl Parse for CallArgs {
100 fn parse(input: ParseStream) -> syn::Result<Self> {
101 if input.is_empty() {
102 return Ok(Self::default());
103 }
104
105 let ident: Ident = input.parse()?;
106 if ident != "payable" {
107 return Err(syn::Error::new(
108 ident.span(),
109 format!("unknown call option '{}', expected 'payable'", ident),
110 ));
111 }
112
113 Ok(Self { payable: true })
114 }
115}
116
117#[derive(Debug)]
119struct MethodInfo {
120 name: Ident,
121 is_view: bool,
122 #[allow(dead_code)] is_call: bool,
124 #[allow(dead_code)] is_payable: bool,
126 format_override: Option<SerializationFormat>,
128 arg_name: Option<Ident>,
129 arg_type: Option<Type>,
130 return_type: Option<Type>,
131}
132
133fn parse_method(method: &TraitItemFn) -> syn::Result<MethodInfo> {
135 let name = method.sig.ident.clone();
136
137 let receiver = method.sig.receiver();
139 let (is_view, is_mut) = match receiver {
140 Some(recv) => {
141 if recv.reference.is_some() {
142 (recv.mutability.is_none(), recv.mutability.is_some())
143 } else {
144 return Err(syn::Error::new(
145 recv.span(),
146 "contract methods must take &self or &mut self",
147 ));
148 }
149 }
150 None => {
151 return Err(syn::Error::new(
152 method.sig.span(),
153 "contract methods must have a receiver (&self or &mut self)",
154 ));
155 }
156 };
157
158 let call_attr = method
160 .attrs
161 .iter()
162 .find(|attr| attr.path().is_ident("call"));
163
164 let (is_call, is_payable) = match call_attr {
165 Some(attr) => {
166 let args: CallArgs = if attr.meta.require_path_only().is_ok() {
167 CallArgs::default()
168 } else {
169 attr.parse_args()?
170 };
171 (true, args.payable)
172 }
173 None => (false, false),
174 };
175
176 let format_override = if method.attrs.iter().any(|attr| attr.path().is_ident("json")) {
178 Some(SerializationFormat::Json)
179 } else if method
180 .attrs
181 .iter()
182 .any(|attr| attr.path().is_ident("borsh"))
183 {
184 Some(SerializationFormat::Borsh)
185 } else {
186 None
187 };
188
189 if is_view && is_call {
191 return Err(syn::Error::new(
192 method.sig.span(),
193 "view methods (&self) should not have #[call] attribute",
194 ));
195 }
196
197 if is_mut && !is_call {
199 return Err(syn::Error::new(
200 method.sig.span(),
201 "call methods (&mut self) must have #[call] attribute",
202 ));
203 }
204
205 let mut arg_name = None;
207 let mut arg_type = None;
208 let mut arg_count = 0;
209
210 for arg in &method.sig.inputs {
211 if let FnArg::Typed(pat_type) = arg {
212 arg_count += 1;
213 if arg_count > 1 {
214 return Err(syn::Error::new(
215 pat_type.span(),
216 "contract methods can have at most one argument (use a struct for multiple parameters)",
217 ));
218 }
219
220 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
222 arg_name = Some(pat_ident.ident.clone());
223 }
224 arg_type = Some((*pat_type.ty).clone());
225 }
226 }
227
228 let return_type = match &method.sig.output {
230 ReturnType::Default => None,
231 ReturnType::Type(_, ty) => Some((**ty).clone()),
232 };
233
234 Ok(MethodInfo {
235 name,
236 is_view,
237 is_call,
238 is_payable,
239 format_override,
240 arg_name,
241 arg_type,
242 return_type,
243 })
244}
245
246fn generate_view_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
248 let method_name = &method.name;
249 let method_name_str = method_name.to_string();
250
251 let format = method.format_override.unwrap_or(contract_format);
253
254 let return_type = method
255 .return_type
256 .as_ref()
257 .map(|t| quote! { #t })
258 .unwrap_or_else(|| quote! { () });
259
260 let borsh_suffix = match format {
262 SerializationFormat::Json => quote! {},
263 SerializationFormat::Borsh => quote! { .borsh() },
264 };
265
266 let view_return_type = match format {
268 SerializationFormat::Json => quote! { near_kit::ViewCall<#return_type> },
269 SerializationFormat::Borsh => quote! { near_kit::ViewCallBorsh<#return_type> },
270 };
271
272 if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
273 let args_method = match format {
275 SerializationFormat::Json => quote! { .args(#arg_name) },
276 SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
277 };
278
279 quote! {
280 pub fn #method_name(&self, #arg_name: #arg_type) -> #view_return_type {
281 self.near.view::<#return_type>(&self.contract_id, #method_name_str)
282 #args_method
283 #borsh_suffix
284 }
285 }
286 } else {
287 match format {
289 SerializationFormat::Json => {
290 quote! {
291 pub fn #method_name(&self) -> #view_return_type {
292 self.near.view::<#return_type>(&self.contract_id, #method_name_str)
293 .args(serde_json::json!({}))
294 }
295 }
296 }
297 SerializationFormat::Borsh => {
298 quote! {
299 pub fn #method_name(&self) -> #view_return_type {
300 self.near.view::<#return_type>(&self.contract_id, #method_name_str)
301 .borsh()
302 }
303 }
304 }
305 }
306 }
307}
308
309fn generate_call_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
311 let method_name = &method.name;
312 let method_name_str = method_name.to_string();
313
314 let format = method.format_override.unwrap_or(contract_format);
316
317 if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
318 let args_method = match format {
320 SerializationFormat::Json => quote! { .args(#arg_name) },
321 SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
322 };
323
324 quote! {
325 pub fn #method_name(&self, #arg_name: #arg_type) -> near_kit::CallBuilder {
326 self.near.call(&self.contract_id, #method_name_str)
327 #args_method
328 }
329 }
330 } else {
331 match format {
333 SerializationFormat::Json => {
334 quote! {
335 pub fn #method_name(&self) -> near_kit::CallBuilder {
336 self.near.call(&self.contract_id, #method_name_str)
337 .args(serde_json::json!({}))
338 }
339 }
340 }
341 SerializationFormat::Borsh => {
342 quote! {
343 pub fn #method_name(&self) -> near_kit::CallBuilder {
344 self.near.call(&self.contract_id, #method_name_str)
345 }
346 }
347 }
348 }
349 }
350}
351
352fn strip_internal_attrs(method: &TraitItemFn) -> TraitItemFn {
354 let mut method = method.clone();
355 method.attrs.retain(|attr| {
356 !attr.path().is_ident("call")
357 && !attr.path().is_ident("json")
358 && !attr.path().is_ident("borsh")
359 });
360 method
361}
362
363#[proc_macro_attribute]
365pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream {
366 let args = parse_macro_input!(attr as ContractArgs);
367 let input = parse_macro_input!(item as ItemTrait);
368
369 match contract_impl(args, input) {
370 Ok(tokens) => tokens.into(),
371 Err(err) => err.to_compile_error().into(),
372 }
373}
374
375fn contract_impl(args: ContractArgs, input: ItemTrait) -> syn::Result<TokenStream2> {
376 let trait_name = &input.ident;
377 let client_name = format_ident!("{}Client", trait_name);
378 let vis = &input.vis;
379
380 let mut methods = Vec::new();
382 for item in &input.items {
383 if let TraitItem::Fn(method) = item {
384 methods.push(parse_method(method)?);
385 }
386 }
387
388 let client_methods: Vec<TokenStream2> = methods
390 .iter()
391 .map(|m| {
392 if m.is_view {
393 generate_view_method(m, args.format)
394 } else {
395 generate_call_method(m, args.format)
396 }
397 })
398 .collect();
399
400 let cleaned_items: Vec<TraitItem> = input
402 .items
403 .iter()
404 .map(|item| {
405 if let TraitItem::Fn(method) = item {
406 TraitItem::Fn(strip_internal_attrs(method))
407 } else {
408 item.clone()
409 }
410 })
411 .collect();
412
413 let trait_attrs = &input.attrs;
414 let trait_supertraits = &input.supertraits;
415 let trait_generics = &input.generics;
416
417 let expanded = quote! {
419 #[allow(dead_code)]
423 #(#trait_attrs)*
424 #vis trait #trait_name #trait_generics : #trait_supertraits {
425 #(#cleaned_items)*
426 }
427
428 #vis struct #client_name<'a> {
430 near: &'a near_kit::Near,
431 contract_id: near_kit::AccountId,
432 }
433
434 impl<'a> #client_name<'a> {
435 pub fn new(near: &'a near_kit::Near, contract_id: near_kit::AccountId) -> Self {
437 Self { near, contract_id }
438 }
439
440 pub fn contract_id(&self) -> &near_kit::AccountId {
442 &self.contract_id
443 }
444
445 #(#client_methods)*
446 }
447
448 impl<'a> near_kit::contract::ContractClient<'a> for #client_name<'a> {
450 fn new(near: &'a near_kit::Near, contract_id: near_kit::AccountId) -> Self {
451 Self { near, contract_id }
452 }
453 }
454
455 impl near_kit::Contract for dyn #trait_name {
457 type Client<'a> = #client_name<'a>;
458 }
459 };
460
461 Ok(expanded)
462}
463
464#[proc_macro_attribute]
478pub fn call(_attr: TokenStream, item: TokenStream) -> TokenStream {
479 item
481}
482
483#[proc_macro_attribute]
497pub fn json(_attr: TokenStream, item: TokenStream) -> TokenStream {
498 item
500}
501
502#[proc_macro_attribute]
520pub fn borsh(_attr: TokenStream, item: TokenStream) -> TokenStream {
521 item
523}