facet_macros_impl/
on_error.rs1use crate::{Delimiter, Group, Ident, Span, TokenStream, TokenTree};
2use quote::quote;
3
4pub fn on_error(attr: TokenStream, item: TokenStream) -> TokenStream {
16 let cleanup_expr = attr;
18
19 let tokens: Vec<TokenTree> = item.into_iter().collect();
20
21 let fn_pos = tokens
23 .iter()
24 .position(|tt| matches!(tt, TokenTree::Ident(id) if *id == "fn"))
25 .expect("Method must have 'fn' keyword");
26
27 let fn_name = if let TokenTree::Ident(id) = &tokens[fn_pos + 1] {
28 id.clone()
29 } else {
30 panic!("Expected function name after 'fn'");
31 };
32
33 let body_idx = tokens
35 .iter()
36 .rposition(|tt| matches!(tt, TokenTree::Group(g) if g.delimiter() == Delimiter::Brace))
37 .expect("Method must have a body");
38
39 let body = if let TokenTree::Group(g) = &tokens[body_idx] {
40 g.clone()
41 } else {
42 unreachable!()
43 };
44
45 let returns_mut_self = {
47 let before_body: TokenStream = tokens[..body_idx].iter().cloned().collect();
48 let s = before_body.to_string();
49 if let Some(arrow_pos) = s.rfind("->") {
50 let ret_type = &s[arrow_pos..];
51 ret_type.contains("& mut Self") || ret_type.contains("&mut Self")
52 } else {
53 false
54 }
55 };
56
57 let inner_name = Ident::new(&format!("__{fn_name}_inner"), Span::call_site());
59
60 let mut inner_sig_tokens: Vec<TokenTree> = Vec::new();
62 for (i, tt) in tokens[..body_idx].iter().enumerate() {
63 if i == fn_pos + 1 {
64 inner_sig_tokens.push(TokenTree::Ident(inner_name.clone()));
66 } else {
67 inner_sig_tokens.push(tt.clone());
68 }
69 }
70 let inner_sig: TokenStream = inner_sig_tokens.into_iter().collect();
71
72 let wrapper_body = if returns_mut_self {
74 quote! {
75 {
76 match self.#inner_name() {
77 Ok(_discarded) => Ok(self),
78 Err(__e) => {
79 #cleanup_expr;
80 Err(__e)
81 }
82 }
83 }
84 }
85 } else {
86 quote! {
87 {
88 let __result = self.#inner_name();
89 if __result.is_err() {
90 #cleanup_expr;
91 }
92 __result
93 }
94 }
95 };
96
97 let wrapper_sig: TokenStream = tokens[..body_idx].iter().cloned().collect();
99
100 let wrapper_body_group = TokenTree::Group(Group::new(Delimiter::Brace, wrapper_body));
102
103 quote! {
104 #[doc(hidden)]
105 #[inline(always)]
106 #inner_sig
107 #body
108
109 #wrapper_sig
110 #wrapper_body_group
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_on_error_basic() {
120 let attr = quote! { self.cleanup() };
121 let item = quote! {
122 pub fn do_something(&mut self) -> Result<i32, Error> {
123 self.inner_work()?;
124 Ok(42)
125 }
126 };
127
128 let result = on_error(attr, item);
129 let result_str = result.to_string();
130
131 assert!(result_str.contains("__do_something_inner"));
133 assert!(result_str.contains("self . cleanup"));
134 }
135
136 #[test]
137 fn test_on_error_mut_self_return() {
138 let attr = quote! { self.poison_and_cleanup() };
139 let item = quote! {
140 pub fn begin_some(&mut self) -> Result<&mut Self, ReflectError> {
141 self.require_active()?;
142 Ok(self)
143 }
144 };
145
146 let result = on_error(attr, item);
147 let result_str = result.to_string();
148
149 assert!(result_str.contains("__begin_some_inner"));
151 assert!(result_str.contains("_discarded"));
152 assert!(result_str.contains("Ok (self)"));
153 }
154}