use proc_macro::TokenStream;
use quote::quote;
fn parse_match_expr(input: proc_macro2::TokenStream) -> syn::Result<syn::ExprMatch> {
if let Ok(m) = syn::parse2::<syn::ExprMatch>(input.clone()) {
return Ok(m);
}
let mut expr_tokens = Vec::new();
let mut brace_group = None;
for tt in input {
if brace_group.is_some() {
continue;
}
match &tt {
proc_macro2::TokenTree::Group(g) if g.delimiter() == proc_macro2::Delimiter::Brace => {
brace_group = Some(g.clone());
}
_ => expr_tokens.push(tt),
}
}
let brace = brace_group.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"match_error: expected `{ ... }` block after expression",
)
})?;
let expr: proc_macro2::TokenStream = expr_tokens.into_iter().collect();
let reconstructed = quote! {
match #expr #brace
};
syn::parse2(reconstructed)
}
pub(crate) fn match_error(input: TokenStream) -> TokenStream {
let input_ts = proc_macro2::TokenStream::from(input);
let match_expr = match parse_match_expr(input_ts) {
Ok(m) => m,
Err(e) => return e.to_compile_error().into(),
};
let expr = &match_expr.expr;
let arms = &match_expr.arms;
let mut ok_arms: Vec<&syn::Arm> = Vec::new();
let mut error_arms: Vec<&syn::Arm> = Vec::new();
for arm in arms {
match classify_pat(&arm.pat) {
ArmType::Ok => ok_arms.push(arm),
ArmType::Error => error_arms.push(arm),
ArmType::CatchAll => {
return syn::Error::new_spanned(
&arm.pat,
"match_error: catch-all patterns (`_` or bare bindings) are not allowed; \
all error types must be matched explicitly",
)
.to_compile_error()
.into();
}
ArmType::Unknown => {
return syn::Error::new_spanned(
&arm.pat,
"match_error: unsupported pattern type",
)
.to_compile_error()
.into();
}
}
}
let mut error_paths: Vec<syn::Path> = Vec::new();
let mut error_arm_data: Vec<ErrorArmData> = Vec::new();
for arm in &error_arms {
if let Some(path) = extract_error_path(&arm.pat) {
error_paths.push(path.clone());
error_arm_data.push(ErrorArmData {
path,
pat: arm.pat.clone(),
guard: arm.guard.clone(),
body: arm.body.clone(),
});
}
}
let sorted_paths = crate::sort::sort_paths(error_paths.clone());
let unique_paths = crate::sort::dedup_paths(sorted_paths);
let path_index_map: std::collections::HashMap<String, usize> = unique_paths
.iter()
.enumerate()
.map(|(i, p)| (crate::sort::path_to_string(p), i))
.collect();
let mut new_arms: Vec<proc_macro2::TokenStream> = Vec::new();
for arm in &ok_arms {
let pat = &arm.pat;
let guard_ts = guard_to_tokens(&arm.guard);
let body = &arm.body;
new_arms.push(quote! {
#pat #guard_ts => #body,
});
}
for data in &error_arm_data {
let path_key = crate::sort::path_to_string(&data.path);
if let Some(&index) = path_index_map.get(&path_key) {
let stripped_pat = strip_err_wrapper(&data.pat);
let new_pat = crate::sort::build_unt_pattern(index, unique_paths.len(), &stripped_pat);
let guard_ts = guard_to_tokens(&data.guard);
let body = &data.body;
new_arms.push(quote! {
#new_pat #guard_ts => #body,
});
}
}
let expanded = quote! {
match #expr {
#(#new_arms)*
}
};
expanded.into()
}
enum ArmType {
Ok,
CatchAll,
Error,
Unknown,
}
struct ErrorArmData {
path: syn::Path,
pat: syn::Pat,
guard: Option<(syn::token::If, Box<syn::Expr>)>,
body: Box<syn::Expr>,
}
fn classify_pat(pat: &syn::Pat) -> ArmType {
match pat {
syn::Pat::TupleStruct(ts) => {
if is_path_ok(&ts.path) {
ArmType::Ok
} else if is_path_err(&ts.path) {
ArmType::Error
} else {
if is_error_path(&ts.path) {
ArmType::Error
} else {
ArmType::Unknown
}
}
}
syn::Pat::Path(pp) => {
if is_path_ok(&pp.path) {
ArmType::Ok
} else if is_path_err(&pp.path) || is_error_path(&pp.path) {
ArmType::Error
} else {
ArmType::Unknown
}
}
syn::Pat::Wild(_) => ArmType::CatchAll,
syn::Pat::Ident(pi) => {
if crate::sort::is_pascal_case(&pi.ident) {
ArmType::Error
} else {
ArmType::CatchAll
}
}
syn::Pat::Struct(ps) => {
if is_error_path(&ps.path) {
ArmType::Error
} else {
ArmType::Unknown
}
}
_ => ArmType::Unknown,
}
}
fn strip_err_wrapper(pat: &syn::Pat) -> syn::Pat {
match pat {
syn::Pat::TupleStruct(ts) if is_path_err(&ts.path) => {
if let Some(inner) = ts.elems.first() {
inner.clone()
} else {
pat.clone()
}
}
_ => pat.clone(),
}
}
fn extract_error_path(pat: &syn::Pat) -> Option<syn::Path> {
match pat {
syn::Pat::TupleStruct(ts) => {
if is_path_err(&ts.path) {
if let Some(inner) = ts.elems.first() {
return extract_error_path(inner);
}
}
Some(ts.path.clone())
}
syn::Pat::Path(pp) => Some(pp.path.clone()),
syn::Pat::Ident(pi) => {
if crate::sort::is_pascal_case(&pi.ident) {
Some(syn::Path::from(pi.ident.clone()))
} else {
None
}
}
syn::Pat::Struct(ps) => Some(ps.path.clone()),
_ => None,
}
}
fn is_path_ok(path: &syn::Path) -> bool {
path.segments.len() == 1
&& path.segments[0].ident == "Ok"
&& matches!(path.segments[0].arguments, syn::PathArguments::None)
}
fn is_path_err(path: &syn::Path) -> bool {
path.segments.len() == 1
&& path.segments[0].ident == "Err"
&& matches!(path.segments[0].arguments, syn::PathArguments::None)
}
fn guard_to_tokens(guard: &Option<(syn::token::If, Box<syn::Expr>)>) -> proc_macro2::TokenStream {
match guard {
Some((_, expr)) => quote! { if #expr },
None => quote! {},
}
}
fn is_error_path(path: &syn::Path) -> bool {
if path.segments.len() == 1 {
let seg = &path.segments[0];
if matches!(seg.arguments, syn::PathArguments::None) {
return crate::sort::is_pascal_case(&seg.ident);
}
}
if let Some(last) = path.segments.last() {
return crate::sort::is_pascal_case(&last.ident);
}
false
}