behaviortree_rs_derive/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::ToTokens;
6use syn::{
7    parse::Parse, punctuated::Punctuated, token::Comma, visit_mut::{self, VisitMut}, AttrStyle, DeriveInput, FnArg, GenericParam, ImplItem, ImplItemFn, ItemImpl, ItemStruct, LitStr, Path, ReturnType, Type
8};
9
10#[macro_use]
11extern crate quote;
12#[macro_use]
13extern crate syn;
14
15extern crate proc_macro;
16
17trait ToMap<T, K, V> {
18    fn to_map(&self) -> syn::Result<std::collections::HashMap<K, V>>;
19}
20
21impl ToMap<Punctuated<syn::Meta, Comma>, syn::Ident, Option<proc_macro2::TokenStream>>
22    for Punctuated<syn::Meta, Comma>
23{
24    /// Convert a list of attribute arguments to a HashMap
25    fn to_map(
26        &self,
27    ) -> syn::Result<std::collections::HashMap<syn::Ident, Option<proc_macro2::TokenStream>>> {
28        self.iter()
29            .map(|m| {
30                match m {
31                    syn::Meta::NameValue(arg) => {
32                        // Convert Expr to one of the valid types:
33                        // Ident (variable name etc)
34                        // ExprCall (function call etc)
35                        // Lit (literal, for integer types etc)
36                        if let syn::Expr::Lit(lit) = &arg.value {
37                            if let syn::Lit::Str(arg_str) = &lit.lit {
38                                let value = if let Ok(call) = arg_str.parse::<syn::ExprCall>() {
39                                    quote! { #call }
40                                }
41                                else if let Ok(ident) = arg_str.parse::<syn::Ident>() {
42                                    quote! { #ident }
43                                }
44                                else if let Ok(lit) = arg_str.parse::<syn::Lit>() {
45                                    quote! { #lit }
46                                }
47                                else if let Ok(path) = arg_str.parse::<syn::Path>() {
48                                    quote! { #path }
49                                }
50                                else {
51                                    return Err(syn::Error::new_spanned(&arg.value, "argument value should be a:  variable, literal, path, function call"))
52                                };
53
54                                Ok((arg.path.get_ident().unwrap().clone(), Some(value)))
55                            }
56                            else {
57                                Err(syn::Error::new_spanned(&arg.value, "argument value should be a string literal"))
58                            }
59                        }
60                        else {
61                            Err(syn::Error::new_spanned(&arg.value, "argument value should be a string literal"))
62                        }
63                    }
64                    syn::Meta::Path(arg) => {
65                        Ok((arg.get_ident().unwrap().clone(), None))
66                    }
67                    _ => Err(syn::Error::new_spanned(m, "argument type should be Path or NameValue: `#[bt(default)]`, or `#[bt(default = \"String::new()\")]`"))
68                }
69            })
70            .collect()
71    }
72}
73
74trait ConcatTokenStream {
75    fn concat_list(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream;
76    fn concat_blocks(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream;
77}
78
79impl ConcatTokenStream for proc_macro2::TokenStream {
80    fn concat_list(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
81        if self.is_empty() {
82            if value.is_empty() {
83                // Both are empty
84                proc_macro2::TokenStream::new()
85            } else {
86                // self empty, value not empty
87                quote! {
88                    #value
89                }
90            }
91        } else if value.is_empty() {
92            // self not empty, value empty
93            quote! {
94                #self
95            }
96        } else {
97            // Both have value
98            quote! {
99                #self,
100                #value
101            }
102        }
103    }
104
105    fn concat_blocks(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
106        if self.is_empty() {
107            if value.is_empty() {
108                // Both are empty
109                proc_macro2::TokenStream::new()
110            } else {
111                // self empty, value not empty
112                quote! {
113                    #value
114                }
115            }
116        } else if value.is_empty() {
117            // self not empty, value empty
118            quote! {
119                #self
120            }
121        } else {
122            // Both have value
123            quote! {
124                #self
125                #value
126            }
127        }
128    }
129}
130
131struct NodeAttribute {
132    name: syn::Ident,
133    value: syn::Ident,
134}
135
136impl Parse for NodeAttribute {
137    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
138        let name = input.parse()?;
139        input.parse::<Token![=]>()?;
140        let value = input.parse()?;
141
142        Ok(Self {
143            name, value
144        })
145    }
146}
147
148struct NodeImplConfig {
149    node_type: syn::Ident,
150    tick_fn: syn::Ident,
151    on_start_fn: Option<syn::Ident>,
152    ports: Option<syn::Ident>,
153    halt: Option<syn::Ident>,
154}
155
156impl Parse for NodeImplConfig {
157    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
158        let node_type: Ident = input.parse()?;
159        let node_type_str = node_type.to_string();
160
161        if input.parse::<Token![,]>().is_ok() {
162            let mut attributes: HashMap<String, NodeAttribute> = input.parse_terminated(NodeAttribute::parse, Token![,])?
163                .into_iter()
164                .map(|val| (val.name.to_string(), val))
165                .collect();    
166    
167            let (tick_fn, on_start_fn) = if node_type_str == "StatefulActionNode" {
168                let tick_fn = attributes
169                    .remove("on_running")
170                    .map(|val| val.value)
171                    .unwrap_or_else(|| syn::parse2(quote! { on_running }).unwrap());
172                
173                let on_start_fn = attributes
174                    .remove("on_start")
175                    .map(|val| val.value)
176                    .unwrap_or_else(|| syn::parse2(quote! { on_start }).unwrap());
177    
178                (tick_fn, Some(on_start_fn))
179            } else {
180                let tick_fn = attributes
181                    .remove("tick")
182                    .map(|val| val.value)
183                    .unwrap_or_else(|| syn::parse2(quote! { tick }).unwrap());
184    
185                (tick_fn, None)
186            };
187    
188            let ports = attributes.remove("ports").map(|val| val.value);
189            let halt = attributes.remove("halt").map(|val| val.value);
190    
191            if let Some((_, invalid_field)) = attributes.into_iter().next() {
192                return Err(syn::Error::new(invalid_field.name.span(), "invalid field name"));
193            }
194
195            Ok(Self {
196                node_type,
197                tick_fn,
198                on_start_fn,
199                ports,
200                halt,
201            })
202        } else {
203            let (tick_fn, on_start_fn) = if node_type_str == "StatefulActionNode" {
204                (Ident::new("on_running", input.span()), Some(Ident::new("on_start", input.span())))
205            } else {
206                (Ident::new("tick", input.span()), None)
207            };
208
209            Ok(Self {
210                node_type,
211                tick_fn,
212                on_start_fn,
213                ports: None,
214                halt: None,
215            })
216        }
217    }
218}
219
220struct SelfVisitor;
221
222impl VisitMut for SelfVisitor {
223    fn visit_ident_mut(&mut self, i: &mut proc_macro2::Ident) {
224        if i == "self" {
225            let ctx = quote! { self_ };
226            let ctx = syn::parse2(ctx).unwrap();
227            
228            *i = ctx;
229        }
230
231        visit_mut::visit_ident_mut(self, i)
232    }
233}
234
235fn alter_node_fn(fn_item: &mut ImplItemFn, struct_type: &Type, is_async: bool) -> syn::Result<()> {
236    // Remove async
237    if is_async {
238        fn_item.sig.asyncness = None;
239    }
240    // Add lifetime to signature
241    let lifetime: GenericParam = syn::parse2(quote!{ 'a })?;
242    fn_item.sig.generics.params.push(lifetime);
243    // Rename parameters
244    for arg in fn_item.sig.inputs.iter_mut() {
245        if let FnArg::Receiver(_) = arg {
246            let new_arg = quote! { node_: &'a mut ::behaviortree_rs::nodes::TreeNodeData };
247            let new_arg = syn::parse2(new_arg)?;
248            *arg = new_arg;
249        }
250    }
251
252    let new_arg = syn::parse2(quote! { ctx: &'a mut ::std::boxed::Box<dyn ::core::any::Any + ::core::marker::Send + ::core::marker::Sync> })?;
253
254    fn_item.sig.inputs.push(new_arg);
255
256    let old_block = &mut fn_item.block;
257    // Rename occurrences of self
258    SelfVisitor.visit_block_mut(old_block);
259
260    let new_block = if is_async {
261        // Get old return without the -> token
262        let old_return = match &fn_item.sig.output {
263            ReturnType::Default => quote! { () },
264            ReturnType::Type(_, ret) => quote! { #ret }
265        };
266
267        // Wrap return type in BoxFuture
268        let new_return = quote! {
269            -> ::futures::future::BoxFuture<'a, #old_return>
270        };
271
272        let new_return = syn::parse2(new_return)?;
273        fn_item.sig.output = new_return;
274    
275        // Wrap function block in Box::pin and create ctx
276        quote! {
277            {
278                ::std::boxed::Box::pin(async move {
279                    let mut self_ = ctx.downcast_mut::<#struct_type>().unwrap();
280                    #old_block
281                })
282            }
283        }
284    } else {
285        // Wrap function block in Box::pin and create ctx
286        quote! {
287            {
288                let mut self_ = ctx.downcast_mut::<#struct_type>().unwrap();
289                #old_block
290            }
291        }
292    };
293
294    let new_block = syn::parse2(new_block)?;
295
296    fn_item.block = new_block;
297
298    Ok(())
299}
300
301fn bt_impl(
302    mut args: NodeImplConfig,
303    mut item: ItemImpl,
304) -> syn::Result<proc_macro2::TokenStream> {
305    let struct_type = &item.self_ty;
306    
307    for sub_item in item.items.iter_mut() {
308        if let ImplItem::Fn(fn_item) = sub_item {
309            let mut should_rewrite_def = false;
310            // Rename methods
311            let mut new_ident = None;
312            // Check if it's a tick
313            if fn_item.sig.ident == args.tick_fn {
314                new_ident = if args.node_type == "StatefulActionNode" {
315                    Some(syn::parse2(quote! { _on_running })?)
316                } else {
317                    Some(syn::parse2(quote! { _tick })?)
318                };
319
320                should_rewrite_def = true;
321            }
322            // Check if it's an on_start
323            if let Some(on_start) = args.on_start_fn.as_ref() {
324                if &fn_item.sig.ident == on_start {
325                    new_ident = Some(syn::parse2(quote! { _on_start })?);
326                    should_rewrite_def = true;
327                }
328            }
329            // Check if it's a halt
330            if let Some(halt) = args.halt.as_ref() {
331                if &fn_item.sig.ident == halt {
332                    new_ident = Some(syn::parse2(quote! { _halt })?);
333                    should_rewrite_def = true;
334                }
335            } else if &fn_item.sig.ident == "halt" {
336                args.halt = Some(fn_item.sig.ident.clone());
337                new_ident = Some(syn::parse2(quote! { _halt })?);
338                should_rewrite_def = true;
339            }
340            // Check if it's a ports
341            if let Some(ports) = args.ports.as_ref() {
342                if &fn_item.sig.ident == ports {
343                    new_ident = Some(syn::parse2(quote! { _ports })?);
344                }
345            } else if &fn_item.sig.ident == "ports" {
346                args.ports = Some(fn_item.sig.ident.clone());
347                new_ident = Some(syn::parse2(quote! { _ports })?);
348            }
349
350            if let Some(new_ident) = new_ident {
351                if should_rewrite_def {
352                    alter_node_fn(fn_item, struct_type, true)?;
353                }
354                
355                fn_item.sig.ident = new_ident;
356            }
357        }
358    }
359
360    let mut extra_impls = Vec::new();
361
362    if args.halt.is_none() {
363        extra_impls.push(syn::parse2(quote! {
364            fn _halt<'a>(node_: &'a mut ::behaviortree_rs::nodes::TreeNodeData, ctx: &'a mut ::std::boxed::Box<dyn ::core::any::Any + ::core::marker::Send + ::core::marker::Sync>) -> ::futures::future::BoxFuture<'a, ()> { ::std::boxed::Box::pin(async move {}) }
365        })?)
366    }
367
368    if args.ports.is_none() {
369        extra_impls.push(syn::parse2(quote! {
370            fn _ports() -> ::behaviortree_rs::basic_types::PortsList { ::behaviortree_rs::basic_types::PortsList::new() }
371        })?)
372    }
373
374    item.items.extend(extra_impls);
375
376    Ok(quote! { #item })
377}
378
379fn bt_struct(
380    type_ident: Path,
381    mut item: ItemStruct,
382) -> syn::Result<proc_macro2::TokenStream> {
383    let mut derives =
384        vec![quote! { ::std::fmt::Debug }];
385
386    let type_ident = type_ident.require_ident()?;
387    let type_ident_str = type_ident.to_string();
388
389    let item_ident = &item.ident;
390
391    let mut default_fields = proc_macro2::TokenStream::new();
392    let mut manual_fields = proc_macro2::TokenStream::new();
393    let mut manual_fields_with_types = proc_macro2::TokenStream::new();
394    let mut extra_impls = proc_macro2::TokenStream::new();
395
396    match &mut item.fields {
397        syn::Fields::Named(fields) => {
398            for f in fields.named.iter_mut() {
399                let name = f.ident.as_ref().unwrap();
400                let ty = &f.ty;
401
402                let mut used_default = false;
403                for a in f.attrs.iter() {
404                    if a.path().is_ident("bt") {
405                        let args: Punctuated<syn::Meta, Comma> =
406                            a.parse_args_with(Punctuated::parse_terminated)?;
407                        let args_map = args.to_map()?;
408
409                        // If the default argument was included
410                        if let Some(value) = args_map.get(&syn::parse_str("default")?) {
411                            used_default = true;
412                            // Use the provided default, if provided by user
413                            let default_value = if let Some(default_value) = value {
414                                quote! { #default_value }
415                            }
416                            // Otherwise, use Default
417                            else {
418                                quote! { <#ty>::default() }
419                            };
420
421                            default_fields =
422                                default_fields.concat_list(quote! { #name: #default_value });
423                        }
424                    }
425                }
426
427                // Mark field as manually specified if
428                if !used_default {
429                    manual_fields = manual_fields.concat_list(quote! { #name });
430                    manual_fields_with_types =
431                        manual_fields_with_types.concat_list(quote! { #name: #ty });
432                }
433
434                // Remove the bt attribute, keep all others
435                f.attrs = f
436                    .attrs
437                    .clone()
438                    .into_iter()
439                    .filter(|a| !a.path().is_ident("bt"))
440                    .collect();
441            }
442        }
443        _ => {
444            return Err(syn::Error::new_spanned(
445                item,
446                "expected a struct with named fields",
447            ))
448        }
449    };
450
451    let vis = &item.vis;
452    let struct_fields = &item.fields;
453
454    let mut user_attrs = Vec::new();
455
456    for attr in item.attrs.iter() {
457        if attr.path().is_ident("derive") {
458            derives.push(attr.parse_args()?);
459        } else if let AttrStyle::Outer = attr.style {
460            user_attrs.push(attr);
461        }
462    }
463
464    let user_attrs = user_attrs
465        .into_iter()
466        .fold(proc_macro2::TokenStream::new(), |acc, a| {
467            // Only want to transfer outer attributes
468            if let AttrStyle::Outer = a.style {
469                if acc.is_empty() {
470                    quote! {
471                        #a
472                    }
473                } else {
474                    quote! {
475                        #acc
476                        #a
477                    }
478                }
479            } else {
480                acc
481            }
482        });
483
484    // Convert Vec of derive Paths into one TokenStream
485    let derives = derives
486        .into_iter()
487        .fold(proc_macro2::TokenStream::new(), |acc, d| {
488            if acc.is_empty() {
489                quote! {
490                    #d
491                }
492            } else {
493                quote! {
494                    #acc, #d
495                }
496            }
497        });
498
499    let extra_fields = proc_macro2::TokenStream::new()
500        .concat_list(default_fields)
501        .concat_list(manual_fields);
502
503    // let node_type = match type_ident_str.as_str() {
504    //     ""
505    // }
506
507    // let node_type = if type_ident == "StatefulActionNode" || type_ident == "SyncActionNode" {
508    //     syn::parse2::<Ident>(quote! { ActionNode })?
509    // } else {
510    //     type_ident.clone()
511    // };
512
513    let node_category = match type_ident_str.as_str() {
514        "StatefulActionNode" | "SyncActionNode" => syn::parse2::<Path>(quote! { Action })?,
515        "ControlNode" => syn::parse2::<Path>(quote! { Control })?,
516        "DecoratorNode" => syn::parse2::<Path>(quote! { Decorator })?,
517        _ => return Err(syn::Error::new_spanned(type_ident, "Invalid node type"))
518    };
519
520    let node_type = match type_ident_str.as_str() {
521        "StatefulActionNode" => syn::parse2::<Path>(quote! { StatefulAction })?,
522        "SyncActionNode" => syn::parse2::<Path>(quote! { SyncAction })?,
523        "ControlNode" => syn::parse2::<Path>(quote! { Control })?,
524        "DecoratorNode" => syn::parse2::<Path>(quote! { Decorator })?,
525        _ => return Err(syn::Error::new_spanned(type_ident, "Invalid node type"))
526    };
527
528    let node_specific_tokens = node_fields(&type_ident_str);
529
530    let struct_name = LitStr::new(&item_ident.to_token_stream().to_string(), item_ident.span());
531
532    let output = quote! {
533        #user_attrs
534        #[derive(#derives)]
535        #vis struct #item_ident #struct_fields
536
537        impl #item_ident {
538            pub fn create_node(name: impl AsRef<str>, config: ::behaviortree_rs::nodes::NodeConfig, #manual_fields_with_types) -> ::behaviortree_rs::nodes::TreeNode {
539                let ctx = #item_ident {
540                    #extra_fields
541                };
542
543                let node_data = ::behaviortree_rs::nodes::TreeNodeData {
544                    name: name.as_ref().to_string(),
545                    type_str: String::from(#struct_name),
546                    node_type: ::behaviortree_rs::nodes::NodeType::#node_type,
547                    node_category: ::behaviortree_rs::basic_types::NodeCategory::#node_category,
548                    config,
549                    status: ::behaviortree_rs::basic_types::NodeStatus::Idle,
550                    children: ::std::vec::Vec::new(),
551                    ports_fn: Self::_ports,
552                };
553                
554                ::behaviortree_rs::nodes::TreeNode {
555                    data: node_data,
556                    context: ::std::boxed::Box::new(ctx),
557                    halt_fn: Self::_halt,
558                    #node_specific_tokens
559                }
560            }
561        }
562
563        #extra_impls
564    };
565
566    Ok(output)
567}
568
569fn node_fields(type_ident_str: &str) -> proc_macro2::TokenStream {
570    match type_ident_str {
571        "StatefulActionNode" => {
572            quote! {
573                tick_fn: Self::_on_running,
574                start_fn: Self::_on_start,
575            }
576        }
577        // Don't need to check others, it has already been checked before now
578        _ => {
579            quote! {
580                tick_fn: Self::_tick,
581                start_fn: Self::_tick,
582            }
583        }
584    }
585}
586
587/// Macro used to automatically generate the default boilerplate needed for all `TreeNode`s.
588///
589/// # Basic Usage
590///
591/// To use the macro, you need to add `#[bt_node(...)]` above your struct. As an argument
592/// to the attribute, specify the NodeType that you would like to implement.
593///
594/// Supported options:
595/// - `SyncActionNode`
596/// - `StatefulActionNode`
597/// - `ControlNode`
598/// - `DecoratorNode`
599///
600/// By default, the tick method implementation is `async`. To specify this explicitly (or
601/// make it synchronous), add `Async` or `Sync` after the node type.
602///
603/// ===
604///
605/// ```rust
606/// use behaviortree_rs::{bt_node, basic_types::NodeStatus, nodes::{AsyncTick, NodeResult, AsyncHalt, NodePorts}, sync::BoxFuture};
607///
608/// // Here we are specifying a `SyncActionNode` as the node type.
609/// #[bt_node(SyncActionNode)]
610/// // Defaults to #[bt_node(SyncActionNode, Async)]
611/// struct MyActionNode {} // No additional fields
612///
613/// // Now I need to `impl TreeNode`
614/// impl AsyncTick for MyActionNode {
615///     fn tick(&mut self) -> BoxFuture<NodeResult> {
616///         Box::pin(async move {
617///             // Do something here
618///             // ...
619///
620///             Ok(NodeStatus::Success)
621///         })
622///     }
623/// }
624///
625/// impl NodePorts for MyActionNode {}
626///
627/// // Also need to `impl NodeHalt`
628/// // However, we'll just use the default implementation
629/// impl AsyncHalt for MyActionNode {}
630/// ```
631///
632/// ===
633///
634/// The above code will add fields to `MyActionNode` and create a `new()` associated method:
635///
636/// ```ignore
637/// impl DummyActionNode {
638///     pub fn new(name: impl AsRef<str>, config: NodeConfig) -> DummyActionNode {
639///         Self {
640///             name: name.as_ref().to_string(),
641///             config,
642///             status: NodeStatus::Idle
643///         }
644///     }
645/// }
646/// ```
647///
648/// # Adding Fields
649///
650/// When you add your own fields into the struct, be default they will be added
651/// to the `new()` definition as arguments. To specify default values, use
652/// the `#[bt(default)]` attribute above the fields.
653///
654/// `#[bt(default)]` will use the type's implementation of the `Default` trait. If
655/// the trait isn't implemented on the type, or if you want to manually specify
656/// a value, use `#[bt(default = "...")]`, where `...` is the value.
657///
658/// Valid argument types within the `"..."` are:
659///
660/// ```ignore
661/// // Function calls
662/// #[bt(default = "String::from(10)")]
663///
664/// // Variables
665/// #[bt(default = "foo")]
666///
667/// // Paths (like enums)
668/// #[bt(default = "NodeStatus::Idle")]
669///
670/// // Literals
671/// #[bt(default = "10")]
672/// ```
673///
674/// ## Example
675///
676/// ```rust
677/// use behaviortree_rs::{bt_node, basic_types::NodeStatus, nodes::{AsyncTick, NodePorts, NodeResult, AsyncHalt}, sync::BoxFuture};
678///
679/// #[bt_node(SyncActionNode)]
680/// struct MyActionNode {
681///     #[bt(default = "NodeStatus::Success")]
682///     foo: NodeStatus,
683///     #[bt(default)] // defaults to empty String
684///     bar: String
685/// }
686///
687/// // Now I need to `impl TreeNode`
688/// impl AsyncTick for MyActionNode {
689///     fn tick(&mut self) -> BoxFuture<NodeResult> {
690///         Box::pin(async move {
691///             Ok(NodeStatus::Success)
692///         })
693///     }
694/// }
695///
696/// impl NodePorts for MyActionNode {}
697///
698/// impl AsyncHalt for MyActionNode {}
699/// ```
700#[proc_macro_attribute]
701pub fn bt_node(args: TokenStream, input: TokenStream) -> TokenStream {
702    if let Ok(struct_) = syn::parse::<ItemStruct>(input.clone()) {
703        let args = parse_macro_input!(args as Path);
704        // let args = parse_macro_input!(args as NodeStructConfig);
705        bt_struct(args, struct_).unwrap_or_else(syn::Error::into_compile_error).into()
706    } else if let Ok(impl_) = syn::parse::<ItemImpl>(input) {
707        let args = parse_macro_input!(args as NodeImplConfig);
708        bt_impl(args, impl_).unwrap_or_else(syn::Error::into_compile_error).into()
709    } else {
710        syn::Error::new(Span::call_site(), "The `bt_node` macro must be used on either a `struct` or `impl` block.").into_compile_error().into()
711    }
712
713    // let args_parsed = parse_macro_input!(args as NodeStructConfig);
714    // let item = parse_macro_input!(input as ItemStruct);
715
716    // bt_struct(args_parsed, item)
717    //     .unwrap_or_else(syn::Error::into_compile_error)
718    //     .into()
719}
720
721#[proc_macro_derive(FromString)]
722pub fn derive_from_string(input: TokenStream) -> TokenStream {
723    let input = parse_macro_input!(input as DeriveInput);
724
725    let ident = input.ident;
726
727    let expanded = quote! {
728        impl ::behaviortree_rs::basic_types::FromString for #ident {
729            type Err = <#ident as ::core::str::FromStr>::Err;
730
731            fn from_string(value: impl AsRef<str>) -> Result<#ident, Self::Err> {
732                value.as_ref().parse()
733            }
734        }
735    };
736
737    TokenStream::from(expanded)
738}
739
740struct NodeRegistration {
741    factory: syn::Ident,
742    name: proc_macro2::TokenStream,
743    node_type: syn::Type,
744    params: Punctuated<syn::Expr, Comma>,
745}
746
747impl Parse for NodeRegistration {
748    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
749        let factory = input.parse()?;
750        input.parse::<Token![,]>()?;
751        
752        let node_name = input.parse::<syn::Expr>()?.to_token_stream();
753        
754        input.parse::<Token![,]>()?;
755        let node_type = input.parse()?;
756        // If there are extra parameters, try to parse a comma. Otherwise skip
757        if !input.is_empty() {
758            input.parse::<Token![,]>()?;
759        }
760    
761        let params = input.parse_terminated(syn::Expr::parse, Token![,])?;
762
763        Ok(Self {
764            factory,
765            name: node_name,
766            node_type,
767            params,
768        })
769    }
770}
771
772fn build_node(node: &NodeRegistration) -> proc_macro2::TokenStream {
773    let NodeRegistration {
774        factory: _,
775        name,
776        node_type,
777        params
778    } = node;
779
780    let cloned_names = (0..params.len())
781        .fold(quote!{}, |acc, i| {
782            let arg_name = Ident::new(&format!("arg{i}"), Span::call_site());
783            quote!{ #acc, #arg_name.clone() }
784        });
785
786    quote! {
787        {
788            let mut node = #node_type::create_node(#name, config #cloned_names);
789            let manifest = ::behaviortree_rs::basic_types::TreeNodeManifest {
790                node_type: node.node_category(),
791                registration_id: #name.into(),
792                ports: node.provided_ports(),
793                description: ::std::string::String::new(),
794            };
795            node.config_mut().set_manifest(::std::sync::Arc::new(manifest));
796            node
797        }
798    }
799}
800
801fn register_node(input: TokenStream, node_type_token: proc_macro2::TokenStream, node_type: NodeTypeInternal) -> TokenStream {
802    let node_registration = parse_macro_input!(input as NodeRegistration);
803
804    let factory = &node_registration.factory;
805    let name = &node_registration.name;
806    let params = &node_registration.params;
807
808    // Create expression that clones all parameters
809    let param_clone_expr = params
810        .iter()
811        .enumerate()
812        .fold(quote!{}, |acc, (i, item)| {
813            let arg_name = Ident::new(&format!("arg{i}"), Span::call_site());
814            quote! {
815                #acc
816                let #arg_name = #item.clone();
817            }
818        });
819
820    let node = build_node(&node_registration);
821
822    let extra_steps = match node_type {
823        NodeTypeInternal::Control => quote! {
824            node.data.children = children;
825        },
826        NodeTypeInternal::Decorator => quote! { 
827            node.data.children = children;
828        },
829        _ => quote!{ }
830    };
831
832    let expanded = quote! {
833        {
834            let blackboard = #factory.blackboard().clone();
835
836            #param_clone_expr
837
838            let node_fn = move |
839                config: ::behaviortree_rs::nodes::NodeConfig,
840                mut children: ::std::vec::Vec<::behaviortree_rs::nodes::TreeNode>
841            | -> ::behaviortree_rs::nodes::TreeNode
842            {
843                let mut node = #node;
844                
845                #extra_steps
846
847                node
848            };
849
850            #factory.register_node(#name, node_fn, #node_type_token);
851        }
852    };
853
854    TokenStream::from(expanded)
855}
856
857enum NodeTypeInternal {
858    Action,
859    Control,
860    Decorator,
861}
862
863/// Registers an Action type node with the factory.
864/// 
865/// **NOTE:** During tree creation, a new node is created using the parameters
866/// given after the node type field. You specified these fields in your node struct
867/// definition. Each time a node is created, the parameters are cloned using `Clone::clone`.
868/// Thus, your parameters must implement `Clone`.
869/// 
870/// # Usage
871/// 
872/// ```ignore
873/// let mut factory = Factory::new();
874/// let arg1 = String::from("hello world");
875/// let arg2 = 10u32;
876/// 
877/// register_action_node!(factory, "TestNode", TestNode, arg1, arg2);
878/// ```
879#[proc_macro]
880pub fn register_action_node(input: TokenStream) -> TokenStream {
881    register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Action }, NodeTypeInternal::Action)
882}
883
884/// Registers an Control type node with the factory.
885/// 
886/// **NOTE:** During tree creation, a new node is created using the parameters
887/// given after the node type field. You specified these fields in your node struct
888/// definition. Each time a node is created, the parameters are cloned using `Clone::clone`.
889/// Thus, your parameters must implement `Clone`.
890/// 
891/// # Usage
892/// 
893/// ```ignore
894/// let mut factory = Factory::new();
895/// let arg1 = String::from("hello world");
896/// let arg2 = 10u32;
897/// 
898/// register_control_node!(factory, "TestNode", TestNode, arg1, arg2);
899/// ```
900#[proc_macro]
901pub fn register_control_node(input: TokenStream) -> TokenStream {
902    register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Control }, NodeTypeInternal::Control)
903}
904
905/// Registers an Decorator type node with the factory.
906/// 
907/// **NOTE:** During tree creation, a new node is created using the parameters
908/// given after the node type field. You specified these fields in your node struct
909/// definition. Each time a node is created, the parameters are cloned using `Clone::clone`.
910/// Thus, your parameters must implement `Clone`.
911/// 
912/// # Usage
913/// 
914/// ```ignore
915/// let mut factory = Factory::new();
916/// let arg1 = String::from("hello world");
917/// let arg2 = 10u32;
918/// 
919/// register_decorator_node!(factory, "TestNode", TestNode, arg1, arg2);
920/// ```
921#[proc_macro]
922pub fn register_decorator_node(input: TokenStream) -> TokenStream {
923    register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Decorator }, NodeTypeInternal::Decorator)
924}