use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Ident, ItemFn};
enum SpawnKind {
Async,
Sync,
}
struct TaskAttr {
kind: SpawnKind,
data_args: Vec<Ident>,
tag_args: Vec<Ident>,
name_override: Option<String>,
}
impl TaskAttr {
fn parse(attr: TokenStream) -> syn::Result<Self> {
let mut kind = SpawnKind::Async;
let mut data_args = Vec::new();
let mut tag_args = Vec::new();
let mut name_override = None;
if attr.is_empty() {
return Ok(Self {
kind,
data_args,
tag_args,
name_override,
});
}
let parsed = syn::parse::Parser::parse(
|input: syn::parse::ParseStream| {
let items =
syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(
input,
)?;
Ok(items)
},
attr,
)?;
for meta in parsed {
match &meta {
syn::Meta::Path(path) => {
if path.is_ident("sync") {
kind = SpawnKind::Sync;
} else {
return Err(syn::Error::new_spanned(path, "expected `sync`"));
}
}
syn::Meta::List(list) if list.path.is_ident("data") => {
let idents = list.parse_args_with(
syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated,
)?;
data_args = idents.into_iter().collect();
}
syn::Meta::List(list) if list.path.is_ident("tags") => {
let idents = list.parse_args_with(
syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated,
)?;
tag_args = idents.into_iter().collect();
}
syn::Meta::NameValue(nv) if nv.path.is_ident("name") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(s),
..
}) = &nv.value
{
name_override = Some(s.value());
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"expected a string literal",
));
}
}
_ => {
return Err(syn::Error::new_spanned(
&meta,
"unexpected attribute, expected `sync`, \
`data(...)`, `tags(...)`, or `name = \"...\"`",
));
}
}
}
Ok(Self {
kind,
data_args,
tag_args,
name_override,
})
}
}
fn is_task_type(ty: &syn::Type) -> bool {
let mut ty = ty;
while let syn::Type::Reference(r) = ty {
ty = &r.elem;
}
if let syn::Type::Path(type_path) = ty {
if let Some(seg) = type_path.path.segments.last() {
return seg.ident == "Task";
}
}
false
}
fn find_task_param(sig: &syn::Signature) -> syn::Result<&Ident> {
let mut found: Option<&Ident> = None;
for param in sig.inputs.iter() {
if let syn::FnArg::Typed(pat_type) = param {
if is_task_type(&pat_type.ty) {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
if found.is_some() {
return Err(syn::Error::new_spanned(
param,
"multiple Task parameters found; #[task] requires exactly one",
));
}
found = Some(&pat_ident.ident);
}
}
}
}
found.ok_or_else(|| {
syn::Error::new_spanned(
sig,
"#[task] requires a parameter whose type is `Task` (e.g. `task: &Task`)",
)
})
}
#[proc_macro_attribute]
pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
let task_attr = match TaskAttr::parse(attr) {
Ok(a) => a,
Err(e) => return e.to_compile_error().into(),
};
let mut func = parse_macro_input!(item as ItemFn);
let is_async = func.sig.asyncness.is_some();
match (&task_attr.kind, is_async) {
(SpawnKind::Sync, true) => {
return syn::Error::new_spanned(
func.sig.fn_token,
"#[task(sync)] requires a non-async `fn`; remove `async` or use #[task]",
)
.to_compile_error()
.into();
}
(SpawnKind::Async, false) => {
return syn::Error::new_spanned(
func.sig.fn_token,
"#[task] requires `async fn`; use #[task(sync)] for synchronous functions",
)
.to_compile_error()
.into();
}
_ => {}
}
let task_ident = match find_task_param(&func.sig) {
Ok(ident) => ident.clone(),
Err(e) => return e.to_compile_error().into(),
};
let param_names: Vec<Ident> = func
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(pat_ident.ident.clone());
}
}
None
})
.collect();
for data_arg in &task_attr.data_args {
if !param_names.contains(data_arg) {
return syn::Error::new_spanned(
data_arg,
format!("data arg `{data_arg}` is not a parameter of this function"),
)
.to_compile_error()
.into();
}
if *data_arg == task_ident {
return syn::Error::new_spanned(data_arg, "cannot log the task parameter as data")
.to_compile_error()
.into();
}
}
let mut task_name = task_attr
.name_override
.unwrap_or_else(|| func.sig.ident.to_string());
for tag in &task_attr.tag_args {
task_name.push_str(&format!(" #{tag}"));
}
let data_stmts: Vec<_> = task_attr
.data_args
.iter()
.map(|arg| {
let arg_str = arg.to_string();
quote! { #task_ident.data(#arg_str, #arg); }
})
.collect();
let body = &func.block;
let new_body: syn::Block = match task_attr.kind {
SpawnKind::Async => {
syn::parse_quote!({
#task_ident.spawn(#task_name, move |#task_ident| async move {
#(#data_stmts)*
#body
}).await
})
}
SpawnKind::Sync => {
syn::parse_quote!({
#task_ident.spawn_sync(#task_name, move |#task_ident| {
#(#data_stmts)*
#body
})
})
}
};
*func.block = new_body;
quote!(#func).into()
}