statum-macros 0.9.0

Proc macros for representing legal workflow and protocol states explicitly in Rust
Documentation
use proc_macro2::TokenStream;
use std::collections::HashSet;
use syn::{GenericArgument, PathArguments, Type};

use crate::contracts::TransitionContract;
use crate::diagnostics::compact_display;

use super::super::diagnostics::{
    invalid_introspect_return_error, invalid_return_type_error, mismatched_introspect_return_error,
};
use super::super::parse::TransitionFn;
use super::super::resolve::{
    AliasResolutionContext, SourceAliasResolver, SupportedWrapper,
    collect_machine_and_states_in_context, collect_machine_and_states_strict,
    expand_source_type_alias, extract_first_generic_type_ref, extract_generic_type_refs,
    machine_segment_matching_target, parse_machine_and_state_in_context,
    parse_primary_machine_and_state_strict, supported_wrapper, type_path,
};

pub(crate) fn build_transition_contract(
    func: &TransitionFn,
    target_type: &syn::Type,
) -> Result<TransitionContract, TokenStream> {
    let return_contract = validate_transition_return_contract(func, target_type)?;
    Ok(TransitionContract {
        primary_next_state: return_contract.primary_next_state,
        next_states: return_contract.next_states,
    })
}

pub(super) struct ObservedReturnShape {
    pub(super) primary_branch: Option<String>,
    pub(super) secondary_machine_branches: Vec<String>,
    pub(super) wrapper: Option<SupportedWrapper>,
    pub(super) canonical_state: Option<String>,
}

impl ObservedReturnShape {
    pub(super) fn canonical_machine_target(&self, machine_name: &str) -> String {
        match self.canonical_state.as_deref() {
            Some(state) => format!("{machine_name}<{state}>"),
            None => format!("{machine_name}<NextState>"),
        }
    }

    pub(super) fn canonical_annotation(&self, machine_name: &str) -> String {
        let machine_target = self.canonical_machine_target(machine_name);
        match self.wrapper {
            Some(SupportedWrapper::Option) => {
                format!("::core::option::Option<{machine_target}>")
            }
            Some(SupportedWrapper::Result) => {
                format!("::core::result::Result<{machine_target}, E>")
            }
            Some(SupportedWrapper::Branch) => {
                format!("::statum::Branch<{machine_target}, OtherBranch>")
            }
            None => machine_target,
        }
    }

    pub(super) fn canonical_wrapped_signature(
        &self,
        func_name: &syn::Ident,
        machine_name: &str,
    ) -> String {
        format!(
            "`fn {func_name}(self) -> {}`",
            self.canonical_annotation(machine_name)
        )
    }

    pub(super) fn fix_message(&self, func_name: &syn::Ident, machine_name: &str) -> String {
        let machine_target = self.canonical_machine_target(machine_name);
        match self.wrapper {
            Some(SupportedWrapper::Option)
            | Some(SupportedWrapper::Result)
            | Some(SupportedWrapper::Branch) => format!(
                "move `{machine_target}` into the primary branch, for example with {}, or return `{machine_target}` directly if you do not need the wrapper.",
                self.canonical_wrapped_signature(func_name, machine_name)
            ),
            None => format!("return `{machine_target}` directly."),
        }
    }
}

struct ValidatedTransitionReturnContract {
    primary_next_state: String,
    next_states: Vec<String>,
}

struct ResolvedTransitionTargets {
    primary_next_state: String,
    next_states: Vec<String>,
}

struct ObservedTransitionTargets {
    primary_next_state: Option<String>,
    next_states: Vec<String>,
}

fn validate_transition_return_contract(
    func: &TransitionFn,
    target_type: &Type,
) -> Result<ValidatedTransitionReturnContract, TokenStream> {
    let Some(written_return_type) = func.return_type.as_ref() else {
        return Err(invalid_return_type_error(
            func,
            target_type,
            "missing return type",
        ));
    };

    let strict_introspection =
        crate::strict_introspection_enabled() || func.introspection.is_some();
    if let Some(introspection) = func.introspection.as_ref() {
        let introspection_targets = resolve_transition_targets_strict(&introspection.return_type, target_type)
            .ok_or_else(|| {
                invalid_introspect_return_error(
                    introspection,
                    func,
                    "expected a direct machine path or a supported `Option`, `Result`, or `statum::Branch` wrapper around that machine path",
                )
            })?;

        let written_targets = resolve_transition_targets(
            written_return_type,
            target_type,
            false,
            func.return_type_span,
        )
        .ok_or_else(|| {
            invalid_return_type_error(
                func,
                target_type,
                "even with `#[introspect(return = ...)]`, the written return type must still resolve to the impl target machine path or a supported wrapper around it",
            )
        })?;
        if written_targets.primary_next_state != introspection_targets.primary_next_state
            || written_targets.next_states != introspection_targets.next_states
        {
            return Err(mismatched_introspect_return_error(
                introspection,
                func,
                written_return_type,
                target_type,
            ));
        }

        return Ok(ValidatedTransitionReturnContract {
            primary_next_state: introspection_targets.primary_next_state,
            next_states: introspection_targets.next_states,
        });
    }

    let reason = if strict_introspection {
        "expected the impl target machine path directly, or that same machine path wrapped in a supported `Option`, `Result`, or `Branch` shape; aliases require an explicit `#[introspect(return = ...)]` annotation in strict mode"
    } else {
        "expected the impl target machine path directly, a source-backed type alias that expands to it, or that same machine path wrapped in a supported `Option`, `Result`, or `Branch` shape"
    };
    let targets = resolve_transition_targets(
        written_return_type,
        target_type,
        strict_introspection,
        func.return_type_span,
    )
    .ok_or_else(|| invalid_return_type_error(func, target_type, reason))?;

    Ok(ValidatedTransitionReturnContract {
        primary_next_state: targets.primary_next_state,
        next_states: targets.next_states,
    })
}

fn resolve_transition_targets(
    ty: &Type,
    target_type: &Type,
    strict: bool,
    return_type_span: Option<proc_macro2::Span>,
) -> Option<ResolvedTransitionTargets> {
    if strict {
        return resolve_transition_targets_strict(ty, target_type);
    }

    SourceAliasResolver::new(return_type_span)
        .find_map(|context| resolve_transition_targets_in_context(ty, target_type, context))
}

fn resolve_transition_targets_strict(
    ty: &Type,
    target_type: &Type,
) -> Option<ResolvedTransitionTargets> {
    let observed = observe_transition_targets_strict(ty, target_type);
    let primary_next_state = observed.primary_next_state?;
    let next_states = observed.next_states;
    (!next_states.is_empty()).then_some(ResolvedTransitionTargets {
        primary_next_state,
        next_states,
    })
}

fn resolve_transition_targets_in_context(
    ty: &Type,
    target_type: &Type,
    context: Option<&super::super::resolve::AliasResolutionContext>,
) -> Option<ResolvedTransitionTargets> {
    let (_, primary_next_state) = parse_machine_and_state_in_context(ty, target_type, context)?;
    let next_states = collect_machine_and_states_in_context(ty, target_type, context)
        .into_iter()
        .map(|(_, state)| state)
        .collect::<Vec<_>>();
    (!next_states.is_empty()).then_some(ResolvedTransitionTargets {
        primary_next_state,
        next_states,
    })
}

fn observe_transition_targets_strict(ty: &Type, target_type: &Type) -> ObservedTransitionTargets {
    ObservedTransitionTargets {
        primary_next_state: parse_primary_machine_and_state_strict(ty, target_type)
            .map(|(_, state)| state),
        next_states: collect_machine_and_states_strict(ty, target_type)
            .into_iter()
            .map(|(_, state)| state)
            .collect::<Vec<_>>(),
    }
}

pub(super) fn strict_introspect_return_suggestion(
    func: &TransitionFn,
    target_type: &Type,
) -> Option<String> {
    let return_type = func.return_type.as_ref()?;
    SourceAliasResolver::new(func.return_type_span)
        .find_map(|context| {
            strict_diagnostic_expanded_return_type(return_type, target_type, context)
        })
        .map(|expanded| compact_display(&expanded))
}

fn strict_diagnostic_expanded_return_type(
    ty: &Type,
    target_type: &Type,
    context: Option<&AliasResolutionContext>,
) -> Option<Type> {
    let mut visited = HashSet::new();
    strict_diagnostic_expanded_return_type_inner(ty, target_type, context, &mut visited)
}

fn strict_diagnostic_expanded_return_type_inner(
    ty: &Type,
    target_type: &Type,
    context: Option<&AliasResolutionContext>,
    visited: &mut HashSet<String>,
) -> Option<Type> {
    let type_path = type_path(ty)?;

    if machine_segment_matching_target(&type_path.path, target_type).is_some() {
        return Some(ty.clone());
    }

    if let Some(expanded_alias) = expand_source_type_alias(ty, context, visited) {
        let (expanded, alias_context, visit_key) = expanded_alias.into_parts();
        let result = strict_diagnostic_expanded_return_type_inner(
            &expanded,
            target_type,
            Some(&alias_context),
            visited,
        )
        .or_else(|| {
            parse_primary_machine_and_state_strict(&expanded, target_type)
                .is_some()
                .then_some(expanded.clone())
        });
        visited.remove(&visit_key);
        return result;
    }

    let segment = type_path.path.segments.last()?;
    supported_wrapper(&type_path.path)?;

    let original_types = extract_generic_type_refs(&segment.arguments)?;
    let mut expanded_ty = ty.clone();
    let Type::Path(expanded_type_path) = &mut expanded_ty else {
        return None;
    };
    let expanded_segment = expanded_type_path.path.segments.last_mut()?;
    let PathArguments::AngleBracketed(args) = &mut expanded_segment.arguments else {
        return None;
    };

    let mut expanded_any = false;
    let mut type_index = 0usize;
    for arg in &mut args.args {
        let GenericArgument::Type(inner_ty) = arg else {
            continue;
        };
        let original_inner = original_types.get(type_index)?;
        if let Some(expanded_inner) = strict_diagnostic_expanded_return_type_inner(
            original_inner,
            target_type,
            context,
            visited,
        ) {
            *inner_ty = expanded_inner;
            expanded_any = true;
        }
        type_index += 1;
    }

    if expanded_any && parse_primary_machine_and_state_strict(&expanded_ty, target_type).is_some() {
        Some(expanded_ty)
    } else {
        None
    }
}

pub(super) fn observed_return_shape(
    func: &TransitionFn,
    target_type: &Type,
) -> Option<ObservedReturnShape> {
    let return_type = func.return_type.as_ref()?;
    let wrapper = raw_wrapper_kind(return_type);
    let primary_branch = primary_branch_display(return_type);
    let mut machine_branches = resolved_machine_branches(func, target_type);
    let canonical_state = parse_primary_machine_and_state_strict(return_type, target_type)
        .map(|(_, state)| state)
        .or_else(|| {
            machine_branches
                .first()
                .map(|branch| state_name_from_machine_target(branch).to_string())
        });
    if let Some(state) = canonical_state.as_deref() {
        let canonical_machine = format!("{}<{state}>", func.machine_name);
        machine_branches.retain(|branch| branch != &canonical_machine);
    }

    Some(ObservedReturnShape {
        primary_branch,
        secondary_machine_branches: machine_branches,
        wrapper,
        canonical_state,
    })
}

fn resolved_machine_branches(func: &TransitionFn, target_type: &Type) -> Vec<String> {
    let Some(return_type) = func.return_type.as_ref() else {
        return Vec::new();
    };
    let uses_strict_resolution =
        crate::strict_introspection_enabled() || func.introspection.is_some();
    let targets = if uses_strict_resolution {
        collect_machine_and_states_strict(return_type, target_type)
    } else {
        SourceAliasResolver::new(func.return_type_span)
            .find_map(|context| {
                let states =
                    collect_machine_and_states_in_context(return_type, target_type, context);
                (!states.is_empty()).then_some(states)
            })
            .unwrap_or_default()
    };

    targets
        .into_iter()
        .map(|(machine, state)| format!("{machine}<{state}>"))
        .collect()
}

fn raw_wrapper_kind(ty: &Type) -> Option<SupportedWrapper> {
    let type_path = type_path(ty)?;
    supported_wrapper(&type_path.path)
}

pub(super) fn primary_branch_display(ty: &Type) -> Option<String> {
    let type_path = type_path(ty)?;
    let segment = type_path.path.segments.last()?;
    match supported_wrapper(&type_path.path) {
        Some(_) => extract_first_generic_type_ref(&segment.arguments).map(compact_display),
        None => Some(compact_display(ty)),
    }
}

fn state_name_from_machine_target(machine_target: &str) -> &str {
    machine_target
        .split_once('<')
        .and_then(|(_, state)| state.strip_suffix('>'))
        .unwrap_or("NextState")
}