iridis_node_derive/
lib.rs1extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    DeriveInput, ImplItem, ItemImpl, ReturnType, Token,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    punctuated::Punctuated,
12};
13
14#[proc_macro_derive(Node)]
17pub fn derive_node(input: TokenStream) -> TokenStream {
18    let input = parse_macro_input!(input as DeriveInput);
19    let name = input.ident;
20
21    let expanded = quote! {
22        #[cfg(feature = "cdylib")]
23        #[doc(hidden)]
24        #[unsafe(no_mangle)]
25        pub static IRIDIS_NODE: iridis_node::prelude::DynamicallyLinkedNodeInstance = |inputs, outputs, queries, queryables, configuration| {
26            <#name>::new(inputs, outputs, queries, queryables, configuration)
27        };
28
29        static DEFAULT_TOKIO_RUNTIME: std::sync::LazyLock<iridis_node::prelude::thirdparty::tokio::runtime::Runtime> =
30            std::sync::LazyLock::new(|| iridis_node::prelude::thirdparty::tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime"));
31
32        fn default_runtime<T: Send + 'static>(
33            task: impl Future<Output = T> + Send + 'static,
34        ) -> iridis_node::prelude::thirdparty::tokio::task::JoinHandle<T> {
35            match iridis_node::prelude::thirdparty::tokio::runtime::Handle::try_current() {
36                Ok(handle) => handle.spawn(task),
37                Err(_) => DEFAULT_TOKIO_RUNTIME.spawn(task)
38            }
39        }
40    };
41
42    TokenStream::from(expanded)
43}
44
45struct MacroArgs {
46    runtime: String,
47}
48
49impl Parse for MacroArgs {
50    fn parse(input: ParseStream) -> syn::Result<Self> {
51        let mut runtime = String::new();
52
53        let vars = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
54
55        for var in vars {
56            if let syn::Meta::NameValue(name_value) = var {
57                let name = name_value.path.get_ident().unwrap().to_string();
58
59                if name == "runtime" {
60                    if let syn::Expr::Lit(lit) = &name_value.value {
61                        if let syn::Lit::Str(lit_str) = &lit.lit {
62                            runtime = lit_str.value();
63                        }
64                    }
65                }
66            }
67        }
68
69        Ok(MacroArgs { runtime })
70    }
71}
72
73#[proc_macro_attribute]
94pub fn node(attr: TokenStream, item: TokenStream) -> TokenStream {
95    let mut impl_block = parse_macro_input!(item as ItemImpl);
96
97    let args = parse_macro_input!(attr as MacroArgs);
98    let runtime_tokens = args.runtime.parse::<proc_macro2::TokenStream>().unwrap();
99
100    for item in &mut impl_block.items {
101        if let ImplItem::Fn(method) = item {
102            let was_async = method.sig.asyncness.is_some();
103            method.sig.asyncness = None;
104
105            let old_block = method.block.clone();
106
107            if was_async {
108                let old_return_type = match &method.sig.output {
109                    ReturnType::Default => quote! { () },
110                    ReturnType::Type(_, ty) => {
111                        if method.sig.ident == "new" {
112                            quote! { iridis_node::prelude::thirdparty::eyre::Result<Box<dyn iridis_node::prelude::Node>> }
113                        } else {
114                            quote! { #ty }
115                        }
116                    }
117                };
118
119                method.sig.output = syn::parse_quote! {
120                    -> tokio::task::JoinHandle<#old_return_type>
121                };
122
123                if method.sig.ident == "new" {
124                    method.block = syn::parse_quote! {
125                        {
126                            #runtime_tokens(async move {
127                                #old_block.map(|node| Box::new(node) as Box<dyn iridis_node::prelude::Node>)
128                            })
129                        }
130                    };
131                } else {
132                    method.block = syn::parse_quote! {
133                        {
134                            #runtime_tokens(async move {
135                                #old_block
136                            })
137                        }
138                    };
139                }
140            } else {
141                panic!("Function is not async");
142            }
143        }
144    }
145
146    quote! {
147        #impl_block
148    }
149    .into()
150}