iridis_api_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    punctuated::Punctuated,
9    DeriveInput, ImplItem, ItemImpl, ReturnType, Token,
10};
11
12#[proc_macro_derive(Node)]
13pub fn derive_node(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = input.ident;
16
17    let expanded = quote! {
18        #[cfg(feature = "cdylib")]
19        #[doc(hidden)]
20        #[unsafe(no_mangle)]
21        pub static iridis_NODE: iridis_api::prelude::DynamicallyLinkedNodeInstance = |inputs, outputs, queries, queryables, configuration| {
22            <#name>::new(inputs, outputs, queries, queryables, configuration)
23        };
24
25        static DEFAULT_TOKIO_RUNTIME: std::sync::LazyLock<iridis_api::prelude::thirdparty::tokio::runtime::Runtime> =
26            std::sync::LazyLock::new(|| iridis_api::prelude::thirdparty::tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime"));
27
28        fn default_runtime<T: Send + 'static>(
29            task: impl Future<Output = T> + Send + 'static,
30        ) -> iridis_api::prelude::thirdparty::tokio::task::JoinHandle<T> {
31            match iridis_api::prelude::thirdparty::tokio::runtime::Handle::try_current() {
32                Ok(handle) => handle.spawn(task),
33                Err(_) => DEFAULT_TOKIO_RUNTIME.spawn(task)
34            }
35        }
36    };
37
38    TokenStream::from(expanded)
39}
40
41struct MacroArgs {
42    runtime: String,
43}
44
45impl Parse for MacroArgs {
46    fn parse(input: ParseStream) -> syn::Result<Self> {
47        let mut runtime = String::new();
48
49        let vars = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
50
51        for var in vars {
52            if let syn::Meta::NameValue(name_value) = var {
53                let name = name_value.path.get_ident().unwrap().to_string();
54
55                if name == "runtime" {
56                    if let syn::Expr::Lit(lit) = &name_value.value {
57                        if let syn::Lit::Str(lit_str) = &lit.lit {
58                            runtime = lit_str.value();
59                        }
60                    }
61                }
62            }
63        }
64
65        Ok(MacroArgs { runtime })
66    }
67}
68
69#[proc_macro_attribute]
70pub fn node(attr: TokenStream, item: TokenStream) -> TokenStream {
71    let mut impl_block = parse_macro_input!(item as ItemImpl);
72
73    let args = parse_macro_input!(attr as MacroArgs);
74    let runtime_tokens = args.runtime.parse::<proc_macro2::TokenStream>().unwrap();
75
76    for item in &mut impl_block.items {
77        if let ImplItem::Fn(method) = item {
78            let was_async = method.sig.asyncness.is_some();
79            method.sig.asyncness = None;
80
81            let old_block = method.block.clone();
82
83            if was_async {
84                let old_return_type = match &method.sig.output {
85                    ReturnType::Default => quote! { () },
86                    ReturnType::Type(_, ty) => {
87                        if method.sig.ident == "new" {
88                            quote! { iridis_api::prelude::thirdparty::eyre::Result<Box<dyn Node>> }
89                        } else {
90                            quote! { #ty }
91                        }
92                    }
93                };
94
95                method.sig.output = syn::parse_quote! {
96                    -> tokio::task::JoinHandle<#old_return_type>
97                };
98
99                if method.sig.ident == "new" {
100                    method.block = syn::parse_quote! {
101                        {
102                            #runtime_tokens(async move {
103                                #old_block.map(|node| Box::new(node) as Box<dyn Node>)
104                            })
105                        }
106                    };
107                } else {
108                    method.block = syn::parse_quote! {
109                        {
110                            #runtime_tokens(async move {
111                                #old_block
112                            })
113                        }
114                    };
115                }
116            } else {
117                panic!("Function is not async");
118            }
119        }
120    }
121
122    quote! {
123        #impl_block
124    }
125    .into()
126}