Skip to main content

near_kit_macros/
lib.rs

1//! Proc macros for near-kit typed contract interfaces.
2//!
3//! This crate provides the `#[near_kit::contract]` attribute macro for defining
4//! type-safe contract interfaces.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use near_kit::*;
10//! use serde::Serialize;
11//!
12//! #[near_kit::contract]
13//! pub trait Counter {
14//!     fn get_count(&self) -> u64;
15//!     
16//!     #[call]
17//!     fn increment(&mut self);
18//!     
19//!     #[call]
20//!     fn add(&mut self, args: AddArgs);
21//! }
22//!
23//! #[derive(Serialize)]
24//! pub struct AddArgs {
25//!     pub value: u64,
26//! }
27//! ```
28//!
29//! # Per-Method Format Override
30//!
31//! You can override the serialization format for individual methods:
32//!
33//! ```ignore
34//! #[near_kit::contract]  // Default: JSON
35//! pub trait MixedContract {
36//!     fn get_json_data(&self) -> JsonData;  // Uses JSON (default)
37//!     
38//!     #[borsh]  // Override: use Borsh for this method
39//!     fn get_binary_state(&self) -> BinaryState;
40//!     
41//!     #[call]
42//!     #[borsh]  // Override: use Borsh for this call
43//!     fn set_binary_state(&mut self, args: BinaryArgs);
44//! }
45//! ```
46
47use proc_macro::TokenStream;
48use proc_macro2::TokenStream as TokenStream2;
49use quote::{format_ident, quote};
50use syn::{
51    FnArg, Ident, ItemTrait, Pat, ReturnType, TraitItem, TraitItemFn, Type,
52    parse::{Parse, ParseStream},
53    parse_macro_input,
54    spanned::Spanned,
55};
56
57/// Serialization format for contract methods.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59enum SerializationFormat {
60    #[default]
61    Json,
62    Borsh,
63}
64
65/// Arguments to the `#[contract]` attribute.
66#[derive(Debug, Default)]
67struct ContractArgs {
68    format: SerializationFormat,
69}
70
71impl Parse for ContractArgs {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        if input.is_empty() {
74            return Ok(Self::default());
75        }
76
77        let ident: Ident = input.parse()?;
78        let format = match ident.to_string().as_str() {
79            "json" => SerializationFormat::Json,
80            "borsh" => SerializationFormat::Borsh,
81            other => {
82                return Err(syn::Error::new(
83                    ident.span(),
84                    format!("unknown format '{}', expected 'json' or 'borsh'", other),
85                ));
86            }
87        };
88
89        Ok(Self { format })
90    }
91}
92
93/// Arguments to the `#[call]` attribute.
94#[derive(Debug, Default)]
95struct CallArgs {
96    payable: bool,
97}
98
99impl Parse for CallArgs {
100    fn parse(input: ParseStream) -> syn::Result<Self> {
101        if input.is_empty() {
102            return Ok(Self::default());
103        }
104
105        let ident: Ident = input.parse()?;
106        if ident != "payable" {
107            return Err(syn::Error::new(
108                ident.span(),
109                format!("unknown call option '{}', expected 'payable'", ident),
110            ));
111        }
112
113        Ok(Self { payable: true })
114    }
115}
116
117/// Information about a parsed method.
118#[derive(Debug)]
119struct MethodInfo {
120    name: Ident,
121    is_view: bool,
122    #[allow(dead_code)] // Reserved for future validation
123    is_call: bool,
124    #[allow(dead_code)] // Reserved for payable method handling
125    is_payable: bool,
126    /// Per-method format override (if specified via #[json] or #[borsh])
127    format_override: Option<SerializationFormat>,
128    arg_name: Option<Ident>,
129    arg_type: Option<Type>,
130    return_type: Option<Type>,
131}
132
133/// Parse a method from a trait item.
134fn parse_method(method: &TraitItemFn) -> syn::Result<MethodInfo> {
135    let name = method.sig.ident.clone();
136
137    // Check receiver type
138    let receiver = method.sig.receiver();
139    let (is_view, is_mut) = match receiver {
140        Some(recv) => {
141            if recv.reference.is_some() {
142                (recv.mutability.is_none(), recv.mutability.is_some())
143            } else {
144                return Err(syn::Error::new(
145                    recv.span(),
146                    "contract methods must take &self or &mut self",
147                ));
148            }
149        }
150        None => {
151            return Err(syn::Error::new(
152                method.sig.span(),
153                "contract methods must have a receiver (&self or &mut self)",
154            ));
155        }
156    };
157
158    // Check for #[call] attribute
159    let call_attr = method
160        .attrs
161        .iter()
162        .find(|attr| attr.path().is_ident("call"));
163
164    let (is_call, is_payable) = match call_attr {
165        Some(attr) => {
166            let args: CallArgs = if attr.meta.require_path_only().is_ok() {
167                CallArgs::default()
168            } else {
169                attr.parse_args()?
170            };
171            (true, args.payable)
172        }
173        None => (false, false),
174    };
175
176    // Check for #[json] or #[borsh] format override
177    let format_override = if method.attrs.iter().any(|attr| attr.path().is_ident("json")) {
178        Some(SerializationFormat::Json)
179    } else if method
180        .attrs
181        .iter()
182        .any(|attr| attr.path().is_ident("borsh"))
183    {
184        Some(SerializationFormat::Borsh)
185    } else {
186        None
187    };
188
189    // Validate: view methods should not have #[call]
190    if is_view && is_call {
191        return Err(syn::Error::new(
192            method.sig.span(),
193            "view methods (&self) should not have #[call] attribute",
194        ));
195    }
196
197    // Validate: call methods must have #[call]
198    if is_mut && !is_call {
199        return Err(syn::Error::new(
200            method.sig.span(),
201            "call methods (&mut self) must have #[call] attribute",
202        ));
203    }
204
205    // Parse arguments (excluding self)
206    let mut arg_name = None;
207    let mut arg_type = None;
208    let mut arg_count = 0;
209
210    for arg in &method.sig.inputs {
211        if let FnArg::Typed(pat_type) = arg {
212            arg_count += 1;
213            if arg_count > 1 {
214                return Err(syn::Error::new(
215                    pat_type.span(),
216                    "contract methods can have at most one argument (use a struct for multiple parameters)",
217                ));
218            }
219
220            // Extract argument name
221            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
222                arg_name = Some(pat_ident.ident.clone());
223            }
224            arg_type = Some((*pat_type.ty).clone());
225        }
226    }
227
228    // Parse return type
229    let return_type = match &method.sig.output {
230        ReturnType::Default => None,
231        ReturnType::Type(_, ty) => Some((**ty).clone()),
232    };
233
234    Ok(MethodInfo {
235        name,
236        is_view,
237        is_call,
238        is_payable,
239        format_override,
240        arg_name,
241        arg_type,
242        return_type,
243    })
244}
245
246/// Generate client method for a view function.
247fn generate_view_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
248    let method_name = &method.name;
249    let method_name_str = method_name.to_string();
250
251    // Use method override if present, otherwise contract default
252    let format = method.format_override.unwrap_or(contract_format);
253
254    let return_type = method
255        .return_type
256        .as_ref()
257        .map(|t| quote! { #t })
258        .unwrap_or_else(|| quote! { () });
259
260    // For Borsh format, chain .borsh() to switch to Borsh deserialization
261    let borsh_suffix = match format {
262        SerializationFormat::Json => quote! {},
263        SerializationFormat::Borsh => quote! { .borsh() },
264    };
265
266    // Return type differs based on format
267    let view_return_type = match format {
268        SerializationFormat::Json => quote! { near_kit::ViewCall<#return_type> },
269        SerializationFormat::Borsh => quote! { near_kit::ViewCallBorsh<#return_type> },
270    };
271
272    if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
273        // View with args
274        let args_method = match format {
275            SerializationFormat::Json => quote! { .args(#arg_name) },
276            SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
277        };
278
279        quote! {
280            pub fn #method_name(&self, #arg_name: #arg_type) -> #view_return_type {
281                self.near.view::<#return_type>(&self.contract_id, #method_name_str)
282                    #args_method
283                    #borsh_suffix
284            }
285        }
286    } else {
287        // View without args - for JSON, pass empty object; for Borsh, no args
288        match format {
289            SerializationFormat::Json => {
290                quote! {
291                    pub fn #method_name(&self) -> #view_return_type {
292                        self.near.view::<#return_type>(&self.contract_id, #method_name_str)
293                            .args(serde_json::json!({}))
294                    }
295                }
296            }
297            SerializationFormat::Borsh => {
298                quote! {
299                    pub fn #method_name(&self) -> #view_return_type {
300                        self.near.view::<#return_type>(&self.contract_id, #method_name_str)
301                            .borsh()
302                    }
303                }
304            }
305        }
306    }
307}
308
309/// Generate client method for a call function.
310fn generate_call_method(method: &MethodInfo, contract_format: SerializationFormat) -> TokenStream2 {
311    let method_name = &method.name;
312    let method_name_str = method_name.to_string();
313
314    // Use method override if present, otherwise contract default
315    let format = method.format_override.unwrap_or(contract_format);
316
317    if let (Some(arg_name), Some(arg_type)) = (&method.arg_name, &method.arg_type) {
318        // Call with args
319        let args_method = match format {
320            SerializationFormat::Json => quote! { .args(#arg_name) },
321            SerializationFormat::Borsh => quote! { .args_borsh(#arg_name) },
322        };
323
324        quote! {
325            pub fn #method_name(&self, #arg_name: #arg_type) -> near_kit::CallBuilder {
326                self.near.call(&self.contract_id, #method_name_str)
327                    #args_method
328            }
329        }
330    } else {
331        // Call without args - for JSON, pass empty object; for Borsh, no args
332        match format {
333            SerializationFormat::Json => {
334                quote! {
335                    pub fn #method_name(&self) -> near_kit::CallBuilder {
336                        self.near.call(&self.contract_id, #method_name_str)
337                            .args(serde_json::json!({}))
338                    }
339                }
340            }
341            SerializationFormat::Borsh => {
342                quote! {
343                    pub fn #method_name(&self) -> near_kit::CallBuilder {
344                        self.near.call(&self.contract_id, #method_name_str)
345                    }
346                }
347            }
348        }
349    }
350}
351
352/// Strip internal attributes from a method for the output trait.
353fn strip_internal_attrs(method: &TraitItemFn) -> TraitItemFn {
354    let mut method = method.clone();
355    method.attrs.retain(|attr| {
356        !attr.path().is_ident("call")
357            && !attr.path().is_ident("json")
358            && !attr.path().is_ident("borsh")
359    });
360    method
361}
362
363/// The main contract macro implementation.
364#[proc_macro_attribute]
365pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream {
366    let args = parse_macro_input!(attr as ContractArgs);
367    let input = parse_macro_input!(item as ItemTrait);
368
369    match contract_impl(args, input) {
370        Ok(tokens) => tokens.into(),
371        Err(err) => err.to_compile_error().into(),
372    }
373}
374
375fn contract_impl(args: ContractArgs, input: ItemTrait) -> syn::Result<TokenStream2> {
376    let trait_name = &input.ident;
377    let client_name = format_ident!("{}Client", trait_name);
378    let vis = &input.vis;
379
380    // Parse all methods
381    let mut methods = Vec::new();
382    for item in &input.items {
383        if let TraitItem::Fn(method) = item {
384            methods.push(parse_method(method)?);
385        }
386    }
387
388    // Generate client methods
389    let client_methods: Vec<TokenStream2> = methods
390        .iter()
391        .map(|m| {
392            if m.is_view {
393                generate_view_method(m, args.format)
394            } else {
395                generate_call_method(m, args.format)
396            }
397        })
398        .collect();
399
400    // Generate the cleaned trait (without internal attributes)
401    let cleaned_items: Vec<TraitItem> = input
402        .items
403        .iter()
404        .map(|item| {
405            if let TraitItem::Fn(method) = item {
406                TraitItem::Fn(strip_internal_attrs(method))
407            } else {
408                item.clone()
409            }
410        })
411        .collect();
412
413    let trait_attrs = &input.attrs;
414    let trait_supertraits = &input.supertraits;
415    let trait_generics = &input.generics;
416
417    // Build the output
418    let expanded = quote! {
419        // Original trait (with internal attrs stripped for cleaner output)
420        // The trait is used for defining the interface, but the generated client
421        // struct is what's actually used - so suppress dead_code warnings.
422        #[allow(dead_code)]
423        #(#trait_attrs)*
424        #vis trait #trait_name #trait_generics : #trait_supertraits {
425            #(#cleaned_items)*
426        }
427
428        // Generated client struct
429        #vis struct #client_name<'a> {
430            near: &'a near_kit::Near,
431            contract_id: near_kit::AccountId,
432        }
433
434        impl<'a> #client_name<'a> {
435            /// Create a new contract client.
436            pub fn new(near: &'a near_kit::Near, contract_id: near_kit::AccountId) -> Self {
437                Self { near, contract_id }
438            }
439
440            /// Get the contract account ID.
441            pub fn contract_id(&self) -> &near_kit::AccountId {
442                &self.contract_id
443            }
444
445            #(#client_methods)*
446        }
447
448        // Implement ContractClient trait for construction via near.contract::<T>()
449        impl<'a> near_kit::contract::ContractClient<'a> for #client_name<'a> {
450            fn new(near: &'a near_kit::Near, contract_id: near_kit::AccountId) -> Self {
451                Self { near, contract_id }
452            }
453        }
454
455        // Implement Contract marker trait
456        impl near_kit::Contract for dyn #trait_name {
457            type Client<'a> = #client_name<'a>;
458        }
459    };
460
461    Ok(expanded)
462}
463
464/// Attribute macro for marking call methods.
465///
466/// This is used internally by `#[near_kit::contract]` traits.
467///
468/// # Examples
469///
470/// ```ignore
471/// #[call]
472/// fn increment(&mut self);
473///
474/// #[call(payable)]
475/// fn donate(&mut self);
476/// ```
477#[proc_macro_attribute]
478pub fn call(_attr: TokenStream, item: TokenStream) -> TokenStream {
479    // This is just a marker attribute - the actual work is done by #[contract]
480    item
481}
482
483/// Attribute macro for specifying JSON serialization format.
484///
485/// Use this to override the contract-level serialization format for a specific method.
486///
487/// # Examples
488///
489/// ```ignore
490/// #[near_kit::contract(borsh)]  // Contract default: Borsh
491/// pub trait MyContract {
492///     #[json]  // Override: this method uses JSON
493///     fn get_json_data(&self) -> JsonData;
494/// }
495/// ```
496#[proc_macro_attribute]
497pub fn json(_attr: TokenStream, item: TokenStream) -> TokenStream {
498    // This is just a marker attribute - the actual work is done by #[contract]
499    item
500}
501
502/// Attribute macro for specifying Borsh serialization format.
503///
504/// Use this to override the contract-level serialization format for a specific method.
505///
506/// # Examples
507///
508/// ```ignore
509/// #[near_kit::contract]  // Contract default: JSON
510/// pub trait MyContract {
511///     #[borsh]  // Override: this method uses Borsh
512///     fn get_binary_state(&self) -> BinaryState;
513///     
514///     #[call]
515///     #[borsh]  // Override: this call uses Borsh
516///     fn set_binary_state(&mut self, args: BinaryArgs);
517/// }
518/// ```
519#[proc_macro_attribute]
520pub fn borsh(_attr: TokenStream, item: TokenStream) -> TokenStream {
521    // This is just a marker attribute - the actual work is done by #[contract]
522    item
523}