mpc-macros 0.2.15

Arcium MPC Macros
Documentation
/*!
# TaskGetters Derive Macro Implementation

The `TaskGetters` derive macro automatically generates `try_get_*_task` functions for each variant
in a `TaskType` enum. This eliminates the need for repetitive boilerplate code when working with
task retrieval functions.

## Usage

```rust
use macros::TaskGetters;

#[derive(TaskGetters)]
pub enum TaskType<C: Curve> {
    ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
    BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
    PointShareTask(Arc<dyn Task<Output = Arc<PointShare<C>>>>),
    // ... more variants
}
```

This generates functions like:
- `try_get_scalar_share_task`
- `try_get_base_field_share_task`
- `try_get_point_share_task`

Each function has the signature:
```rust
pub fn try_get_<variant_name>_task<C: Curve>(
    label: Label,
    task_map: &[TaskType<C>],
) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
```

## Requirements

- The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
- The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
- Each variant must have a corresponding `try_as_*` method (typically generated by `EnumTryAsInner`)

## Implementation Details

The macro performs the following steps:
1. Validates that the input is an enum
2. For each enum variant:
   - Converts variant name from PascalCase to snake_case
   - Extracts the inner type from the Arc<dyn Task<Output = Arc<T>>> pattern
   - Generates a function with the appropriate signature
3. Uses proper generic handling with split_for_impl() for complex generic bounds

## Error Handling

The macro will produce compile-time errors for:
- Non-enum types
- Malformed variant types that don't match the expected pattern
*/

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};

/// Generates the TaskGetters derive macro implementation.
///
/// This function takes a TokenStream representing an enum and generates getter functions
/// for each variant that follows the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`.
///
/// # Arguments
/// * `input` - The TokenStream from the derive macro containing the enum definition
///
/// # Returns
/// A TokenStream containing the generated getter functions
pub fn derive_task_getters(input: TokenStream) -> TokenStream {
    // Parse the input token stream into a DeriveInput struct for analysis
    let input = parse_macro_input!(input as DeriveInput);
    let enum_name = &input.ident;
    let generics = &input.generics;

    // Validate that the input is an enum - return compile error if not
    let Data::Enum(data_enum) = &input.data else {
        return syn::Error::new_spanned(&input.ident, "TaskGetters can only be derived for enums")
            .to_compile_error()
            .into();
    };

    // Vector to collect all generated function token streams
    let mut functions = Vec::new();

    // Iterate through each variant in the enum to generate corresponding getter functions
    for variant in &data_enum.variants {
        let variant_name = &variant.ident;
        // Convert PascalCase variant name to snake_case for function naming
        let variant_snake = to_snake_case(&variant_name.to_string());

        // Process only variants with unnamed fields (tuple-like variants)
        // Extract the Task output type from Arc<dyn Task<Output = Arc<T>>>
        if let Fields::Unnamed(fields) = &variant.fields {
            if let Some(field) = fields.unnamed.first() {
                // Attempt to extract the inner type T from the Arc<dyn Task<Output = Arc<T>>>
                // pattern
                if let Some(output_type) = extract_task_output_type(&field.ty) {
                    // Generate function name: try_get_<variant_name_in_snake_case>_task
                    let fn_name =
                        syn::Ident::new(&format!("try_get_{variant_snake}"), variant_name.span());
                    // Generate corresponding try_as method name for enum variant access
                    let try_as_method =
                        syn::Ident::new(&format!("try_as_{variant_snake}"), variant_name.span());

                    // Split generics for proper impl block generation
                    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

                    // Generate the getter function using quote! macro
                    functions.push(quote! {
                        pub fn #fn_name #impl_generics (
                            label: core_utils::types::Label,
                            task_map: &[#enum_name #ty_generics],
                        ) -> Result<std::sync::Arc<dyn crate::tasks::Task<Output = std::sync::Arc<#output_type>>>, crate::protocol::ProtocolError>
                        #where_clause
                        {
                            Ok(task_map
                                // Get the task at the specified label index
                                .get(label.0 as usize)
                                // Return error if index is out of bounds
                                .ok_or(crate::protocol::ProtocolError::UnorderedCircuit)?
                                // Try to cast to the specific task variant
                                .#try_as_method()?
                                // Clone the Arc to return ownership
                                .clone())
                        }
                    });
                }
            }
        }
    }

    // Combine all generated functions into a single token stream
    let expanded = quote! {
        #(#functions)*
    };

    TokenStream::from(expanded)
}

/// Converts a PascalCase string to snake_case.
///
/// This function transforms enum variant names from PascalCase (e.g., "ScalarShareTask")
/// to snake_case (e.g., "scalar_share_task") for use in generated function names.
///
/// # Arguments
/// * `s` - The PascalCase string to convert
///
/// # Returns
/// A String in snake_case format
///
/// # Examples
/// * "ScalarShareTask" -> "scalar_share_task"
/// * "BaseFieldShare" -> "base_field_share"
fn to_snake_case(s: &str) -> String {
    let mut result = String::new();
    let mut chars = s.chars().peekable();

    while let Some(ch) = chars.next() {
        if ch.is_uppercase() {
            // Add underscore before uppercase letters if:
            // 1. Result is not empty (not the first character)
            // 2. Next character is lowercase (handles cases like "XMLHttp" -> "xml_http")
            if !result.is_empty() && chars.peek().is_some_and(|c| c.is_lowercase()) {
                result.push('_');
            }
            // Convert uppercase to lowercase and add to result
            result.push(ch.to_lowercase().next().unwrap());
        } else {
            // Add lowercase characters directly
            result.push(ch);
        }
    }

    result
}

/// Extracts the inner type T from Arc<dyn Task<Output = Arc<T>>> pattern.
///
/// This function performs deep type analysis to extract the final inner type from
/// the complex nested generic structure used in task definitions. It navigates through
/// multiple layers of type information to find the actual data type.
///
/// # Arguments
/// * `ty` - The Type to analyze (should be Arc<dyn Task<Output = Arc<T>>>)
///
/// # Returns
/// * `Some(&Type)` - Reference to the inner type T if extraction succeeds
/// * `None` - If the type doesn't match the expected pattern
///
/// # Type Navigation Path
/// The function follows this path through the type structure:
/// 1. `Arc<...>` -> Extract the trait object inside Arc
/// 2. `dyn Task<...>` -> Find the Task trait bound
/// 3. `Task<Output = ...>` -> Extract the Output associated type
/// 4. `Arc<T>` -> Extract the final inner type T
fn extract_task_output_type(ty: &Type) -> Option<&Type> {
    // Step 1: Ensure we have a path type (like Arc<...>)
    if let Type::Path(type_path) = ty {
        if let Some(segment) = type_path.path.segments.first() {
            // Step 2: Verify the outermost type is Arc
            if segment.ident == "Arc" {
                if let PathArguments::AngleBracketed(args) = &segment.arguments {
                    // Step 3: Extract the trait object inside Arc (should be dyn Task<...>)
                    if let Some(GenericArgument::Type(Type::TraitObject(trait_obj))) =
                        args.args.first()
                    {
                        // Step 4: Iterate through trait bounds to find Task trait
                        for bound in &trait_obj.bounds {
                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
                                // Step 5: Check if this bound is the Task trait
                                if trait_bound.path.segments.last().unwrap().ident == "Task" {
                                    if let PathArguments::AngleBracketed(task_args) =
                                        &trait_bound.path.segments.last().unwrap().arguments
                                    {
                                        // Step 6: Look for the Output associated type
                                        for arg in &task_args.args {
                                            if let GenericArgument::AssocType(assoc_type) = arg {
                                                if assoc_type.ident == "Output" {
                                                    if let Type::Path(output_path) = &assoc_type.ty
                                                    {
                                                        // Step 7: Extract T from Arc<T> in the
                                                        // Output type
                                                        if let Some(arc_segment) =
                                                            output_path.path.segments.first()
                                                        {
                                                            if arc_segment.ident == "Arc" {
                                                                if let PathArguments::AngleBracketed(arc_args) = &arc_segment.arguments {
                                                                    if let Some(GenericArgument::Type(inner_type)) = arc_args.args.first() {
                                                                        // Step 8: Return the final inner type T
                                                                        return Some(inner_type);
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    // Return None if type doesn't match expected pattern
    None
}