facet_macros_impl/
on_error.rs

1use crate::{Delimiter, Group, Ident, Span, TokenStream, TokenTree};
2use quote::quote;
3
4/// Entry point for the on_error attribute macro.
5///
6/// Usage: `#[on_error(self.poison_and_cleanup())]`
7///
8/// This wraps methods that return `Result<_, E>` to run cleanup code on error.
9/// For methods returning `Result<&mut Self, E>`, it properly handles the borrow
10/// by discarding the returned reference and returning a fresh `Ok(self)`.
11///
12/// The macro generates two methods:
13/// - `__method_name_inner`: contains the original body
14/// - `method_name`: wrapper that calls inner and handles errors
15pub fn on_error(attr: TokenStream, item: TokenStream) -> TokenStream {
16    // The attribute contains the cleanup expression
17    let cleanup_expr = attr;
18
19    let tokens: Vec<TokenTree> = item.into_iter().collect();
20
21    // Find the function name - it's the identifier after "fn"
22    let fn_pos = tokens
23        .iter()
24        .position(|tt| matches!(tt, TokenTree::Ident(id) if *id == "fn"))
25        .expect("Method must have 'fn' keyword");
26
27    let fn_name = if let TokenTree::Ident(id) = &tokens[fn_pos + 1] {
28        id.clone()
29    } else {
30        panic!("Expected function name after 'fn'");
31    };
32
33    // Find the body - it's the last BraceGroup
34    let body_idx = tokens
35        .iter()
36        .rposition(|tt| matches!(tt, TokenTree::Group(g) if g.delimiter() == Delimiter::Brace))
37        .expect("Method must have a body");
38
39    let body = if let TokenTree::Group(g) = &tokens[body_idx] {
40        g.clone()
41    } else {
42        unreachable!()
43    };
44
45    // Check if return type contains `&mut Self`
46    let returns_mut_self = {
47        let before_body: TokenStream = tokens[..body_idx].iter().cloned().collect();
48        let s = before_body.to_string();
49        if let Some(arrow_pos) = s.rfind("->") {
50            let ret_type = &s[arrow_pos..];
51            ret_type.contains("& mut Self") || ret_type.contains("&mut Self")
52        } else {
53            false
54        }
55    };
56
57    // Generate inner method name
58    let inner_name = Ident::new(&format!("__{fn_name}_inner"), Span::call_site());
59
60    // Build the signature for the inner method (everything before the body, with renamed fn)
61    let mut inner_sig_tokens: Vec<TokenTree> = Vec::new();
62    for (i, tt) in tokens[..body_idx].iter().enumerate() {
63        if i == fn_pos + 1 {
64            // Replace function name with inner name
65            inner_sig_tokens.push(TokenTree::Ident(inner_name.clone()));
66        } else {
67            inner_sig_tokens.push(tt.clone());
68        }
69    }
70    let inner_sig: TokenStream = inner_sig_tokens.into_iter().collect();
71
72    // Build the wrapper body
73    let wrapper_body = if returns_mut_self {
74        quote! {
75            {
76                match self.#inner_name() {
77                    Ok(_discarded) => Ok(self),
78                    Err(__e) => {
79                        #cleanup_expr;
80                        Err(__e)
81                    }
82                }
83            }
84        }
85    } else {
86        quote! {
87            {
88                let __result = self.#inner_name();
89                if __result.is_err() {
90                    #cleanup_expr;
91                }
92                __result
93            }
94        }
95    };
96
97    // Build the wrapper signature (original, without doc comments to avoid duplication)
98    let wrapper_sig: TokenStream = tokens[..body_idx].iter().cloned().collect();
99
100    // Combine: inner method + wrapper method
101    let wrapper_body_group = TokenTree::Group(Group::new(Delimiter::Brace, wrapper_body));
102
103    quote! {
104        #[doc(hidden)]
105        #[inline(always)]
106        #inner_sig
107        #body
108
109        #wrapper_sig
110        #wrapper_body_group
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_on_error_basic() {
120        let attr = quote! { self.cleanup() };
121        let item = quote! {
122            pub fn do_something(&mut self) -> Result<i32, Error> {
123                self.inner_work()?;
124                Ok(42)
125            }
126        };
127
128        let result = on_error(attr, item);
129        let result_str = result.to_string();
130
131        // Should contain the inner method
132        assert!(result_str.contains("__do_something_inner"));
133        assert!(result_str.contains("self . cleanup"));
134    }
135
136    #[test]
137    fn test_on_error_mut_self_return() {
138        let attr = quote! { self.poison_and_cleanup() };
139        let item = quote! {
140            pub fn begin_some(&mut self) -> Result<&mut Self, ReflectError> {
141                self.require_active()?;
142                Ok(self)
143            }
144        };
145
146        let result = on_error(attr, item);
147        let result_str = result.to_string();
148
149        // Should detect &mut Self return and use the special handling
150        assert!(result_str.contains("__begin_some_inner"));
151        assert!(result_str.contains("_discarded"));
152        assert!(result_str.contains("Ok (self)"));
153    }
154}