1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::ToTokens;
6use syn::{
7 parse::Parse, punctuated::Punctuated, token::Comma, visit_mut::{self, VisitMut}, AttrStyle, DeriveInput, FnArg, GenericParam, ImplItem, ImplItemFn, ItemImpl, ItemStruct, LitStr, Path, ReturnType, Type
8};
9
10#[macro_use]
11extern crate quote;
12#[macro_use]
13extern crate syn;
14
15extern crate proc_macro;
16
17trait ToMap<T, K, V> {
18 fn to_map(&self) -> syn::Result<std::collections::HashMap<K, V>>;
19}
20
21impl ToMap<Punctuated<syn::Meta, Comma>, syn::Ident, Option<proc_macro2::TokenStream>>
22 for Punctuated<syn::Meta, Comma>
23{
24 fn to_map(
26 &self,
27 ) -> syn::Result<std::collections::HashMap<syn::Ident, Option<proc_macro2::TokenStream>>> {
28 self.iter()
29 .map(|m| {
30 match m {
31 syn::Meta::NameValue(arg) => {
32 if let syn::Expr::Lit(lit) = &arg.value {
37 if let syn::Lit::Str(arg_str) = &lit.lit {
38 let value = if let Ok(call) = arg_str.parse::<syn::ExprCall>() {
39 quote! { #call }
40 }
41 else if let Ok(ident) = arg_str.parse::<syn::Ident>() {
42 quote! { #ident }
43 }
44 else if let Ok(lit) = arg_str.parse::<syn::Lit>() {
45 quote! { #lit }
46 }
47 else if let Ok(path) = arg_str.parse::<syn::Path>() {
48 quote! { #path }
49 }
50 else {
51 return Err(syn::Error::new_spanned(&arg.value, "argument value should be a: variable, literal, path, function call"))
52 };
53
54 Ok((arg.path.get_ident().unwrap().clone(), Some(value)))
55 }
56 else {
57 Err(syn::Error::new_spanned(&arg.value, "argument value should be a string literal"))
58 }
59 }
60 else {
61 Err(syn::Error::new_spanned(&arg.value, "argument value should be a string literal"))
62 }
63 }
64 syn::Meta::Path(arg) => {
65 Ok((arg.get_ident().unwrap().clone(), None))
66 }
67 _ => Err(syn::Error::new_spanned(m, "argument type should be Path or NameValue: `#[bt(default)]`, or `#[bt(default = \"String::new()\")]`"))
68 }
69 })
70 .collect()
71 }
72}
73
74trait ConcatTokenStream {
75 fn concat_list(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream;
76 fn concat_blocks(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream;
77}
78
79impl ConcatTokenStream for proc_macro2::TokenStream {
80 fn concat_list(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
81 if self.is_empty() {
82 if value.is_empty() {
83 proc_macro2::TokenStream::new()
85 } else {
86 quote! {
88 #value
89 }
90 }
91 } else if value.is_empty() {
92 quote! {
94 #self
95 }
96 } else {
97 quote! {
99 #self,
100 #value
101 }
102 }
103 }
104
105 fn concat_blocks(&self, value: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
106 if self.is_empty() {
107 if value.is_empty() {
108 proc_macro2::TokenStream::new()
110 } else {
111 quote! {
113 #value
114 }
115 }
116 } else if value.is_empty() {
117 quote! {
119 #self
120 }
121 } else {
122 quote! {
124 #self
125 #value
126 }
127 }
128 }
129}
130
131struct NodeAttribute {
132 name: syn::Ident,
133 value: syn::Ident,
134}
135
136impl Parse for NodeAttribute {
137 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
138 let name = input.parse()?;
139 input.parse::<Token![=]>()?;
140 let value = input.parse()?;
141
142 Ok(Self {
143 name, value
144 })
145 }
146}
147
148struct NodeImplConfig {
149 node_type: syn::Ident,
150 tick_fn: syn::Ident,
151 on_start_fn: Option<syn::Ident>,
152 ports: Option<syn::Ident>,
153 halt: Option<syn::Ident>,
154}
155
156impl Parse for NodeImplConfig {
157 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
158 let node_type: Ident = input.parse()?;
159 let node_type_str = node_type.to_string();
160
161 if input.parse::<Token![,]>().is_ok() {
162 let mut attributes: HashMap<String, NodeAttribute> = input.parse_terminated(NodeAttribute::parse, Token![,])?
163 .into_iter()
164 .map(|val| (val.name.to_string(), val))
165 .collect();
166
167 let (tick_fn, on_start_fn) = if node_type_str == "StatefulActionNode" {
168 let tick_fn = attributes
169 .remove("on_running")
170 .map(|val| val.value)
171 .unwrap_or_else(|| syn::parse2(quote! { on_running }).unwrap());
172
173 let on_start_fn = attributes
174 .remove("on_start")
175 .map(|val| val.value)
176 .unwrap_or_else(|| syn::parse2(quote! { on_start }).unwrap());
177
178 (tick_fn, Some(on_start_fn))
179 } else {
180 let tick_fn = attributes
181 .remove("tick")
182 .map(|val| val.value)
183 .unwrap_or_else(|| syn::parse2(quote! { tick }).unwrap());
184
185 (tick_fn, None)
186 };
187
188 let ports = attributes.remove("ports").map(|val| val.value);
189 let halt = attributes.remove("halt").map(|val| val.value);
190
191 if let Some((_, invalid_field)) = attributes.into_iter().next() {
192 return Err(syn::Error::new(invalid_field.name.span(), "invalid field name"));
193 }
194
195 Ok(Self {
196 node_type,
197 tick_fn,
198 on_start_fn,
199 ports,
200 halt,
201 })
202 } else {
203 let (tick_fn, on_start_fn) = if node_type_str == "StatefulActionNode" {
204 (Ident::new("on_running", input.span()), Some(Ident::new("on_start", input.span())))
205 } else {
206 (Ident::new("tick", input.span()), None)
207 };
208
209 Ok(Self {
210 node_type,
211 tick_fn,
212 on_start_fn,
213 ports: None,
214 halt: None,
215 })
216 }
217 }
218}
219
220struct SelfVisitor;
221
222impl VisitMut for SelfVisitor {
223 fn visit_ident_mut(&mut self, i: &mut proc_macro2::Ident) {
224 if i == "self" {
225 let ctx = quote! { self_ };
226 let ctx = syn::parse2(ctx).unwrap();
227
228 *i = ctx;
229 }
230
231 visit_mut::visit_ident_mut(self, i)
232 }
233}
234
235fn alter_node_fn(fn_item: &mut ImplItemFn, struct_type: &Type, is_async: bool) -> syn::Result<()> {
236 if is_async {
238 fn_item.sig.asyncness = None;
239 }
240 let lifetime: GenericParam = syn::parse2(quote!{ 'a })?;
242 fn_item.sig.generics.params.push(lifetime);
243 for arg in fn_item.sig.inputs.iter_mut() {
245 if let FnArg::Receiver(_) = arg {
246 let new_arg = quote! { node_: &'a mut ::behaviortree_rs::nodes::TreeNodeData };
247 let new_arg = syn::parse2(new_arg)?;
248 *arg = new_arg;
249 }
250 }
251
252 let new_arg = syn::parse2(quote! { ctx: &'a mut ::std::boxed::Box<dyn ::core::any::Any + ::core::marker::Send + ::core::marker::Sync> })?;
253
254 fn_item.sig.inputs.push(new_arg);
255
256 let old_block = &mut fn_item.block;
257 SelfVisitor.visit_block_mut(old_block);
259
260 let new_block = if is_async {
261 let old_return = match &fn_item.sig.output {
263 ReturnType::Default => quote! { () },
264 ReturnType::Type(_, ret) => quote! { #ret }
265 };
266
267 let new_return = quote! {
269 -> ::futures::future::BoxFuture<'a, #old_return>
270 };
271
272 let new_return = syn::parse2(new_return)?;
273 fn_item.sig.output = new_return;
274
275 quote! {
277 {
278 ::std::boxed::Box::pin(async move {
279 let mut self_ = ctx.downcast_mut::<#struct_type>().unwrap();
280 #old_block
281 })
282 }
283 }
284 } else {
285 quote! {
287 {
288 let mut self_ = ctx.downcast_mut::<#struct_type>().unwrap();
289 #old_block
290 }
291 }
292 };
293
294 let new_block = syn::parse2(new_block)?;
295
296 fn_item.block = new_block;
297
298 Ok(())
299}
300
301fn bt_impl(
302 mut args: NodeImplConfig,
303 mut item: ItemImpl,
304) -> syn::Result<proc_macro2::TokenStream> {
305 let struct_type = &item.self_ty;
306
307 for sub_item in item.items.iter_mut() {
308 if let ImplItem::Fn(fn_item) = sub_item {
309 let mut should_rewrite_def = false;
310 let mut new_ident = None;
312 if fn_item.sig.ident == args.tick_fn {
314 new_ident = if args.node_type == "StatefulActionNode" {
315 Some(syn::parse2(quote! { _on_running })?)
316 } else {
317 Some(syn::parse2(quote! { _tick })?)
318 };
319
320 should_rewrite_def = true;
321 }
322 if let Some(on_start) = args.on_start_fn.as_ref() {
324 if &fn_item.sig.ident == on_start {
325 new_ident = Some(syn::parse2(quote! { _on_start })?);
326 should_rewrite_def = true;
327 }
328 }
329 if let Some(halt) = args.halt.as_ref() {
331 if &fn_item.sig.ident == halt {
332 new_ident = Some(syn::parse2(quote! { _halt })?);
333 should_rewrite_def = true;
334 }
335 } else if &fn_item.sig.ident == "halt" {
336 args.halt = Some(fn_item.sig.ident.clone());
337 new_ident = Some(syn::parse2(quote! { _halt })?);
338 should_rewrite_def = true;
339 }
340 if let Some(ports) = args.ports.as_ref() {
342 if &fn_item.sig.ident == ports {
343 new_ident = Some(syn::parse2(quote! { _ports })?);
344 }
345 } else if &fn_item.sig.ident == "ports" {
346 args.ports = Some(fn_item.sig.ident.clone());
347 new_ident = Some(syn::parse2(quote! { _ports })?);
348 }
349
350 if let Some(new_ident) = new_ident {
351 if should_rewrite_def {
352 alter_node_fn(fn_item, struct_type, true)?;
353 }
354
355 fn_item.sig.ident = new_ident;
356 }
357 }
358 }
359
360 let mut extra_impls = Vec::new();
361
362 if args.halt.is_none() {
363 extra_impls.push(syn::parse2(quote! {
364 fn _halt<'a>(node_: &'a mut ::behaviortree_rs::nodes::TreeNodeData, ctx: &'a mut ::std::boxed::Box<dyn ::core::any::Any + ::core::marker::Send + ::core::marker::Sync>) -> ::futures::future::BoxFuture<'a, ()> { ::std::boxed::Box::pin(async move {}) }
365 })?)
366 }
367
368 if args.ports.is_none() {
369 extra_impls.push(syn::parse2(quote! {
370 fn _ports() -> ::behaviortree_rs::basic_types::PortsList { ::behaviortree_rs::basic_types::PortsList::new() }
371 })?)
372 }
373
374 item.items.extend(extra_impls);
375
376 Ok(quote! { #item })
377}
378
379fn bt_struct(
380 type_ident: Path,
381 mut item: ItemStruct,
382) -> syn::Result<proc_macro2::TokenStream> {
383 let mut derives =
384 vec![quote! { ::std::fmt::Debug }];
385
386 let type_ident = type_ident.require_ident()?;
387 let type_ident_str = type_ident.to_string();
388
389 let item_ident = &item.ident;
390
391 let mut default_fields = proc_macro2::TokenStream::new();
392 let mut manual_fields = proc_macro2::TokenStream::new();
393 let mut manual_fields_with_types = proc_macro2::TokenStream::new();
394 let mut extra_impls = proc_macro2::TokenStream::new();
395
396 match &mut item.fields {
397 syn::Fields::Named(fields) => {
398 for f in fields.named.iter_mut() {
399 let name = f.ident.as_ref().unwrap();
400 let ty = &f.ty;
401
402 let mut used_default = false;
403 for a in f.attrs.iter() {
404 if a.path().is_ident("bt") {
405 let args: Punctuated<syn::Meta, Comma> =
406 a.parse_args_with(Punctuated::parse_terminated)?;
407 let args_map = args.to_map()?;
408
409 if let Some(value) = args_map.get(&syn::parse_str("default")?) {
411 used_default = true;
412 let default_value = if let Some(default_value) = value {
414 quote! { #default_value }
415 }
416 else {
418 quote! { <#ty>::default() }
419 };
420
421 default_fields =
422 default_fields.concat_list(quote! { #name: #default_value });
423 }
424 }
425 }
426
427 if !used_default {
429 manual_fields = manual_fields.concat_list(quote! { #name });
430 manual_fields_with_types =
431 manual_fields_with_types.concat_list(quote! { #name: #ty });
432 }
433
434 f.attrs = f
436 .attrs
437 .clone()
438 .into_iter()
439 .filter(|a| !a.path().is_ident("bt"))
440 .collect();
441 }
442 }
443 _ => {
444 return Err(syn::Error::new_spanned(
445 item,
446 "expected a struct with named fields",
447 ))
448 }
449 };
450
451 let vis = &item.vis;
452 let struct_fields = &item.fields;
453
454 let mut user_attrs = Vec::new();
455
456 for attr in item.attrs.iter() {
457 if attr.path().is_ident("derive") {
458 derives.push(attr.parse_args()?);
459 } else if let AttrStyle::Outer = attr.style {
460 user_attrs.push(attr);
461 }
462 }
463
464 let user_attrs = user_attrs
465 .into_iter()
466 .fold(proc_macro2::TokenStream::new(), |acc, a| {
467 if let AttrStyle::Outer = a.style {
469 if acc.is_empty() {
470 quote! {
471 #a
472 }
473 } else {
474 quote! {
475 #acc
476 #a
477 }
478 }
479 } else {
480 acc
481 }
482 });
483
484 let derives = derives
486 .into_iter()
487 .fold(proc_macro2::TokenStream::new(), |acc, d| {
488 if acc.is_empty() {
489 quote! {
490 #d
491 }
492 } else {
493 quote! {
494 #acc, #d
495 }
496 }
497 });
498
499 let extra_fields = proc_macro2::TokenStream::new()
500 .concat_list(default_fields)
501 .concat_list(manual_fields);
502
503 let node_category = match type_ident_str.as_str() {
514 "StatefulActionNode" | "SyncActionNode" => syn::parse2::<Path>(quote! { Action })?,
515 "ControlNode" => syn::parse2::<Path>(quote! { Control })?,
516 "DecoratorNode" => syn::parse2::<Path>(quote! { Decorator })?,
517 _ => return Err(syn::Error::new_spanned(type_ident, "Invalid node type"))
518 };
519
520 let node_type = match type_ident_str.as_str() {
521 "StatefulActionNode" => syn::parse2::<Path>(quote! { StatefulAction })?,
522 "SyncActionNode" => syn::parse2::<Path>(quote! { SyncAction })?,
523 "ControlNode" => syn::parse2::<Path>(quote! { Control })?,
524 "DecoratorNode" => syn::parse2::<Path>(quote! { Decorator })?,
525 _ => return Err(syn::Error::new_spanned(type_ident, "Invalid node type"))
526 };
527
528 let node_specific_tokens = node_fields(&type_ident_str);
529
530 let struct_name = LitStr::new(&item_ident.to_token_stream().to_string(), item_ident.span());
531
532 let output = quote! {
533 #user_attrs
534 #[derive(#derives)]
535 #vis struct #item_ident #struct_fields
536
537 impl #item_ident {
538 pub fn create_node(name: impl AsRef<str>, config: ::behaviortree_rs::nodes::NodeConfig, #manual_fields_with_types) -> ::behaviortree_rs::nodes::TreeNode {
539 let ctx = #item_ident {
540 #extra_fields
541 };
542
543 let node_data = ::behaviortree_rs::nodes::TreeNodeData {
544 name: name.as_ref().to_string(),
545 type_str: String::from(#struct_name),
546 node_type: ::behaviortree_rs::nodes::NodeType::#node_type,
547 node_category: ::behaviortree_rs::basic_types::NodeCategory::#node_category,
548 config,
549 status: ::behaviortree_rs::basic_types::NodeStatus::Idle,
550 children: ::std::vec::Vec::new(),
551 ports_fn: Self::_ports,
552 };
553
554 ::behaviortree_rs::nodes::TreeNode {
555 data: node_data,
556 context: ::std::boxed::Box::new(ctx),
557 halt_fn: Self::_halt,
558 #node_specific_tokens
559 }
560 }
561 }
562
563 #extra_impls
564 };
565
566 Ok(output)
567}
568
569fn node_fields(type_ident_str: &str) -> proc_macro2::TokenStream {
570 match type_ident_str {
571 "StatefulActionNode" => {
572 quote! {
573 tick_fn: Self::_on_running,
574 start_fn: Self::_on_start,
575 }
576 }
577 _ => {
579 quote! {
580 tick_fn: Self::_tick,
581 start_fn: Self::_tick,
582 }
583 }
584 }
585}
586
587#[proc_macro_attribute]
701pub fn bt_node(args: TokenStream, input: TokenStream) -> TokenStream {
702 if let Ok(struct_) = syn::parse::<ItemStruct>(input.clone()) {
703 let args = parse_macro_input!(args as Path);
704 bt_struct(args, struct_).unwrap_or_else(syn::Error::into_compile_error).into()
706 } else if let Ok(impl_) = syn::parse::<ItemImpl>(input) {
707 let args = parse_macro_input!(args as NodeImplConfig);
708 bt_impl(args, impl_).unwrap_or_else(syn::Error::into_compile_error).into()
709 } else {
710 syn::Error::new(Span::call_site(), "The `bt_node` macro must be used on either a `struct` or `impl` block.").into_compile_error().into()
711 }
712
713 }
720
721#[proc_macro_derive(FromString)]
722pub fn derive_from_string(input: TokenStream) -> TokenStream {
723 let input = parse_macro_input!(input as DeriveInput);
724
725 let ident = input.ident;
726
727 let expanded = quote! {
728 impl ::behaviortree_rs::basic_types::FromString for #ident {
729 type Err = <#ident as ::core::str::FromStr>::Err;
730
731 fn from_string(value: impl AsRef<str>) -> Result<#ident, Self::Err> {
732 value.as_ref().parse()
733 }
734 }
735 };
736
737 TokenStream::from(expanded)
738}
739
740struct NodeRegistration {
741 factory: syn::Ident,
742 name: proc_macro2::TokenStream,
743 node_type: syn::Type,
744 params: Punctuated<syn::Expr, Comma>,
745}
746
747impl Parse for NodeRegistration {
748 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
749 let factory = input.parse()?;
750 input.parse::<Token![,]>()?;
751
752 let node_name = input.parse::<syn::Expr>()?.to_token_stream();
753
754 input.parse::<Token![,]>()?;
755 let node_type = input.parse()?;
756 if !input.is_empty() {
758 input.parse::<Token![,]>()?;
759 }
760
761 let params = input.parse_terminated(syn::Expr::parse, Token![,])?;
762
763 Ok(Self {
764 factory,
765 name: node_name,
766 node_type,
767 params,
768 })
769 }
770}
771
772fn build_node(node: &NodeRegistration) -> proc_macro2::TokenStream {
773 let NodeRegistration {
774 factory: _,
775 name,
776 node_type,
777 params
778 } = node;
779
780 let cloned_names = (0..params.len())
781 .fold(quote!{}, |acc, i| {
782 let arg_name = Ident::new(&format!("arg{i}"), Span::call_site());
783 quote!{ #acc, #arg_name.clone() }
784 });
785
786 quote! {
787 {
788 let mut node = #node_type::create_node(#name, config #cloned_names);
789 let manifest = ::behaviortree_rs::basic_types::TreeNodeManifest {
790 node_type: node.node_category(),
791 registration_id: #name.into(),
792 ports: node.provided_ports(),
793 description: ::std::string::String::new(),
794 };
795 node.config_mut().set_manifest(::std::sync::Arc::new(manifest));
796 node
797 }
798 }
799}
800
801fn register_node(input: TokenStream, node_type_token: proc_macro2::TokenStream, node_type: NodeTypeInternal) -> TokenStream {
802 let node_registration = parse_macro_input!(input as NodeRegistration);
803
804 let factory = &node_registration.factory;
805 let name = &node_registration.name;
806 let params = &node_registration.params;
807
808 let param_clone_expr = params
810 .iter()
811 .enumerate()
812 .fold(quote!{}, |acc, (i, item)| {
813 let arg_name = Ident::new(&format!("arg{i}"), Span::call_site());
814 quote! {
815 #acc
816 let #arg_name = #item.clone();
817 }
818 });
819
820 let node = build_node(&node_registration);
821
822 let extra_steps = match node_type {
823 NodeTypeInternal::Control => quote! {
824 node.data.children = children;
825 },
826 NodeTypeInternal::Decorator => quote! {
827 node.data.children = children;
828 },
829 _ => quote!{ }
830 };
831
832 let expanded = quote! {
833 {
834 let blackboard = #factory.blackboard().clone();
835
836 #param_clone_expr
837
838 let node_fn = move |
839 config: ::behaviortree_rs::nodes::NodeConfig,
840 mut children: ::std::vec::Vec<::behaviortree_rs::nodes::TreeNode>
841 | -> ::behaviortree_rs::nodes::TreeNode
842 {
843 let mut node = #node;
844
845 #extra_steps
846
847 node
848 };
849
850 #factory.register_node(#name, node_fn, #node_type_token);
851 }
852 };
853
854 TokenStream::from(expanded)
855}
856
857enum NodeTypeInternal {
858 Action,
859 Control,
860 Decorator,
861}
862
863#[proc_macro]
880pub fn register_action_node(input: TokenStream) -> TokenStream {
881 register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Action }, NodeTypeInternal::Action)
882}
883
884#[proc_macro]
901pub fn register_control_node(input: TokenStream) -> TokenStream {
902 register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Control }, NodeTypeInternal::Control)
903}
904
905#[proc_macro]
922pub fn register_decorator_node(input: TokenStream) -> TokenStream {
923 register_node(input, quote! { ::behaviortree_rs::basic_types::NodeCategory::Decorator }, NodeTypeInternal::Decorator)
924}