flarrow_api_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    DeriveInput, ImplItem, ItemImpl, ReturnType, Token,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10};
11
12#[proc_macro_derive(Node)]
13pub fn derive_node(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = input.ident;
16
17    let expanded = quote! {
18        #[doc(hidden)]
19        #[unsafe(no_mangle)]
20        pub static FLARROW_NODE: DynamicallyLinkedNodeInstance = |inputs, outputs, configuration| {
21            <#name>::new(inputs, outputs, configuration)
22        };
23
24        static DEFAULT_TOKIO_RUNTIME: std::sync::LazyLock<tokio::runtime::Runtime> =
25            std::sync::LazyLock::new(|| tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime"));
26
27        fn default_runtime<T: Send + 'static>(
28            task: impl Future<Output = T> + Send + 'static,
29        ) -> tokio::task::JoinHandle<T> {
30            match tokio::runtime::Handle::try_current() {
31                Ok(handle) => handle.spawn(task),
32                Err(_) => DEFAULT_TOKIO_RUNTIME.spawn(task)
33            }
34        }
35    };
36
37    TokenStream::from(expanded)
38}
39
40struct MacroArgs {
41    runtime: String,
42}
43
44impl Parse for MacroArgs {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        let mut runtime = String::new();
47
48        let vars = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
49
50        for var in vars {
51            if let syn::Meta::NameValue(name_value) = var {
52                let name = name_value.path.get_ident().unwrap().to_string();
53
54                if name == "runtime" {
55                    if let syn::Expr::Lit(lit) = &name_value.value {
56                        if let syn::Lit::Str(lit_str) = &lit.lit {
57                            runtime = lit_str.value();
58                        }
59                    }
60                }
61            }
62        }
63
64        Ok(MacroArgs { runtime })
65    }
66}
67
68#[proc_macro_attribute]
69pub fn node(attr: TokenStream, item: TokenStream) -> TokenStream {
70    let mut impl_block = parse_macro_input!(item as ItemImpl);
71
72    let args = parse_macro_input!(attr as MacroArgs);
73    let runtime_tokens = args.runtime.parse::<proc_macro2::TokenStream>().unwrap();
74
75    for item in &mut impl_block.items {
76        if let ImplItem::Fn(method) = item {
77            let was_async = method.sig.asyncness.is_some();
78            method.sig.asyncness = None;
79
80            let old_block = method.block.clone();
81
82            if was_async {
83                let old_return_type = match &method.sig.output {
84                    ReturnType::Default => quote! { () },
85                    ReturnType::Type(_, ty) => quote! { #ty },
86                };
87
88                method.sig.output = syn::parse_quote! {
89                    -> tokio::task::JoinHandle<#old_return_type>
90                };
91
92                method.block = syn::parse_quote! {
93                    {
94                        #runtime_tokens(async move {
95                            #old_block
96                        })
97                    }
98                };
99            } else {
100                panic!("Function is not async");
101            }
102        }
103    }
104
105    quote! {
106        #impl_block
107    }
108    .into()
109}