use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
pub fn derive_task_getters(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let enum_name = &input.ident;
let generics = &input.generics;
let Data::Enum(data_enum) = &input.data else {
return syn::Error::new_spanned(&input.ident, "TaskGetters can only be derived for enums")
.to_compile_error()
.into();
};
let mut functions = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
let variant_snake = to_snake_case(&variant_name.to_string());
if let Fields::Unnamed(fields) = &variant.fields {
if let Some(field) = fields.unnamed.first() {
if let Some(output_type) = extract_task_output_type(&field.ty) {
let fn_name =
syn::Ident::new(&format!("try_get_{variant_snake}"), variant_name.span());
let try_as_method =
syn::Ident::new(&format!("try_as_{variant_snake}"), variant_name.span());
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
functions.push(quote! {
pub fn #fn_name #impl_generics (
label: core_utils::types::Label,
task_map: &[#enum_name #ty_generics],
) -> Result<std::sync::Arc<dyn crate::tasks::Task<Output = std::sync::Arc<#output_type>>>, crate::protocol::ProtocolError>
#where_clause
{
Ok(task_map
.get(label.0 as usize)
.ok_or(crate::protocol::ProtocolError::UnorderedCircuit)?
.#try_as_method()?
.clone())
}
});
}
}
}
}
let expanded = quote! {
#(#functions)*
};
TokenStream::from(expanded)
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
let mut chars = s.chars().peekable();
while let Some(ch) = chars.next() {
if ch.is_uppercase() {
if !result.is_empty() && chars.peek().is_some_and(|c| c.is_lowercase()) {
result.push('_');
}
result.push(ch.to_lowercase().next().unwrap());
} else {
result.push(ch);
}
}
result
}
fn extract_task_output_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first() {
if segment.ident == "Arc" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(Type::TraitObject(trait_obj))) =
args.args.first()
{
for bound in &trait_obj.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if trait_bound.path.segments.last().unwrap().ident == "Task" {
if let PathArguments::AngleBracketed(task_args) =
&trait_bound.path.segments.last().unwrap().arguments
{
for arg in &task_args.args {
if let GenericArgument::AssocType(assoc_type) = arg {
if assoc_type.ident == "Output" {
if let Type::Path(output_path) = &assoc_type.ty
{
if let Some(arc_segment) =
output_path.path.segments.first()
{
if arc_segment.ident == "Arc" {
if let PathArguments::AngleBracketed(arc_args) = &arc_segment.arguments {
if let Some(GenericArgument::Type(inner_type)) = arc_args.args.first() {
return Some(inner_type);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
None
}