use darling::{ast::NestedMeta, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, FnArg, ItemFn, Pat, ReturnType};
#[derive(Debug, FromMeta, Default)]
struct WorkerArgs {
#[darling(default)]
name: Option<String>,
#[darling(default)]
poll_interval: Option<u64>,
#[darling(default)]
thread_count: Option<usize>,
#[darling(default)]
domain: Option<String>,
#[darling(default)]
identity: Option<String>,
}
#[proc_macro_attribute]
pub fn worker(args: TokenStream, input: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(args.into()) {
Ok(v) => v,
Err(e) => {
return TokenStream::from(darling::Error::from(e).write_errors());
}
};
let args = match WorkerArgs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => {
return TokenStream::from(e.write_errors());
}
};
let input_fn = parse_macro_input!(input as ItemFn);
match generate_worker(args, input_fn) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_worker(args: WorkerArgs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_block = &input_fn.block;
let fn_inputs = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
if input_fn.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&input_fn.sig,
"worker function must be async",
));
}
let task_name = args.name.unwrap_or_else(|| fn_name.to_string());
let worker_fn_name = format_ident!("{}_worker", fn_name);
let poll_interval = args.poll_interval.unwrap_or(100);
let thread_count = args.thread_count.unwrap_or(1);
let mut param_extractions = Vec::new();
let mut fn_args = Vec::new();
let mut has_task_param = false;
let mut has_context_param = false;
for arg in fn_inputs {
match arg {
FnArg::Receiver(_) => {
return Err(syn::Error::new_spanned(
arg,
"worker functions cannot have self parameter",
));
}
FnArg::Typed(pat_type) => {
let name = match &*pat_type.pat {
Pat::Ident(ident) => ident.ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected simple identifier pattern",
));
}
};
let ty = &pat_type.ty;
let ty_str = quote!(#ty).to_string().replace(' ', "");
if ty_str.contains("TaskContext") {
has_context_param = true;
fn_args.push(quote! { __ctx });
} else if ty_str.contains("Task") {
has_task_param = true;
fn_args.push(quote! { task.clone() });
} else {
let name_str = name.to_string();
param_extractions.push(quote! {
let #name: #ty = task.get_input(#name_str).unwrap_or_default();
});
fn_args.push(quote! { #name });
}
}
}
}
let context_extraction = if has_context_param {
quote! {
let __ctx = task.context();
}
} else {
quote! {}
};
let return_handling = match fn_output {
ReturnType::Default => {
quote! {
Ok(::conductor::worker::WorkerOutput::completed_with_result(()))
}
}
ReturnType::Type(_, ty) => {
let ty_str = quote!(#ty).to_string().replace(' ', "");
if ty_str.contains("WorkerOutput") {
quote! { Ok(result) }
} else if ty_str.starts_with("Result<") || ty_str.contains("::Result<") {
quote! {
match result {
Ok(value) => Ok(::conductor::worker::WorkerOutput::completed_with_result(value)),
Err(e) => Ok(::conductor::worker::WorkerOutput::failed(format!("{}", e))),
}
}
} else {
quote! {
Ok(::conductor::worker::WorkerOutput::completed_with_result(result))
}
}
}
};
let domain_config = if let Some(domain) = &args.domain {
quote! { .with_domain(#domain) }
} else {
quote! {}
};
let identity_config = if let Some(identity) = &args.identity {
quote! { .with_identity(#identity) }
} else {
quote! {}
};
let async_body = if has_task_param || has_context_param {
quote! {
#context_extraction
#(#param_extractions)*
let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
#return_handling
}
} else if !fn_args.is_empty() {
quote! {
#(#param_extractions)*
let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
#return_handling
}
} else {
quote! {
#(#param_extractions)*
let result = (|| async move #fn_block)().await;
#return_handling
}
};
let output = quote! {
#fn_vis fn #worker_fn_name() -> ::conductor::worker::FnWorker {
::conductor::worker::FnWorker::new(#task_name, |task: ::conductor::models::Task| async move {
#async_body
})
.with_poll_interval_millis(#poll_interval)
.with_thread_count(#thread_count)
#domain_config
#identity_config
}
};
Ok(output)
}
#[proc_macro_attribute]
pub fn worker_task(args: TokenStream, input: TokenStream) -> TokenStream {
worker(args, input)
}