statum-macros 0.8.8

Proc macros for representing legal workflow and protocol states explicitly in Rust
Documentation
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use std::collections::HashSet;
use syn::spanned::Spanned;
use syn::{FnArg, GenericArgument, Ident, Pat, PathArguments, ReturnType, Type};

use crate::diagnostics::{DiagnosticMessage, compact_display};
use crate::source::{candidate_alias_resolution_contexts, expand_source_type_alias};

use super::contract::{
    ValidatorMethodContract, ValidatorReturnKind, VariantSpec,
    build_validator_method_contract as build_semantic_validator_method_contract,
};
use super::type_equivalence::types_equivalent;

pub(super) struct ValidatorDiagnosticContext<'a> {
    pub(super) persisted_type_display: &'a str,
    pub(super) machine_name: &'a str,
    pub(super) state_enum_name: &'a str,
    pub(super) variant_name: &'a str,
    pub(super) machine_fields: &'a [Ident],
    pub(super) expected_ok_type: &'a Type,
}

pub(super) fn validator_state_name_from_ident(ident: &Ident) -> Option<String> {
    ident
        .to_string()
        .strip_prefix("is_")
        .map(std::borrow::ToOwned::to_owned)
}

pub(super) fn build_validator_method_contract(
    func: &syn::ImplItemFn,
    spec: &VariantSpec,
    context: &ValidatorDiagnosticContext<'_>,
) -> Result<ValidatorMethodContract, proc_macro2::TokenStream> {
    validate_validator_signature(func, context)?;
    let return_kind = validate_validator_return_contract(func, context.expected_ok_type, context)?;
    Ok(build_semantic_validator_method_contract(func, spec, return_kind))
}

fn validate_validator_signature(
    func: &syn::ImplItemFn,
    context: &ValidatorDiagnosticContext<'_>,
) -> Result<(), proc_macro2::TokenStream> {
    if func.sig.inputs.len() != 1 {
        let collision_line = explicit_param_collision_line(&func.sig.inputs, context.machine_fields);
        let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
        let message = DiagnosticMessage::new(format!(
            "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must declare only `&self`.",
            func.sig.ident,
            context.persisted_type_display,
            context.machine_name,
            context.state_enum_name,
            context.variant_name,
        ))
        .found(format!("`fn {}({})`", func.sig.ident, compact_display(&func.sig.inputs)))
        .expected(expected_signature.clone())
        .reason(collision_line.unwrap_or_else(|| {
            "validator methods do not accept explicit machine-field parameters.".to_string()
        }))
        .note(injected_machine_fields_line(context.machine_name, context.machine_fields))
        .fix("remove explicit parameters and read injected machine fields by bare name inside the validator body.".to_string())
        .render();
        let error = if let Some(extra_input) = func.sig.inputs.iter().nth(1) {
            syn::Error::new_spanned(extra_input, message)
        } else {
            syn::Error::new_spanned(&func.sig.inputs, message)
        };
        return Err(error.to_compile_error());
    }
    match &func.sig.inputs[0] {
        FnArg::Receiver(receiver) => {
            if receiver.reference.is_none() || receiver.mutability.is_some() {
                let receiver_display = receiver.to_token_stream().to_string();
                let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
                let message = DiagnosticMessage::new(format!(
                    "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must take `&self`, not `{}`.",
                    func.sig.ident,
                    context.persisted_type_display,
                    context.machine_name,
                    context.state_enum_name,
                    context.variant_name,
                    receiver_display,
                ))
                .found(format!("`fn {}({receiver_display})`", func.sig.ident))
                .expected(expected_signature)
                .note(injected_machine_fields_line(context.machine_name, context.machine_fields))
                .fix("change the receiver to `&self`.".to_string())
                .render();
                return Err(syn::Error::new_spanned(receiver, message).to_compile_error());
            }
        }
        FnArg::Typed(_) => {
            let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
            let message = DiagnosticMessage::new(format!(
                "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must take `&self` as its receiver.",
                func.sig.ident,
                context.persisted_type_display,
                context.machine_name,
                context.state_enum_name,
                context.variant_name,
            ))
            .found(format!("`fn {}({})`", func.sig.ident, compact_display(&func.sig.inputs[0])))
            .expected(expected_signature)
            .note(injected_machine_fields_line(context.machine_name, context.machine_fields))
            .fix("rewrite the method to take `&self` and no other parameters.".to_string())
            .render();
            return Err(syn::Error::new_spanned(&func.sig.inputs[0], message).to_compile_error());
        }
    }
    Ok(())
}

fn validate_validator_return_contract(
    func: &syn::ImplItemFn,
    expected_ok_type: &Type,
    context: &ValidatorDiagnosticContext<'_>,
) -> Result<ValidatorReturnKind, TokenStream> {
    let ReturnType::Type(_, return_ty) = &func.sig.output else {
        let expected_ok_display = expected_ok_type.to_token_stream().to_string();
        let message = DiagnosticMessage::new(format!(
            "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return `Result<{}, _>` or `Validation<{}>`.",
            func.sig.ident,
            context.persisted_type_display,
            context.machine_name,
            context.state_enum_name,
            context.variant_name,
            expected_ok_display,
            expected_ok_display,
        ))
        .expected(expected_validator_signature(&func.sig.ident, expected_ok_type))
        .reason(expected_state_shape(
            context.state_enum_name,
            context.variant_name,
            &expected_ok_display,
        ))
        .fix("add an explicit validator return type.".to_string())
        .render();
        return Err(syn::Error::new_spanned(&func.sig.output, message).to_compile_error());
    };

    let (actual_ok_ty, return_kind) =
        match extract_supported_validator_ok_type(return_ty, return_ty.span()) {
        Some(info) => info,
        None => {
            let expected_ok_display = expected_ok_type.to_token_stream().to_string();
            let message = DiagnosticMessage::new(format!(
                "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return a supported validator result whose payload is `{}`.",
                func.sig.ident,
                context.persisted_type_display,
                context.machine_name,
                context.state_enum_name,
                context.variant_name,
                expected_ok_display,
            ))
            .found(format!("`{}`", compact_display(return_ty)))
            .expected(format!(
                "`Result<{expected_ok_display}, _>` or `Validation<{expected_ok_display}>`"
            ))
            .note(supported_validator_wrapper_note())
            .fix("rewrite the return type to use one of the supported validator result wrappers.".to_string())
            .render();
            return Err(syn::Error::new_spanned(return_ty, message).to_compile_error());
        }
    };

    if !types_equivalent(&actual_ok_ty, expected_ok_type) {
        let expected_ok_display = expected_ok_type.to_token_stream().to_string();
        let actual_return_type = return_ty.to_token_stream().to_string();
        let actual_ok_display = actual_ok_ty.to_token_stream().to_string();
        let message = DiagnosticMessage::new(format!(
            "validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return `Result<{}, _>` or `Validation<{}>` (or an equivalent supported alias).",
            func.sig.ident,
            context.persisted_type_display,
            context.machine_name,
            context.state_enum_name,
            context.variant_name,
            expected_ok_display,
            expected_ok_display,
        ))
        .found(format!(
            "`{actual_return_type}` with payload `{actual_ok_display}`"
        ))
        .expected(format!(
            "`Result<{expected_ok_display}, _>` or `Validation<{expected_ok_display}>`"
        ))
        .fix(format!(
            "change the validator to return `{expected_ok_display}` for `{}::{}`.",
            context.state_enum_name, context.variant_name
        ))
        .render();
        return Err(syn::Error::new_spanned(return_ty, message).to_compile_error());
    }

    Ok(return_kind)
}

fn injected_machine_fields_line(machine_name: &str, machine_fields: &[Ident]) -> String {
    if machine_fields.is_empty() {
        format!(
            "Machine `{machine_name}` has no user-defined fields to inject, so validator methods should not take any extra parameters."
        )
    } else {
        let injected = machine_fields
            .iter()
            .map(|field| format!("`{field}`"))
            .collect::<Vec<_>>()
            .join(", ");
        format!(
            "Machine `{machine_name}` injects these fields by bare name inside validator bodies: {injected}. Remove explicit parameters and use those bindings directly."
        )
    }
}

fn expected_validator_signature(func_ident: &Ident, expected_ok_type: &Type) -> String {
    format!(
        "`fn {func_ident}(&self) -> Result<{}, _>` or `fn {func_ident}(&self) -> Validation<{}>`",
        expected_ok_type.to_token_stream(),
        expected_ok_type.to_token_stream()
    )
}

fn explicit_param_collision_line(
    inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
    machine_fields: &[Ident],
) -> Option<String> {
    let collisions = inputs
        .iter()
        .skip(1)
        .filter_map(|arg| match arg {
            FnArg::Typed(arg) => match &*arg.pat {
                Pat::Ident(ident) if machine_fields.iter().any(|field| field == &ident.ident) => {
                    Some(ident.ident.to_string())
                }
                _ => None,
            },
            FnArg::Receiver(_) => None,
        })
        .collect::<Vec<_>>();

    if collisions.is_empty() {
        None
    } else {
        Some(format!(
            "Parameter name collision: {} {} with injected machine field {}.",
            collisions
                .iter()
                .map(|name| format!("`{name}`"))
                .collect::<Vec<_>>()
                .join(", "),
            if collisions.len() == 1 { "collides" } else { "collide" },
            if collisions.len() == 1 { "binding" } else { "bindings" }
        ))
    }
}

fn expected_state_shape(state_enum_name: &str, variant_name: &str, expected_ok_display: &str) -> String {
    if expected_ok_display == "()" {
        format!("`{state_enum_name}::{variant_name}` is a unit state")
    } else {
        format!("`{state_enum_name}::{variant_name}` carries `{expected_ok_display}`")
    }
}

fn extract_supported_validator_ok_type(
    return_ty: &Type,
    return_ty_span: Span,
) -> Option<(Type, ValidatorReturnKind)> {
    let contexts = candidate_alias_resolution_contexts(Some(return_ty_span));
    for context in &contexts {
        let mut visited = HashSet::new();
        if let Some(info) =
            extract_supported_validator_ok_type_in_context(return_ty, Some(context), &mut visited)
        {
            return Some(info);
        }
    }

    let mut visited = HashSet::new();
    extract_supported_validator_ok_type_in_context(return_ty, None, &mut visited)
}

fn extract_supported_validator_ok_type_in_context(
    return_ty: &Type,
    context: Option<&crate::source::AliasResolutionContext>,
    visited: &mut HashSet<String>,
) -> Option<(Type, ValidatorReturnKind)> {
    if let Some((expanded, alias_context, visit_key)) =
        expand_source_type_alias(return_ty, context, visited)
    {
        let expanded_result =
            extract_supported_validator_ok_type_in_context(&expanded, Some(&alias_context), visited);
        visited.remove(&visit_key);
        if expanded_result.is_some() {
            return expanded_result;
        }
    }

    direct_supported_validator_ok_type(return_ty)
}

fn direct_supported_validator_ok_type(return_ty: &Type) -> Option<(Type, ValidatorReturnKind)> {
    let Type::Path(type_path) = return_ty else {
        return None;
    };

    let last_segment = type_path.path.segments.last()?;
    if last_segment.ident == "Validation" && path_is_supported_validation_wrapper(&type_path.path) {
        let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
            return None;
        };

        let type_args: Vec<Type> = args
            .args
            .iter()
            .filter_map(|arg| match arg {
                GenericArgument::Type(ty) => Some(ty.clone()),
                _ => None,
            })
            .collect();

        if type_args.len() != 1 || type_args.len() != args.args.len() {
            return None;
        }

        return type_args
            .first()
            .cloned()
            .map(|ty| (ty, ValidatorReturnKind::Diagnostic));
    }

    if last_segment.ident != "Result" || !path_is_supported_result_wrapper(&type_path.path) {
        return None;
    }

    let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
        return None;
    };

    let type_args: Vec<Type> = args
        .args
        .iter()
        .filter_map(|arg| match arg {
            GenericArgument::Type(ty) => Some(ty.clone()),
            _ => None,
        })
        .collect();

    if type_args.is_empty() || type_args.len() > 2 || type_args.len() != args.args.len() {
        return None;
    }

    let return_kind = if type_args.get(1).is_some_and(is_rejection_type) {
        ValidatorReturnKind::Diagnostic
    } else {
        ValidatorReturnKind::Plain
    };

    type_args.first().cloned().map(|ty| (ty, return_kind))
}

fn supported_validator_wrapper_note() -> &'static str {
    "supported forms: `Result<T, E>`, `core::result::Result<T, E>`, `std::result::Result<T, E>`, `statum::Result<T>`, direct `Result<T, statum::Rejection>`, and `Validation<T>` / `statum::Validation<T>`, plus source-declared type aliases that expand to those wrappers."
}

fn path_is_supported_validation_wrapper(path: &syn::Path) -> bool {
    let segments = path
        .segments
        .iter()
        .map(|segment| segment.ident.to_string())
        .collect::<Vec<_>>();
    matches!(segments.as_slice(), [validation] if validation == "Validation")
        || segments.as_slice() == ["statum", "Validation"]
}

fn path_is_supported_result_wrapper(
    path: &syn::Path,
) -> bool {
    let segments = path
        .segments
        .iter()
        .map(|segment| segment.ident.to_string())
        .collect::<Vec<_>>();
    matches!(segments.as_slice(), [result] if result == "Result")
        || segments.as_slice() == ["statum", "Result"]
        || segments.as_slice() == ["core", "result", "Result"]
        || segments.as_slice() == ["std", "result", "Result"]
}

fn is_rejection_type(ty: &Type) -> bool {
    let Type::Path(type_path) = ty else {
        return false;
    };

    type_path
        .path
        .segments
        .last()
        .is_some_and(|segment| segment.ident == "Rejection")
}