use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
Fields,
FieldsNamed,
GenericParam,
Generics,
ItemFn,
ItemStruct,
Type,
TypePath,
};
pub fn define_task(input: TokenStream) -> TokenStream {
let task_def = parse_macro_input!(input as TaskDefinition);
match expand_task_definition(task_def) {
Ok(result) => result.into(),
Err(err) => err.to_compile_error().into(),
}
}
struct TaskDefinition {
struct_def: ItemStruct,
compute_fn: ItemFn,
}
impl Parse for TaskDefinition {
fn parse(input: ParseStream) -> syn::Result<Self> {
let struct_def: ItemStruct = input.parse()?;
let compute_fn: ItemFn = input.parse()?;
Ok(TaskDefinition {
struct_def,
compute_fn,
})
}
}
fn expand_task_definition(task_def: TaskDefinition) -> syn::Result<TokenStream2> {
let TaskDefinition {
struct_def,
compute_fn,
} = task_def;
let output_type = extract_output_type_from_compute_fn(&compute_fn)?;
let struct_name = &struct_def.ident;
let generics = &struct_def.generics;
let visibility = &struct_def.vis;
let fields = match &struct_def.fields {
Fields::Named(fields) => fields,
_ => {
return Err(syn::Error::new(
struct_def.span(),
"Only named fields supported",
))
}
};
let network = has_network_fields(fields);
let unresolved_struct_name = format_ident!("Unresolved{}", struct_name);
let task_name = struct_name;
let wrapper_type = quote! { crate::tasks::TaskWrapper };
let arc_output_type: Type = syn::parse2(quote! { Arc<#output_type> })?;
let unresolved_struct = generate_unresolved_struct_new(
&unresolved_struct_name,
generics,
fields,
network,
visibility,
);
let type_alias = generate_type_alias_new(
task_name,
&unresolved_struct_name,
generics,
&wrapper_type,
&output_type, visibility,
);
let task_impl = generate_task_impl_new(
task_name,
&unresolved_struct_name,
generics,
fields,
&compute_fn,
&arc_output_type,
network,
)?;
let constructor = generate_constructor_new(
task_name,
&unresolved_struct_name,
generics,
fields,
network,
visibility,
);
let standalone_compute =
generate_standalone_compute_new(task_name, &compute_fn, generics, visibility);
Ok(quote! {
#unresolved_struct
#type_alias
#task_impl
#constructor
#standalone_compute
})
}
fn generate_unresolved_struct_new(
name: &Ident,
generics: &Generics,
fields: &FieldsNamed,
_network: bool,
visibility: &syn::Visibility,
) -> TokenStream2 {
let (impl_generics, _ty_generics, where_clause) = generics.split_for_impl();
let mut struct_fields = Vec::new();
for field in &fields.named {
let field_name = &field.ident;
let field_type = &field.ty;
struct_fields.push(quote! {
#field_name: #field_type,
});
}
quote! {
#visibility struct #name #impl_generics #where_clause {
#(#struct_fields)*
}
}
}
fn generate_type_alias_generics(generics: &Generics) -> (TokenStream2, TokenStream2) {
let mut type_params = Vec::new();
let mut type_args = Vec::new();
for param in &generics.params {
match param {
GenericParam::Type(type_param) => {
let ident = &type_param.ident;
let bounds = &type_param.bounds;
let default = &type_param.default;
if bounds.is_empty() && default.is_none() {
type_params.push(quote! { #ident });
} else if bounds.is_empty() && default.is_some() {
type_params.push(quote! { #ident = #default });
} else if default.is_none() {
type_params.push(quote! { #ident: #bounds });
} else {
type_params.push(quote! { #ident: #bounds = #default });
}
type_args.push(quote! { #ident });
}
GenericParam::Lifetime(lifetime_param) => {
let lifetime = &lifetime_param.lifetime;
let bounds = &lifetime_param.bounds;
if bounds.is_empty() {
type_params.push(quote! { #lifetime });
} else {
type_params.push(quote! { #lifetime: #bounds });
}
type_args.push(quote! { #lifetime });
}
GenericParam::Const(const_param) => {
let ident = &const_param.ident;
let ty = &const_param.ty;
let default = &const_param.default;
if default.is_none() {
type_params.push(quote! { const #ident: #ty });
} else {
type_params.push(quote! { const #ident: #ty = #default });
}
type_args.push(quote! { #ident });
}
}
}
let impl_generics = if type_params.is_empty() {
quote! {}
} else {
quote! { <#(#type_params),*> }
};
let ty_generics = if type_args.is_empty() {
quote! {}
} else {
quote! { <#(#type_args),*> }
};
(impl_generics, ty_generics)
}
fn generate_type_alias_new(
task_name: &Ident,
unresolved_name: &Ident,
generics: &Generics,
wrapper_type: &TokenStream2,
output_type: &Type,
visibility: &syn::Visibility,
) -> TokenStream2 {
let (impl_generics, ty_generics) = generate_type_alias_generics(generics);
quote! {
#[allow(type_alias_bounds)]
#visibility type #task_name #impl_generics = #wrapper_type<#unresolved_name #ty_generics, #output_type>;
}
}
fn generate_task_impl_new(
task_name: &Ident,
unresolved_name: &Ident,
generics: &Generics,
fields: &FieldsNamed,
compute_fn: &ItemFn,
output_type: &Type,
_network: bool,
) -> syn::Result<TokenStream2> {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut field_param_types = std::collections::HashMap::new();
for input in &compute_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let param_name = pat_ident.ident.clone();
let param_type = (*pat_type.ty).clone();
field_param_types.insert(param_name.to_string(), param_type);
}
}
}
let mut field_names = Vec::new();
for field in &fields.named {
let field_name = &field.ident;
if *field_name.as_ref().unwrap() == "label" {
field_names.push(quote! { mut #field_name });
} else {
field_names.push(quote! { #field_name });
}
}
let mut task_spawns = Vec::new();
let mut task_awaitable_names = Vec::new();
let mut vec_task_names = Vec::new();
let mut vec_future_names = Vec::new();
let mut single_future_names = Vec::new();
let mut compute_args = Vec::new();
for field in &fields.named {
let field_name = field.ident.as_ref().unwrap();
let field_type = &field.ty;
if is_vec_task_field(field_type) {
vec_task_names.push(quote! { #field_name });
if let Some(param_type) = field_param_types.get(&field_name.to_string()) {
if is_slice_reference_param(param_type) {
compute_args.push(quote! { &#field_name[..] });
} else {
compute_args.push(quote! { #field_name });
}
} else {
compute_args.push(quote! { #field_name });
}
} else if is_task_field_new(field_type) {
task_spawns.push(quote! {
let #field_name = tokio::task::spawn(async move { #field_name.execute().await });
});
task_awaitable_names.push(quote! { #field_name });
if let Some(param_type) = field_param_types.get(&field_name.to_string()) {
if is_slice_reference_param(param_type) {
compute_args.push(quote! { &std::sync::Arc::as_ref(&#field_name)[..] });
} else if is_reference_param(param_type) {
compute_args.push(quote! { std::sync::Arc::as_ref(&#field_name) });
} else {
compute_args.push(quote! { #field_name });
}
} else {
compute_args.push(quote! { #field_name });
}
} else if is_vec_future_field(field_type) {
vec_future_names.push(quote! { #field_name });
compute_args.push(quote! { #field_name });
} else if is_future_field_new(field_type) {
single_future_names.push(quote! { #field_name });
compute_args.push(quote! { #field_name });
} else if is_phantom_data_field(field_type) {
continue;
} else {
if *field_name == "label" {
compute_args.push(quote! { &mut #field_name });
} else {
compute_args.push(quote! { #field_name });
}
}
}
let compute_call = quote! {
let result = Self::compute(#(#compute_args),*).await?;
};
let mut awaiting_stmts = Vec::new();
for vec_task_var in &vec_task_names {
awaiting_stmts.push(quote! {
let #vec_task_var: Vec<_> = #vec_task_var.into_iter()
.map(|task| tokio::task::spawn(async move { task.execute().await }))
.collect();
let #vec_task_var: Result<Vec<_>, _> = futures::future::try_join_all(
#vec_task_var.into_iter().map(|handle| async move {
handle.await.map_err(Into::<core_utils::errors::AbortError>::into)
})
).await;
let #vec_task_var: Vec<_> = #vec_task_var?
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(std::sync::Arc::unwrap_or_clone)
.collect();
});
}
if !task_awaitable_names.is_empty() {
if task_awaitable_names.len() == 1 {
let task_var = &task_awaitable_names[0];
awaiting_stmts.push(quote! {
let #task_var = #task_var.await??;
});
} else {
awaiting_stmts.push(quote! {
let (#(#task_awaitable_names),*) = tokio::try_join!(
#(
futures::TryFutureExt::map_err(#task_awaitable_names, Into::<core_utils::errors::AbortError>::into)
),*
)?;
let (#(#task_awaitable_names),*) = (#(#task_awaitable_names?),*);
});
}
}
for vec_future_var in &vec_future_names {
awaiting_stmts.push(quote! {
let #vec_future_var = futures::future::try_join_all(#vec_future_var).await?;
});
}
for single_future_var in &single_future_names {
awaiting_stmts.push(quote! {
let #single_future_var = #single_future_var.await?;
});
}
let awaiting_and_unwrapping = quote! {
#(#awaiting_stmts)*
};
Ok(quote! {
#[async_trait::async_trait]
impl #impl_generics crate::tasks::Task for #task_name #ty_generics #where_clause {
type Output = #output_type;
async fn execute(&self) -> Result<Self::Output, core_utils::errors::AbortError> {
let mut state = self.state.lock().await;
if let Some(unresolved_state) = state.take_unresolved()? {
let #unresolved_name { #(#field_names),* } = unresolved_state;
#(#task_spawns)*
#awaiting_and_unwrapping
#compute_call
state.resolve(std::sync::Arc::new(result));
}
Ok(state.clone_output()?)
}
}
})
}
fn generate_constructor_new(
task_name: &Ident,
unresolved_name: &Ident,
generics: &Generics,
fields: &FieldsNamed,
_network: bool,
_visibility: &syn::Visibility,
) -> TokenStream2 {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut constructor_params = Vec::new();
let mut field_assignments = Vec::new();
for field in &fields.named {
let field_name = &field.ident;
let field_type = &field.ty;
if is_phantom_data_field(field_type) {
field_assignments.push(quote! { #field_name: Default::default() });
} else {
constructor_params.push(quote! { #field_name: #field_type });
field_assignments.push(quote! { #field_name });
}
}
quote! {
impl #impl_generics #task_name #ty_generics #where_clause {
pub fn new(#(#constructor_params),*) -> std::sync::Arc<Self> {
let unresolved = #unresolved_name {
#(#field_assignments),*
};
std::sync::Arc::new(Self {
state: tokio::sync::Mutex::new(crate::tasks::InternalTaskState::new(unresolved)),
})
}
}
}
}
fn generate_standalone_compute_new(
task_name: &Ident,
compute_fn: &ItemFn,
generics: &Generics,
_visibility: &syn::Visibility,
) -> TokenStream2 {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let sig = &compute_fn.sig;
let block = &compute_fn.block;
quote! {
impl #impl_generics #task_name #ty_generics #where_clause {
pub #sig #block
}
}
}
fn is_task_field_new(ty: &Type) -> bool {
if let Type::Path(TypePath { path, .. }) = ty {
if let Some(segment) = path.segments.last() {
if segment.ident == "Arc" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(Type::TraitObject(trait_obj)) = arg {
for bound in &trait_obj.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if trait_bound.path.segments.last().unwrap().ident == "Task" {
return true;
}
}
}
}
}
}
}
}
}
false
}
fn is_future_field_new(ty: &Type) -> bool {
if let Type::Path(TypePath { path, .. }) = ty {
for segment in &path.segments {
if segment.ident.to_string().contains("Next") {
return true;
}
if segment.ident == "Vec" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_type) = arg {
if is_future_field_new(inner_type) {
return true;
}
}
}
}
}
}
}
false
}
fn is_vec_future_field(ty: &Type) -> bool {
if let Type::Path(TypePath { path, .. }) = ty {
if let Some(segment) = path.segments.last() {
if segment.ident == "Vec" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_type) = arg {
if is_future_field_new(inner_type) {
return true;
}
}
}
}
}
}
}
false
}
fn is_vec_task_field(ty: &Type) -> bool {
if let Type::Path(TypePath { path, .. }) = ty {
if let Some(segment) = path.segments.last() {
if segment.ident == "Vec" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_type) = arg {
if is_task_field_new(inner_type) {
return true;
}
}
}
}
}
}
}
false
}
fn is_phantom_data_field(ty: &Type) -> bool {
if let Type::Path(TypePath { path, .. }) = ty {
if let Some(segment) = path.segments.last() {
return segment.ident == "PhantomData";
}
}
false
}
fn has_network_fields(fields: &FieldsNamed) -> bool {
let mut has_network_interface = false;
let mut has_label = false;
for field in &fields.named {
if let Some(field_name) = &field.ident {
match field_name.to_string().as_str() {
"network_interface" => has_network_interface = true,
"label" => has_label = true,
_ => {}
}
}
}
has_network_interface && has_label
}
fn is_reference_param(ty: &Type) -> bool {
matches!(ty, Type::Reference(_))
}
fn is_slice_reference_param(ty: &Type) -> bool {
if let Type::Reference(type_ref) = ty {
if let Type::Slice(_) = &*type_ref.elem {
return true;
}
}
false
}
fn extract_output_type_from_compute_fn(compute_fn: &ItemFn) -> syn::Result<Type> {
if let syn::ReturnType::Type(_, return_type) = &compute_fn.sig.output {
if let syn::Type::Path(type_path) = return_type.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(output_type)) = args.args.first() {
return Ok(output_type.clone());
}
}
}
}
}
}
Err(syn::Error::new(
compute_fn.sig.span(),
"Could not extract output type from compute function return type",
))
}