Skip to main content

near_kit_macros/
lib.rs

1//! Proc macros for near-kit typed contract interfaces.
2//!
3//! This crate provides the `#[near_kit::contract]` attribute macro for defining
4//! type-safe contract interfaces.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use near_kit::*;
10//! use serde::Serialize;
11//!
12//! #[near_kit::contract]
13//! pub trait Counter {
14//!     fn get_count(&self) -> u64;
15//!     
16//!     #[call]
17//!     fn increment(&mut self);
18//!     
19//!     #[call]
20//!     fn add(&mut self, args: AddArgs);
21//! }
22//!
23//! #[derive(Serialize)]
24//! pub struct AddArgs {
25//!     pub value: u64,
26//! }
27//! ```
28//!
29//! # Per-Method Format Override
30//!
31//! You can override the serialization format for individual methods:
32//!
33//! ```ignore
34//! #[near_kit::contract]  // Default: JSON
35//! pub trait MixedContract {
36//!     fn get_json_data(&self) -> JsonData;  // Uses JSON (default)
37//!     
38//!     #[borsh]  // Override: use Borsh for this method
39//!     fn get_binary_state(&self) -> BinaryState;
40//!     
41//!     #[call]
42//!     #[borsh]  // Override: use Borsh for this call
43//!     fn set_binary_state(&mut self, args: BinaryArgs);
44//! }
45//! ```
46
47use proc_macro::TokenStream;
48use proc_macro2::TokenStream as TokenStream2;
49use quote::{format_ident, quote};
50use syn::{
51    FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, TraitItemFn, Type,
52    parse::{Parse, ParseStream},
53    parse_macro_input,
54    spanned::Spanned,
55};
56
57/// Serialization format for contract methods.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59enum SerializationFormat {
60    #[default]
61    Json,
62    Borsh,
63}
64
65/// Arguments to the `#[contract]` attribute.
66#[derive(Debug, Default)]
67struct ContractArgs {
68    format: SerializationFormat,
69}
70
71impl Parse for ContractArgs {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        if input.is_empty() {
74            return Ok(Self::default());
75        }
76
77        let ident: Ident = input.parse()?;
78        let format = match ident.to_string().as_str() {
79            "json" => SerializationFormat::Json,
80            "borsh" => SerializationFormat::Borsh,
81            other => {
82                return Err(syn::Error::new(
83                    ident.span(),
84                    format!("unknown format '{}', expected 'json' or 'borsh'", other),
85                ));
86            }
87        };
88
89        Ok(Self { format })
90    }
91}
92
93/// Arguments to the `#[call]` attribute.
94#[derive(Debug, Default)]
95struct CallArgs {
96    payable: bool,
97}
98
99impl Parse for CallArgs {
100    fn parse(input: ParseStream) -> syn::Result<Self> {
101        if input.is_empty() {
102            return Ok(Self::default());
103        }
104
105        let ident: Ident = input.parse()?;
106        if ident != "payable" {
107            return Err(syn::Error::new(
108                ident.span(),
109                format!("unknown call option '{}', expected 'payable'", ident),
110            ));
111        }
112
113        Ok(Self { payable: true })
114    }
115}
116
117/// Information about a parsed method.
118#[derive(Debug)]
119struct MethodInfo {
120    name: Ident,
121    is_view: bool,
122    #[allow(dead_code)] // Reserved for future validation
123    is_call: bool,
124    #[allow(dead_code)] // Reserved for payable method handling
125    is_payable: bool,
126    /// Per-method format override (if specified via #[json] or #[borsh])
127    format_override: Option<SerializationFormat>,
128    arg_name: Option<Ident>,
129    arg_type: Option<Type>,
130    return_type: Option<Type>,
131}
132
133/// Parse a method from a trait item.
134fn parse_method(method: &TraitItemFn) -> syn::Result<MethodInfo> {
135    let name = method.sig.ident.clone();
136
137    // Check receiver type
138    let receiver = method.sig.receiver();
139    let (is_view, is_mut) = match receiver {
140        Some(recv) => {
141            if recv.reference.is_some() {
142                (recv.mutability.is_none(), recv.mutability.is_some())
143            } else {
144                return Err(syn::Error::new(
145                    recv.span(),
146                    "contract methods must take &self or &mut self",
147                ));
148            }
149        }
150        None => {
151            return Err(syn::Error::new(
152                method.sig.span(),
153                "contract methods must have a receiver (&self or &mut self)",
154            ));
155        }
156    };
157
158    // Check for #[call] attribute
159    let call_attr = method
160        .attrs
161        .iter()
162        .find(|attr| attr.path().is_ident("call"));
163
164    let (is_call, is_payable) = match call_attr {
165        Some(attr) => {
166            let args: CallArgs = if attr.meta.require_path_only().is_ok() {
167                CallArgs::default()
168            } else {
169                attr.parse_args()?
170            };
171            (true, args.payable)
172        }
173        None => (false, false),
174    };
175
176    // Check for #[json] or #[borsh] format override
177    let format_override = if method.attrs.iter().any(|attr| attr.path().is_ident("json")) {
178        Some(SerializationFormat::Json)
179    } else if method
180        .attrs
181        .iter()
182        .any(|attr| attr.path().is_ident("borsh"))
183    {
184        Some(SerializationFormat::Borsh)
185    } else {
186        None
187    };
188
189    // Validate: view methods should not have #[call]
190    if is_view && is_call {
191        return Err(syn::Error::new(
192            method.sig.span(),
193            "view methods (&self) should not have #[call] attribute",
194        ));
195    }
196
197    // Validate: call methods must have #[call]
198    if is_mut && !is_call {
199        return Err(syn::Error::new(
200            method.sig.span(),
201            "call methods (&mut self) must have #[call] attribute",
202        ));
203    }
204
205    // Parse arguments (excluding self)
206    let mut arg_name = None;
207    let mut arg_type = None;
208    let mut arg_count = 0;
209
210    for arg in &method.sig.inputs {
211        if let FnArg::Typed(pat_type) = arg {
212            arg_count += 1;
213            if arg_count > 1 {
214                return Err(syn::Error::new(
215                    pat_type.span(),
216                    "contract methods can have at most one argument (use a struct for multiple parameters)",
217                ));
218            }
219
220            // Extract argument name
221            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
222                arg_name = Some(pat_ident.ident.clone());
223            }
224            arg_type = Some((*pat_type.ty).clone());
225        }
226    }
227
228    // Parse return type
229    let return_type = match &method.sig.output {
230        ReturnType::Default => None,
231        ReturnType::Type(_, ty) => Some((**ty).clone()),
232    };
233
234    Ok(MethodInfo {
235        name,
236        is_view,
237        is_call,
238        is_payable,
239        format_override,
240        arg_name,
241        arg_type,
242        return_type,
243    })
244}
245
246/// Generate client method for a view function.
247fn generate_view_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
248    let method_name = &method.name;
249    let method_name_str = method_name.to_string();
250
251    // Use method override if present, otherwise contract default
252    let format = method.format_override.unwrap_or(contract_format);
253
254    let return_type = method
255        .return_type
256        .as_ref()
257        .map(|t| quote! { #t })
258        .unwrap_or_else(|| quote! { () });
259
260    // For Borsh format, chain .borsh() to switch to Borsh deserialization
261    let borsh_suffix = match format {
262        SerializationFormat::Json => quote! {},
263        SerializationFormat::Borsh => quote! { .borsh() },
264    };
265
266    // Return type differs based on format
267    let view_return_type = match format {
268        SerializationFormat::Json => quote! { near_kit::ViewCall<#return_type> },
269        SerializationFormat::Borsh => quote! { near_kit::ViewCallBorsh<#return_type> },
270    };
271
272    if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
273        // View with args
274        let args_method = match format {
275            SerializationFormat::Json => quote! { .args(#arg_name) },
276            SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
277        };
278
279        quote! {
280            pub fn #method_name(&self, #arg_name: #arg_type) -> #view_return_type {
281                self.near.view::<#return_type>(&self.contract_id, #method_name_str)
282                    #args_method
283                    #borsh_suffix
284            }
285        }
286    } else {
287        // View without args - for JSON, pass empty object; for Borsh, no args
288        match format {
289            SerializationFormat::Json => {
290                quote! {
291                    pub fn #method_name(&self) -> #view_return_type {
292                        self.near.view::<#return_type>(&self.contract_id, #method_name_str)
293                            .args_raw(b"{}".to_vec())
294                    }
295                }
296            }
297            SerializationFormat::Borsh => {
298                quote! {
299                    pub fn #method_name(&self) -> #view_return_type {
300                        self.near.view::<#return_type>(&self.contract_id, #method_name_str)
301                            .borsh()
302                    }
303                }
304            }
305        }
306    }
307}
308
309/// Generate client method for a call function.
310fn generate_call_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
311    let method_name = &method.name;
312    let method_name_str = method_name.to_string();
313
314    // Use method override if present, otherwise contract default
315    let format = method.format_override.unwrap_or(contract_format);
316
317    if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
318        // Call with args
319        let args_method = match format {
320            SerializationFormat::Json => quote! { .args(#arg_name) },
321            SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
322        };
323
324        quote! {
325            pub fn #method_name(&self, #arg_name: #arg_type) -> near_kit::CallBuilder {
326                self.near.call(&self.contract_id, #method_name_str)
327                    #args_method
328            }
329        }
330    } else {
331        // Call without args - for JSON, pass empty object; for Borsh, no args
332        match format {
333            SerializationFormat::Json => {
334                quote! {
335                    pub fn #method_name(&self) -> near_kit::CallBuilder {
336                        self.near.call(&self.contract_id, #method_name_str)
337                            .args_raw(b"{}".to_vec())
338                    }
339                }
340            }
341            SerializationFormat::Borsh => {
342                quote! {
343                    pub fn #method_name(&self) -> near_kit::CallBuilder {
344                        self.near.call(&self.contract_id, #method_name_str)
345                    }
346                }
347            }
348        }
349    }
350}
351
352/// Generate a static associated function on the contract struct that returns `FunctionCall`.
353///
354/// Only generated for call methods (not view methods). This enables composable
355/// transactions where typed contract calls can be mixed with other actions.
356fn generate_function_call_method(
357    method: &MethodInfo,
358    contract_format: SerializationFormat,
359) -> TokenStream2 {
360    let method_name = &method.name;
361    let method_name_str = method_name.to_string();
362
363    let format = method.format_override.unwrap_or(contract_format);
364
365    if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
366        let args_method = match format {
367            SerializationFormat::Json => quote! { .args(#arg_name) },
368            SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
369        };
370
371        quote! {
372            pub fn #method_name(#arg_name: #arg_type) -> near_kit::FunctionCall {
373                near_kit::FunctionCall::new(#method_name_str)
374                    #args_method
375            }
376        }
377    } else {
378        match format {
379            SerializationFormat::Json => {
380                // Use args_raw to avoid depending on serde_json in expanded code
381                quote! {
382                    pub fn #method_name() -> near_kit::FunctionCall {
383                        near_kit::FunctionCall::new(#method_name_str)
384                            .args_raw(b"{}".to_vec())
385                    }
386                }
387            }
388            SerializationFormat::Borsh => {
389                quote! {
390                    pub fn #method_name() -> near_kit::FunctionCall {
391                        near_kit::FunctionCall::new(#method_name_str)
392                    }
393                }
394            }
395        }
396    }
397}
398
399/// The main contract macro implementation.
400#[proc_macro_attribute]
401pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream {
402    let args = parse_macro_input!(attr as ContractArgs);
403    let input = parse_macro_input!(item as ItemTrait);
404
405    match contract_impl(args, input) {
406        Ok(tokens) => tokens.into(),
407        Err(err) => err.to_compile_error().into(),
408    }
409}
410
411fn contract_impl(args: ContractArgs, input: ItemTrait) -> syn::Result<TokenStream2> {
412    let trait_name = &input.ident;
413    let client_name = format_ident!("{}Client", trait_name);
414    let vis = &input.vis;
415
416    // Reject unsupported trait features
417    if let Some(unsafety) = input.unsafety {
418        return Err(syn::Error::new(
419            unsafety.span(),
420            "#[near_kit::contract] does not support unsafe traits",
421        ));
422    }
423    if let Some(auto_token) = input.auto_token {
424        return Err(syn::Error::new(
425            auto_token.span(),
426            "#[near_kit::contract] does not support auto traits",
427        ));
428    }
429    if !input.generics.params.is_empty() {
430        return Err(syn::Error::new(
431            input.generics.span(),
432            "#[near_kit::contract] does not support generic parameters",
433        ));
434    }
435    if let Some(where_clause) = &input.generics.where_clause {
436        return Err(syn::Error::new(
437            where_clause.span(),
438            "#[near_kit::contract] does not support where clauses",
439        ));
440    }
441    if !input.supertraits.is_empty() {
442        return Err(syn::Error::new(
443            input.supertraits.span(),
444            "#[near_kit::contract] does not support supertraits",
445        ));
446    }
447
448    // Parse all methods, reject non-method items
449    let mut methods = Vec::new();
450    for item in &input.items {
451        match item {
452            TraitItem::Fn(method) => {
453                if method.default.is_some() {
454                    return Err(syn::Error::new(
455                        method.sig.span(),
456                        "#[near_kit::contract] does not support default method implementations",
457                    ));
458                }
459                methods.push(parse_method(method)?);
460            }
461            other => {
462                return Err(syn::Error::new(
463                    other.span(),
464                    "#[near_kit::contract] only supports methods in traits",
465                ));
466            }
467        }
468    }
469
470    // Generate client methods (view → ViewCall, call → CallBuilder)
471    let client_methods: Vec<TokenStream2> = methods
472        .iter()
473        .map(|m| {
474            if m.is_view {
475                generate_view_method(m, args.format)
476            } else {
477                generate_call_method(m, args.format)
478            }
479        })
480        .collect();
481
482    // Generate FunctionCall constructors for call methods only
483    let function_call_methods: Vec<TokenStream2> = methods
484        .iter()
485        .filter(|m| !m.is_view)
486        .map(|m| generate_function_call_method(m, args.format))
487        .collect();
488
489    // Propagate trait-level attributes (doc comments, #[cfg], etc.) to the struct
490    let trait_attrs = &input.attrs;
491
492    // Build the output
493    let expanded = quote! {
494        #(#trait_attrs)*
495        #vis struct #trait_name;
496
497        impl #trait_name {
498            #(#function_call_methods)*
499        }
500
501        // Generated client struct for the simple (non-composed) case.
502        #vis struct #client_name {
503            near: near_kit::Near,
504            contract_id: near_kit::AccountId,
505        }
506
507        impl #client_name {
508            /// Create a new contract client.
509            pub fn new(near: near_kit::Near, contract_id: near_kit::AccountId) -> Self {
510                Self { near, contract_id }
511            }
512
513            /// Get the contract account ID.
514            pub fn contract_id(&self) -> &near_kit::AccountId {
515                &self.contract_id
516            }
517
518            /// Return a new client that uses the given signer for transactions.
519            pub fn with_signer(&self, signer: impl near_kit::Signer + 'static) -> Self {
520                Self {
521                    near: self.near.with_signer(signer),
522                    contract_id: self.contract_id.clone(),
523                }
524            }
525
526            #(#client_methods)*
527        }
528
529        // Implement ContractClient trait for construction via near.contract::<T>()
530        impl near_kit::contract::ContractClient for #client_name {
531            fn new(near: near_kit::Near, contract_id: near_kit::AccountId) -> Self {
532                Self::new(near, contract_id)
533            }
534        }
535
536        // Implement Contract marker trait
537        impl near_kit::Contract for #trait_name {
538            type Client = #client_name;
539        }
540    };
541
542    Ok(expanded)
543}
544
545/// Attribute macro for marking call methods.
546///
547/// This is used internally by `#[near_kit::contract]` traits.
548///
549/// # Examples
550///
551/// ```ignore
552/// #[call]
553/// fn increment(&mut self);
554///
555/// #[call(payable)]
556/// fn donate(&mut self);
557/// ```
558#[proc_macro_attribute]
559pub fn call(_attr: TokenStream, item: TokenStream) -> TokenStream {
560    // This is just a marker attribute - the actual work is done by #[contract]
561    item
562}
563
564/// Attribute macro for specifying JSON serialization format.
565///
566/// Use this to override the contract-level serialization format for a specific method.
567///
568/// # Examples
569///
570/// ```ignore
571/// #[near_kit::contract(borsh)]  // Contract default: Borsh
572/// pub trait MyContract {
573///     #[json]  // Override: this method uses JSON
574///     fn get_json_data(&self) -> JsonData;
575/// }
576/// ```
577#[proc_macro_attribute]
578pub fn json(_attr: TokenStream, item: TokenStream) -> TokenStream {
579    // This is just a marker attribute - the actual work is done by #[contract]
580    item
581}
582
583/// Attribute macro for specifying Borsh serialization format.
584///
585/// Use this to override the contract-level serialization format for a specific method.
586///
587/// # Examples
588///
589/// ```ignore
590/// #[near_kit::contract]  // Contract default: JSON
591/// pub trait MyContract {
592///     #[borsh]  // Override: this method uses Borsh
593///     fn get_binary_state(&self) -> BinaryState;
594///     
595///     #[call]
596///     #[borsh]  // Override: this call uses Borsh
597///     fn set_binary_state(&mut self, args: BinaryArgs);
598/// }
599/// ```
600#[proc_macro_attribute]
601pub fn borsh(_attr: TokenStream, item: TokenStream) -> TokenStream {
602    // This is just a marker attribute - the actual work is done by #[contract]
603    item
604}