use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{ToTokens, quote};
use syn::parse::Parser;
use syn::visit_mut::VisitMut;
pub(crate) fn throws(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_ts = proc_macro2::TokenStream::from(attr);
let item_ts = proc_macro2::TokenStream::from(item);
let error_paths = match parse_error_paths(&attr_ts) {
Ok(p) => p,
Err(e) => return e.to_compile_error().into(),
};
let item_fn = match syn::parse2::<syn::ItemFn>(item_ts.clone()) {
Ok(f) => f,
Err(e) => return e.to_compile_error().into(),
};
let fn_sig = &item_fn.sig;
let fn_block = &item_fn.block;
let fn_vis = &item_fn.vis;
let fn_attrs = &item_fn.attrs;
let ret_ty = match &fn_sig.output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
};
let mut sig_without_ret = fn_sig.clone();
sig_without_ret.output = syn::ReturnType::Default;
let output = if error_paths.is_empty() {
quote! {
::struct_error::__throws_impl! {
[] [#ret_ty] [#fn_block] [#sig_without_ret] [#fn_vis] [#(#fn_attrs)*]
}
}
} else {
let first_path = &error_paths[0];
let rest_paths = &error_paths[1..];
quote! {
::struct_error::macro_magic::forward_tokens! {
#first_path,
::struct_error::__throws_cps,
::struct_error::macro_magic,
{ [@recurse] [[]] [#(#rest_paths),*] [#ret_ty] [#fn_block] [#sig_without_ret] [#fn_vis] [#(#fn_attrs)*] }
}
}
};
output.into()
}
fn parse_error_paths(attr_ts: &proc_macro2::TokenStream) -> syn::Result<Vec<syn::Path>> {
let paths = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
.parse2(attr_ts.clone())?;
Ok(paths.into_iter().collect())
}
use syn::parse::{Parse, ParseStream};
struct CpsAttr {
item: syn::Item,
extra: proc_macro2::Group,
}
impl Parse for CpsAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _keyword: syn::Ident = input.parse()?; let item: syn::Item = input.parse()?;
let _comma: syn::Token![,] = input.parse()?;
let extra = input.parse::<proc_macro2::Group>()?;
Ok(Self { item, extra })
}
}
pub(crate) fn __throws_cps(attr: TokenStream, _item: TokenStream) -> TokenStream {
let attr_ts = proc_macro2::TokenStream::from(attr);
let cps_attr = match syn::parse2::<CpsAttr>(attr_ts) {
Ok(a) => a,
Err(e) => {
eprintln!("__THROWS_CPS PARSE ERROR: {}", e);
return e.to_compile_error().into();
}
};
let forwarded_item = cps_attr.item;
let extra_stream = cps_attr.extra.stream();
let groups = extract_bracket_groups(extra_stream);
if groups.len() < 8 {
return syn::Error::new(
Span::call_site(),
"__throws_cps: expected format: { [@recurse] [acc] [remaining] [ret_ty] [body] [sig] [vis] [attrs] }",
)
.to_compile_error()
.into();
}
let acc_ts = &groups[1];
let remaining_ts = &groups[2];
let ret_ty_ts = &groups[3];
let body_ts = &groups[4];
let sig_ts = &groups[5];
let vis_ts = &groups[6];
let attrs_ts = &groups[7];
let new_acc = if acc_ts.is_empty() {
quote! { [(#forwarded_item)] }
} else {
quote! { #acc_ts(#forwarded_item) }
};
let remaining_paths: Vec<syn::Path> = if remaining_ts.is_empty() {
Vec::new()
} else {
match syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
.parse2(remaining_ts.clone())
{
Ok(p) => p.into_iter().collect(),
Err(e) => return e.to_compile_error().into(),
}
};
if remaining_paths.is_empty() {
let output = quote! {
::struct_error::__throws_impl! {
[#new_acc] [#ret_ty_ts] [#body_ts] [#sig_ts] [#vis_ts] [#attrs_ts]
}
};
output.into()
} else {
let next_path = &remaining_paths[0];
let rest_paths = &remaining_paths[1..];
let output = quote! {
::struct_error::macro_magic::forward_tokens! {
#next_path,
::struct_error::__throws_cps,
::struct_error::macro_magic,
{ [@recurse] [#new_acc] [#(#rest_paths),*] [#ret_ty_ts] [#body_ts] [#sig_ts] [#vis_ts] [#attrs_ts] }
}
};
output.into()
}
}
pub(crate) fn __throws_impl(input: TokenStream) -> TokenStream {
let input_ts = proc_macro2::TokenStream::from(input);
let groups = extract_bracket_groups(input_ts);
if groups.len() < 6 {
return syn::Error::new(
Span::call_site(),
"__throws_impl: expected format: [tokens...] [ret_ty] [body] [sig] [vis] [attrs]",
)
.to_compile_error()
.into();
}
let forwarded_tokens_group = &groups[0];
let ret_ty_ts = &groups[1];
let body_ts = &groups[2];
let sig_ts = &groups[3];
let vis_ts = &groups[4];
let attrs_ts = &groups[5];
let forwarded_items = extract_paren_groups(forwarded_tokens_group.clone());
let ret_ty: syn::Type = match syn::parse2(ret_ty_ts.clone()) {
Ok(t) => t,
Err(e) => return e.to_compile_error().into(),
};
let mut fn_block: syn::Block = match syn::parse2(body_ts.clone()) {
Ok(b) => b,
Err(e) => return e.to_compile_error().into(),
};
let mut fn_sig: syn::Signature = match syn::parse2(sig_ts.clone()) {
Ok(s) => s,
Err(e) => return e.to_compile_error().into(),
};
let fn_vis: syn::Visibility = match syn::parse2(vis_ts.clone()) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let fn_attrs: Vec<syn::Attribute> = if attrs_ts.is_empty() {
Vec::new()
} else {
let dummy = quote! { #attrs_ts fn __dummy() {} };
match syn::parse2::<syn::ItemFn>(dummy) {
Ok(item) => item.attrs,
Err(_) => Vec::new(),
}
};
let mut all_error_paths: Vec<syn::Path> = Vec::new();
for tokens in &forwarded_items {
if tokens.is_empty() {
continue;
}
if let Ok(item_struct) = syn::parse2::<syn::ItemStruct>(tokens.clone()) {
let mut found_members = false;
for attr in &item_struct.attrs {
if attr
.path()
.segments
.last()
.map(|s| s.ident == "__struct_error_members")
.unwrap_or(false)
{
found_members = true;
match attr.parse_args_with(
syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
) {
Ok(members) => {
all_error_paths.extend(members);
}
Err(e) => return e.to_compile_error().into(),
}
}
}
if !found_members {
let path = syn::Path::from(item_struct.ident.clone());
all_error_paths.push(path);
}
} else {
if let Ok(path) = syn::parse2::<syn::Path>(tokens.clone()) {
all_error_paths.push(path);
}
}
}
let sorted_paths = crate::sort::sort_paths(all_error_paths);
let unique_paths = crate::sort::dedup_paths(sorted_paths);
let unt_type = crate::sort::build_unt_type(&unique_paths);
let _error_index_map: std::collections::HashMap<String, usize> = unique_paths
.iter()
.enumerate()
.map(|(i, p)| (crate::sort::path_to_string(p), i))
.collect();
let into_trait = generate_into_trait(&unique_paths);
let mut rewriter = ThrowsRewriter;
rewriter.visit_block_mut(&mut fn_block);
fn_sig.output = syn::ReturnType::Type(
syn::Token),
syn::parse2(quote! { ::core::result::Result<#ret_ty, #unt_type> }).unwrap(),
);
let is_unit_ret = matches!(ret_ty, syn::Type::Tuple(t) if t.elems.is_empty());
if is_unit_ret {
let needs_ok = if let Some(last) = fn_block.stmts.last() {
!matches!(last, syn::Stmt::Expr(_, None))
} else {
true
};
if needs_ok {
fn_block.stmts.push(
syn::parse2(quote! {
::core::result::Result::Ok(())
})
.unwrap(),
);
}
}
let into_trait_items = match syn::parse2::<syn::File>(into_trait) {
Ok(file) => file.items,
Err(_) => Vec::new(),
};
let mut trait_stmts: Vec<syn::Stmt> =
into_trait_items.into_iter().map(syn::Stmt::Item).collect();
trait_stmts.append(&mut fn_block.stmts);
fn_block.stmts = trait_stmts;
let expanded = quote! {
#(#fn_attrs)*
#fn_vis #fn_sig #fn_block
};
expanded.into()
}
fn extract_bracket_groups(ts: proc_macro2::TokenStream) -> Vec<proc_macro2::TokenStream> {
let mut result = Vec::new();
for tt in ts {
if let proc_macro2::TokenTree::Group(g) = tt
&& g.delimiter() == proc_macro2::Delimiter::Bracket
{
result.push(g.stream());
}
}
result
}
fn extract_paren_groups(ts: proc_macro2::TokenStream) -> Vec<proc_macro2::TokenStream> {
let mut result = Vec::new();
for tt in ts {
if let proc_macro2::TokenTree::Group(g) = tt
&& g.delimiter() == proc_macro2::Delimiter::Parenthesis
{
result.push(g.stream());
}
}
result
}
fn generate_into_trait(paths: &[syn::Path]) -> proc_macro2::TokenStream {
if paths.is_empty() {
return quote! {};
}
let unt_type = crate::sort::build_unt_type(paths);
let impls: Vec<_> = paths
.iter()
.enumerate()
.map(|(index, path)| {
let nested = crate::sort::build_unt_nesting(index, paths.len(), "e! { self });
quote! {
impl __StructErrorInto for #path {
fn into_unt(self) -> #unt_type {
#nested
}
}
}
})
.collect();
quote! {
trait __StructErrorInto {
fn into_unt(self) -> #unt_type;
}
#(#impls)*
impl __StructErrorInto for ::struct_error::End {
fn into_unt(self) -> #unt_type {
match self {}
}
}
impl<H, T> __StructErrorInto for ::struct_error::Unt<H, T>
where
H: __StructErrorInto,
T: __StructErrorInto,
{
fn into_unt(self) -> #unt_type {
match self {
::struct_error::Unt::Here(h) => h.into_unt(),
::struct_error::Unt::There(t) => t.into_unt(),
}
}
}
}
}
struct ThrowsRewriter;
impl VisitMut for ThrowsRewriter {
fn visit_expr_closure_mut(&mut self, _node: &mut syn::ExprClosure) {
}
fn visit_expr_async_mut(&mut self, _node: &mut syn::ExprAsync) {
}
fn visit_item_fn_mut(&mut self, _node: &mut syn::ItemFn) {
}
fn visit_item_mod_mut(&mut self, _node: &mut syn::ItemMod) {
}
fn visit_expr_mut(&mut self, node: &mut syn::Expr) {
if let syn::Expr::Try(try_expr) = node {
self.visit_expr_mut(&mut try_expr.expr);
let expr = &try_expr.expr;
*node = syn::parse2(quote! {
match #expr {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(e) => {
return ::core::result::Result::Err(__StructErrorInto::into_unt(e))
}
}
})
.unwrap();
} else {
syn::visit_mut::visit_expr_mut(self, node);
}
}
fn visit_stmt_mut(&mut self, node: &mut syn::Stmt) {
syn::visit_mut::visit_stmt_mut(self, node);
if let syn::Stmt::Expr(syn::Expr::Return(ret), _) = node
&& let Some(expr) = &mut ret.expr
{
let wrapped = self.wrap_expr(expr);
**expr = wrapped;
}
}
fn visit_block_mut(&mut self, node: &mut syn::Block) {
for stmt in &mut node.stmts {
self.visit_stmt_mut(stmt);
}
if let Some(last) = node.stmts.last_mut()
&& let syn::Stmt::Expr(expr, semi) = last
&& semi.is_none()
{
let wrapped = self.wrap_tail_expr(expr);
*expr = wrapped;
}
}
}
impl ThrowsRewriter {
fn wrap_expr(&self, expr: &syn::Expr) -> syn::Expr {
if crate::utils::is_ok_expr(expr) || crate::utils::is_err_expr(expr) {
return expr.clone();
}
syn::parse2(quote! {
::core::result::Result::Ok(#expr)
})
.unwrap()
}
fn wrap_tail_expr(&self, expr: &syn::Expr) -> syn::Expr {
self.wrap_expr(expr)
}
}