use crate::{Delimiter, Group, Ident, Span, TokenStream, TokenTree};
use quote::quote;
pub fn on_error(attr: TokenStream, item: TokenStream) -> TokenStream {
let cleanup_expr = attr;
let tokens: Vec<TokenTree> = item.into_iter().collect();
let fn_pos = tokens
.iter()
.position(|tt| matches!(tt, TokenTree::Ident(id) if *id == "fn"))
.expect("Method must have 'fn' keyword");
let fn_name = if let TokenTree::Ident(id) = &tokens[fn_pos + 1] {
id.clone()
} else {
panic!("Expected function name after 'fn'");
};
let body_idx = tokens
.iter()
.rposition(|tt| matches!(tt, TokenTree::Group(g) if g.delimiter() == Delimiter::Brace))
.expect("Method must have a body");
let body = if let TokenTree::Group(g) = &tokens[body_idx] {
g.clone()
} else {
unreachable!()
};
let returns_mut_self = {
let before_body: TokenStream = tokens[..body_idx].iter().cloned().collect();
let s = before_body.to_string();
if let Some(arrow_pos) = s.rfind("->") {
let ret_type = &s[arrow_pos..];
ret_type.contains("& mut Self") || ret_type.contains("&mut Self")
} else {
false
}
};
let inner_name = Ident::new(&format!("__{fn_name}_inner"), Span::call_site());
let mut inner_sig_tokens: Vec<TokenTree> = Vec::new();
for (i, tt) in tokens[..body_idx].iter().enumerate() {
if i == fn_pos + 1 {
inner_sig_tokens.push(TokenTree::Ident(inner_name.clone()));
} else {
inner_sig_tokens.push(tt.clone());
}
}
let inner_sig: TokenStream = inner_sig_tokens.into_iter().collect();
let wrapper_body = if returns_mut_self {
quote! {
{
match self.#inner_name() {
Ok(_discarded) => Ok(self),
Err(__e) => {
#cleanup_expr;
Err(__e)
}
}
}
}
} else {
quote! {
{
let __result = self.#inner_name();
if __result.is_err() {
#cleanup_expr;
}
__result
}
}
};
let wrapper_sig: TokenStream = tokens[..body_idx].iter().cloned().collect();
let wrapper_body_group = TokenTree::Group(Group::new(Delimiter::Brace, wrapper_body));
quote! {
#[doc(hidden)]
#[inline(always)]
#inner_sig
#body
#wrapper_sig
#wrapper_body_group
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_on_error_basic() {
let attr = quote! { self.cleanup() };
let item = quote! {
pub fn do_something(&mut self) -> Result<i32, Error> {
self.inner_work()?;
Ok(42)
}
};
let result = on_error(attr, item);
let result_str = result.to_string();
assert!(result_str.contains("__do_something_inner"));
assert!(result_str.contains("self . cleanup"));
}
#[test]
fn test_on_error_mut_self_return() {
let attr = quote! { self.poison_and_cleanup() };
let item = quote! {
pub fn begin_some(&mut self) -> Result<&mut Self, ReflectError> {
self.require_active()?;
Ok(self)
}
};
let result = on_error(attr, item);
let result_str = result.to_string();
assert!(result_str.contains("__begin_some_inner"));
assert!(result_str.contains("_discarded"));
assert!(result_str.contains("Ok (self)"));
}
}