use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use std::collections::HashSet;
use syn::spanned::Spanned;
use syn::{FnArg, GenericArgument, Ident, Pat, PathArguments, ReturnType, Type};
use crate::diagnostics::{DiagnosticMessage, compact_display};
use crate::source::{SourceAliasResolver, expand_source_type_alias};
use super::contract::{
ValidatorMethodContract, ValidatorReturnKind, VariantSpec,
build_validator_method_contract as build_semantic_validator_method_contract,
};
use super::type_equivalence::types_equivalent;
pub(super) struct ValidatorDiagnosticContext<'a> {
pub(super) persisted_type_display: &'a str,
pub(super) machine_name: &'a str,
pub(super) state_enum_name: &'a str,
pub(super) variant_name: &'a str,
pub(super) machine_fields: &'a [Ident],
pub(super) expected_ok_type: &'a Type,
}
pub(super) fn validator_state_name_from_ident(ident: &Ident) -> Option<String> {
ident
.to_string()
.strip_prefix("is_")
.map(std::borrow::ToOwned::to_owned)
}
pub(super) fn build_validator_method_contract(
func: &syn::ImplItemFn,
spec: &VariantSpec,
context: &ValidatorDiagnosticContext<'_>,
) -> Result<ValidatorMethodContract, proc_macro2::TokenStream> {
validate_validator_signature(func, context)?;
let return_kind = validate_validator_return_contract(func, context.expected_ok_type, context)?;
Ok(build_semantic_validator_method_contract(func, spec, return_kind))
}
fn validate_validator_signature(
func: &syn::ImplItemFn,
context: &ValidatorDiagnosticContext<'_>,
) -> Result<(), proc_macro2::TokenStream> {
if func.sig.inputs.len() != 1 {
let collision_line = explicit_param_collision_line(&func.sig.inputs, context.machine_fields);
let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must declare only `&self`.",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
))
.found(format!("`fn {}({})`", func.sig.ident, compact_display(&func.sig.inputs)))
.expected(expected_signature.clone())
.reason(collision_line.unwrap_or_else(|| {
"validator methods do not accept explicit machine-field parameters.".to_string()
}))
.note(injected_machine_fields_line(context.machine_name, context.machine_fields))
.fix("remove explicit parameters and read injected machine fields by bare name inside the validator body.".to_string())
.render();
let error = if let Some(extra_input) = func.sig.inputs.iter().nth(1) {
syn::Error::new_spanned(extra_input, message)
} else {
syn::Error::new_spanned(&func.sig.inputs, message)
};
return Err(error.to_compile_error());
}
match &func.sig.inputs[0] {
FnArg::Receiver(receiver) => {
if receiver.reference.is_none() || receiver.mutability.is_some() {
let receiver_display = receiver.to_token_stream().to_string();
let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must take `&self`, not `{}`.",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
receiver_display,
))
.found(format!("`fn {}({receiver_display})`", func.sig.ident))
.expected(expected_signature)
.note(injected_machine_fields_line(context.machine_name, context.machine_fields))
.fix("change the receiver to `&self`.".to_string())
.render();
return Err(syn::Error::new_spanned(receiver, message).to_compile_error());
}
}
FnArg::Typed(_) => {
let expected_signature = expected_validator_signature(&func.sig.ident, context.expected_ok_type);
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must take `&self` as its receiver.",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
))
.found(format!("`fn {}({})`", func.sig.ident, compact_display(&func.sig.inputs[0])))
.expected(expected_signature)
.note(injected_machine_fields_line(context.machine_name, context.machine_fields))
.fix("rewrite the method to take `&self` and no other parameters.".to_string())
.render();
return Err(syn::Error::new_spanned(&func.sig.inputs[0], message).to_compile_error());
}
}
Ok(())
}
fn validate_validator_return_contract(
func: &syn::ImplItemFn,
expected_ok_type: &Type,
context: &ValidatorDiagnosticContext<'_>,
) -> Result<ValidatorReturnKind, TokenStream> {
let ReturnType::Type(_, return_ty) = &func.sig.output else {
let expected_ok_display = expected_ok_type.to_token_stream().to_string();
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return `Result<{}, _>` or `Validation<{}>`.",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
expected_ok_display,
expected_ok_display,
))
.expected(expected_validator_signature(&func.sig.ident, expected_ok_type))
.reason(expected_state_shape(
context.state_enum_name,
context.variant_name,
&expected_ok_display,
))
.fix("add an explicit validator return type.".to_string())
.render();
return Err(syn::Error::new_spanned(&func.sig.output, message).to_compile_error());
};
let (actual_ok_ty, return_kind) =
match extract_supported_validator_ok_type(return_ty, return_ty.span()) {
Some(info) => info,
None => {
let expected_ok_display = expected_ok_type.to_token_stream().to_string();
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return a supported validator result whose payload is `{}`.",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
expected_ok_display,
))
.found(format!("`{}`", compact_display(return_ty)))
.expected(format!(
"`Result<{expected_ok_display}, _>` or `Validation<{expected_ok_display}>`"
))
.note(supported_validator_wrapper_note())
.fix("rewrite the return type to use one of the supported validator result wrappers.".to_string())
.render();
return Err(syn::Error::new_spanned(return_ty, message).to_compile_error());
}
};
if !types_equivalent(&actual_ok_ty, expected_ok_type) {
let expected_ok_display = expected_ok_type.to_token_stream().to_string();
let actual_return_type = return_ty.to_token_stream().to_string();
let actual_ok_display = actual_ok_ty.to_token_stream().to_string();
let message = DiagnosticMessage::new(format!(
"validator `{}` for `impl {}` rebuilding `{}` state `{}::{}` must return `Result<{}, _>` or `Validation<{}>` (or an equivalent supported alias).",
func.sig.ident,
context.persisted_type_display,
context.machine_name,
context.state_enum_name,
context.variant_name,
expected_ok_display,
expected_ok_display,
))
.found(format!(
"`{actual_return_type}` with payload `{actual_ok_display}`"
))
.expected(format!(
"`Result<{expected_ok_display}, _>` or `Validation<{expected_ok_display}>`"
))
.fix(format!(
"change the validator to return `{expected_ok_display}` for `{}::{}`.",
context.state_enum_name, context.variant_name
))
.render();
return Err(syn::Error::new_spanned(return_ty, message).to_compile_error());
}
Ok(return_kind)
}
fn injected_machine_fields_line(machine_name: &str, machine_fields: &[Ident]) -> String {
if machine_fields.is_empty() {
format!(
"Machine `{machine_name}` has no user-defined fields to inject, so validator methods should not take any extra parameters."
)
} else {
let injected = machine_fields
.iter()
.map(|field| format!("`{field}`"))
.collect::<Vec<_>>()
.join(", ");
format!(
"Machine `{machine_name}` injects these fields by bare name inside validator bodies: {injected}. Remove explicit parameters and use those bindings directly."
)
}
}
fn expected_validator_signature(func_ident: &Ident, expected_ok_type: &Type) -> String {
format!(
"`fn {func_ident}(&self) -> Result<{}, _>` or `fn {func_ident}(&self) -> Validation<{}>`",
expected_ok_type.to_token_stream(),
expected_ok_type.to_token_stream()
)
}
fn explicit_param_collision_line(
inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
machine_fields: &[Ident],
) -> Option<String> {
let collisions = inputs
.iter()
.skip(1)
.filter_map(|arg| match arg {
FnArg::Typed(arg) => match &*arg.pat {
Pat::Ident(ident) if machine_fields.iter().any(|field| field == &ident.ident) => {
Some(ident.ident.to_string())
}
_ => None,
},
FnArg::Receiver(_) => None,
})
.collect::<Vec<_>>();
if collisions.is_empty() {
None
} else {
Some(format!(
"Parameter name collision: {} {} with injected machine field {}.",
collisions
.iter()
.map(|name| format!("`{name}`"))
.collect::<Vec<_>>()
.join(", "),
if collisions.len() == 1 { "collides" } else { "collide" },
if collisions.len() == 1 { "binding" } else { "bindings" }
))
}
}
fn expected_state_shape(state_enum_name: &str, variant_name: &str, expected_ok_display: &str) -> String {
if expected_ok_display == "()" {
format!("`{state_enum_name}::{variant_name}` is a unit state")
} else {
format!("`{state_enum_name}::{variant_name}` carries `{expected_ok_display}`")
}
}
fn extract_supported_validator_ok_type(
return_ty: &Type,
return_ty_span: Span,
) -> Option<(Type, ValidatorReturnKind)> {
SourceAliasResolver::new(Some(return_ty_span)).find_map(|context| {
let mut visited = HashSet::new();
extract_supported_validator_ok_type_in_context(return_ty, context, &mut visited)
})
}
fn extract_supported_validator_ok_type_in_context(
return_ty: &Type,
context: Option<&crate::source::AliasResolutionContext>,
visited: &mut HashSet<String>,
) -> Option<(Type, ValidatorReturnKind)> {
if let Some(expanded_alias) = expand_source_type_alias(return_ty, context, visited)
{
let (expanded, alias_context, visit_key) = expanded_alias.into_parts();
let expanded_result =
extract_supported_validator_ok_type_in_context(&expanded, Some(&alias_context), visited);
visited.remove(&visit_key);
if expanded_result.is_some() {
return expanded_result;
}
}
direct_supported_validator_ok_type(return_ty)
}
fn direct_supported_validator_ok_type(return_ty: &Type) -> Option<(Type, ValidatorReturnKind)> {
let Type::Path(type_path) = return_ty else {
return None;
};
let last_segment = type_path.path.segments.last()?;
if last_segment.ident == "Validation" && path_is_supported_validation_wrapper(&type_path.path) {
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
let type_args: Vec<Type> = args
.args
.iter()
.filter_map(|arg| match arg {
GenericArgument::Type(ty) => Some(ty.clone()),
_ => None,
})
.collect();
if type_args.len() != 1 || type_args.len() != args.args.len() {
return None;
}
return type_args
.first()
.cloned()
.map(|ty| (ty, ValidatorReturnKind::Diagnostic));
}
if last_segment.ident != "Result" || !path_is_supported_result_wrapper(&type_path.path) {
return None;
}
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
let type_args: Vec<Type> = args
.args
.iter()
.filter_map(|arg| match arg {
GenericArgument::Type(ty) => Some(ty.clone()),
_ => None,
})
.collect();
if type_args.is_empty() || type_args.len() > 2 || type_args.len() != args.args.len() {
return None;
}
let return_kind = if type_args.get(1).is_some_and(is_rejection_type) {
ValidatorReturnKind::Diagnostic
} else {
ValidatorReturnKind::Plain
};
type_args.first().cloned().map(|ty| (ty, return_kind))
}
fn supported_validator_wrapper_note() -> &'static str {
"supported forms: `Result<T, E>`, `core::result::Result<T, E>`, `std::result::Result<T, E>`, `statum::Result<T>`, direct `Result<T, statum::Rejection>`, and `Validation<T>` / `statum::Validation<T>`, plus source-declared type aliases that expand to those wrappers."
}
fn path_is_supported_validation_wrapper(path: &syn::Path) -> bool {
let segments = path
.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>();
matches!(segments.as_slice(), [validation] if validation == "Validation")
|| segments.as_slice() == ["statum", "Validation"]
}
fn path_is_supported_result_wrapper(
path: &syn::Path,
) -> bool {
let segments = path
.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>();
matches!(segments.as_slice(), [result] if result == "Result")
|| segments.as_slice() == ["statum", "Result"]
|| segments.as_slice() == ["core", "result", "Result"]
|| segments.as_slice() == ["std", "result", "Result"]
}
fn is_rejection_type(ty: &Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
type_path
.path
.segments
.last()
.is_some_and(|segment| segment.ident == "Rejection")
}