use crate::visitor::MatchFinder;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Arm, Expr, ExprMatch, ExprPath, ItemFn, Pat, Path, spanned::Spanned, visit::Visit};
pub fn find_and_validate<'f>(attr: &str, func: &'f ItemFn) -> &'f [Arm] {
let mut finder = MatchFinder { found: None };
finder.visit_item_fn(func);
let ExprMatch { arms, .. } = finder.found.unwrap_or_else(|| {
panic!("{attr} can only be used on functions containing a match expression")
});
validate_enum_to_enum_arms(arms);
arms
}
fn validate_enum_to_enum_arms(arms: &[Arm]) {
assert!(
!arms.is_empty(),
"bijective: match must have at least one arm"
);
for arm in arms {
assert!(
arm.guard.is_none(),
"bijective: match guards are not supported"
);
match &arm.pat {
Pat::Path(_) => {}
_ => panic!(
"bijective: every arm pattern must be an enum variant path (e.g. `Enum::Variant`)"
),
}
match arm.body.as_ref() {
Expr::Path(_) => {}
_ => panic!(
"bijective: every arm body must be an enum variant path (e.g. `Enum::Variant`)"
),
}
}
}
pub fn check_injectivity(arms: &[Arm]) -> Option<TokenStream2> {
let mut seen: Vec<(String, proc_macro2::Span)> = Vec::new();
for arm in arms {
let Expr::Path(output) = arm.body.as_ref() else {
unreachable!("already validated")
};
let key = quote!(#output).to_string();
if seen.iter().any(|(k, _)| k == &key) {
return Some(
syn::Error::new(
output.span(),
format!(
"injective: `{key}` is produced by more than one arm; \
the mapping is not injective"
),
)
.to_compile_error(),
);
}
seen.push((key, output.span()));
}
None
}
pub fn surjectivity_check_arms(arms: &[Arm]) -> Vec<TokenStream2> {
let mut seen: Vec<String> = Vec::new();
let mut unique_outputs: Vec<ExprPath> = Vec::new();
for arm in arms {
let Expr::Path(output) = arm.body.as_ref() else {
unreachable!("already validated")
};
let key = quote!(#output).to_string();
if !seen.contains(&key) {
seen.push(key);
unique_outputs.push(output.clone());
}
}
unique_outputs
.iter()
.map(|output| quote! { #output => (), })
.collect()
}
fn enum_type_of_path(path: &Path) -> Path {
let n = path.segments.len();
assert!(
n >= 2,
"bijective: enum path must have at least 2 segments (e.g. `Enum::Variant`), got: `{}`",
quote::quote!(#path),
);
let mut segments = syn::punctuated::Punctuated::new();
for seg in path.segments.iter().take(n - 1) {
segments.push(seg.clone());
}
Path {
leading_colon: path.leading_colon,
segments,
}
}
pub fn enum_type_of_expr(expr: &ExprPath) -> Path {
enum_type_of_path(&expr.path)
}
pub trait AsExprPath {
fn as_expr_path(&self) -> &ExprPath;
}
impl AsExprPath for Expr {
fn as_expr_path(&self) -> &ExprPath {
let Self::Path(p) = self else {
panic!("expected Expr::Path")
};
p
}
}