1use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
2
3struct ShouldPanicConfig {
4 enabled: bool,
5 expected: Option<String>,
6}
7
8#[derive(Debug)]
9struct AttrArgs {
10 name: Option<String>,
11 id: Option<String>,
12}
13
14#[derive(Debug)]
15struct StepAttrArgs {
16 name: Option<String>,
17}
18
19#[proc_macro_attribute]
20pub fn step(args: TokenStream, input: TokenStream) -> TokenStream {
21 let attrs = match parse_step_args(args) {
22 Ok(attrs) => attrs,
23 Err(err) => return compile_error(err),
24 };
25
26 transform_step_fn(attrs, input)
27}
28
29#[proc_macro_attribute]
30pub fn allure_test(args: TokenStream, input: TokenStream) -> TokenStream {
31 let attrs = match parse_args(args) {
32 Ok(attrs) => attrs,
33 Err(err) => return compile_error(err),
34 };
35
36 transform_fn(attrs, input)
37}
38
39fn parse_step_args(args: TokenStream) -> Result<StepAttrArgs, &'static str> {
40 let attrs = parse_kv_args(args, &["name"])?;
41 Ok(StepAttrArgs { name: attrs.name })
42}
43
44fn parse_args(args: TokenStream) -> Result<AttrArgs, &'static str> {
45 parse_kv_args(args, &["name", "id"])
46}
47
48fn parse_kv_args(args: TokenStream, allowed_keys: &[&str]) -> Result<AttrArgs, &'static str> {
49 let tokens: Vec<TokenTree> = args.into_iter().collect();
50 if tokens.is_empty() {
51 return Ok(AttrArgs {
52 name: None,
53 id: None,
54 });
55 }
56
57 let mut idx = 0;
58 let mut parsed = AttrArgs {
59 name: None,
60 id: None,
61 };
62
63 while idx < tokens.len() {
64 let key = match &tokens[idx] {
65 TokenTree::Ident(id) => id.to_string(),
66 _ => return Err("unsupported attribute arguments, expected key = \"value\""),
67 };
68 idx += 1;
69
70 match tokens.get(idx) {
71 Some(TokenTree::Punct(eq)) if eq.as_char() == '=' => idx += 1,
72 _ => return Err("unsupported attribute arguments, expected key = \"value\""),
73 }
74
75 let raw = match tokens.get(idx) {
76 Some(TokenTree::Literal(value)) => value.to_string(),
77 _ => return Err("unsupported attribute arguments, expected key = \"value\""),
78 };
79 idx += 1;
80
81 if !raw.starts_with('"') || !raw.ends_with('"') || raw.len() < 2 {
82 return Err("attribute values must be string literals");
83 }
84
85 let value = raw[1..raw.len() - 1].to_string();
86 match key.as_str() {
87 "name" if allowed_keys.contains(&"name") => {
88 if parsed.name.is_some() {
89 return Err("duplicate attribute argument: name");
90 }
91 parsed.name = Some(value);
92 }
93 "id" if allowed_keys.contains(&"id") => {
94 if parsed.id.is_some() {
95 return Err("duplicate attribute argument: id");
96 }
97 parsed.id = Some(value);
98 }
99 _ => {
100 return Err(match allowed_keys {
101 ["name"] => "unsupported attribute argument, expected: name",
102 _ => "unsupported attribute argument, expected one of: name, id",
103 });
104 }
105 }
106
107 if idx < tokens.len() {
108 match &tokens[idx] {
109 TokenTree::Punct(p) if p.as_char() == ',' => idx += 1,
110 _ => {
111 return Err(
112 "unsupported attribute arguments, expected comma-separated key/value pairs",
113 );
114 }
115 }
116 }
117 }
118
119 Ok(parsed)
120}
121
122fn transform_fn(attrs: AttrArgs, input: TokenStream) -> TokenStream {
123 let mut tokens: Vec<TokenTree> = input.into_iter().collect();
124
125 let fn_index = tokens
126 .iter()
127 .position(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "fn"));
128 let Some(fn_index) = fn_index else {
129 return compile_error("#[allure_test] can be applied only to functions");
130 };
131
132 if tokens[..fn_index]
133 .iter()
134 .any(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "async"))
135 {
136 return compile_error("#[allure_test] does not support async functions");
137 }
138
139 let should_panic = parse_should_panic_config(&tokens[..fn_index]);
140
141 let fn_name = tokens.iter().skip(fn_index + 1).find_map(|t| match t {
142 TokenTree::Ident(id) => Some(id.to_string()),
143 _ => None,
144 });
145 let Some(fn_name) = fn_name else {
146 return compile_error("failed to parse function name");
147 };
148
149 let body_index = tokens.iter().position(
150 |t| matches!(t, TokenTree::Group(group) if group.delimiter() == Delimiter::Brace),
151 );
152 let Some(body_index) = body_index else {
153 return compile_error("failed to parse function body");
154 };
155
156 if has_return_type(&tokens[fn_index..body_index]) {
157 return compile_error(
158 "#[allure_test] currently supports only test functions that return ()",
159 );
160 }
161
162 let original_body = match &tokens[body_index] {
163 TokenTree::Group(group) => group.stream().to_string(),
164 _ => return compile_error("failed to parse function body"),
165 };
166
167 let test_name = attrs.name.unwrap_or(fn_name.clone());
168 let allure_id_setup = match attrs.id.as_ref() {
169 Some(id) => format!("allure.id({id:?});"),
170 None => String::new(),
171 };
172
173 let wrapped_body_src = if should_panic.enabled {
174 format!(
175 "{{
176 let __allure_results_dir = ::std::env::var(\"ALLURE_RESULTS_DIR\")
177 .unwrap_or_else(|_| \"target/allure-results\".to_string());
178 let __allure_reporter = ::allure_cargotest::CargoTestReporter::new(__allure_results_dir)
179 .expect(\"allure reporter should be created\");
180 if !__allure_reporter.is_selected({test_name:?}, Some({test_name:?}), None, None) {{
181 return;
182 }}
183 __allure_reporter.run_test_with_result({test_name:?}, |allure| {{
184 {allure_id_setup}
185 let __allure_result = ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {{ {original_body} }}));
186 match __allure_result {{
187 Ok(()) => (
188 ::allure_cargotest::Status::Failed,
189 Some(::allure_cargotest::StatusDetails {{
190 message: Some(\"expected panic but none occurred\".to_string()),
191 trace: None,
192 actual: None,
193 expected: None,
194 }}),
195 None,
196 ),
197 Err(__allure_payload) => {{
198 let __allure_message = if let Some(__allure_msg) = __allure_payload.downcast_ref::<&str>() {{
199 (*__allure_msg).to_string()
200 }} else if let Some(__allure_msg) = __allure_payload.downcast_ref::<String>() {{
201 __allure_msg.clone()
202 }} else {{
203 \"panic without string payload\".to_string()
204 }};
205 {}
206 }}
207 }}
208 }});
209}}",
210 expected_match_arm(&should_panic.expected)
211 )
212 } else {
213 format!(
214 "{{
215 let __allure_results_dir = ::std::env::var(\"ALLURE_RESULTS_DIR\")
216 .unwrap_or_else(|_| \"target/allure-results\".to_string());
217 let __allure_reporter = ::allure_cargotest::CargoTestReporter::new(__allure_results_dir)
218 .expect(\"allure reporter should be created\");
219 let __allure_full_name = format!(\"{{}}::{{}}\", module_path!(), {fn_name:?});
220 __allure_reporter.run_test_with_metadata({test_name:?}, Some(&__allure_full_name), None, None, |allure| {{ {allure_id_setup} {original_body} }});
221}}"
222 )
223 };
224
225 let wrapped_body_stream: TokenStream = match wrapped_body_src.parse() {
226 Ok(stream) => stream,
227 Err(_) => return compile_error("failed to generate transformed test body"),
228 };
229 let wrapped_group = match wrapped_body_stream.into_iter().next() {
230 Some(TokenTree::Group(group)) => group,
231 _ => return compile_error("failed to generate transformed test body"),
232 };
233
234 tokens[body_index] = TokenTree::Group(Group::new(Delimiter::Brace, wrapped_group.stream()));
235
236 TokenStream::from_iter(tokens)
237}
238
239fn parse_should_panic_config(tokens: &[TokenTree]) -> ShouldPanicConfig {
240 let mut index = 0;
241 while index + 1 < tokens.len() {
242 let Some(TokenTree::Punct(pound)) = tokens.get(index) else {
243 index += 1;
244 continue;
245 };
246 if pound.as_char() != '#' {
247 index += 1;
248 continue;
249 }
250
251 let Some(TokenTree::Group(group)) = tokens.get(index + 1) else {
252 index += 1;
253 continue;
254 };
255 if group.delimiter() != Delimiter::Bracket {
256 index += 2;
257 continue;
258 }
259
260 let mut attr_tokens = group.stream().into_iter();
261 let Some(TokenTree::Ident(name)) = attr_tokens.next() else {
262 index += 2;
263 continue;
264 };
265 if name.to_string() != "should_panic" {
266 index += 2;
267 continue;
268 }
269
270 let expected = attr_tokens.find_map(|token| match token {
271 TokenTree::Group(arguments) if arguments.delimiter() == Delimiter::Parenthesis => {
272 parse_should_panic_expected(arguments.stream())
273 }
274 _ => None,
275 });
276 return ShouldPanicConfig {
277 enabled: true,
278 expected,
279 };
280 }
281
282 ShouldPanicConfig {
283 enabled: false,
284 expected: None,
285 }
286}
287
288fn parse_should_panic_expected(tokens: TokenStream) -> Option<String> {
289 let parsed: Vec<TokenTree> = tokens.into_iter().collect();
290 for window in parsed.windows(3) {
291 match window {
292 [TokenTree::Ident(name), TokenTree::Punct(eq), TokenTree::Literal(value)]
293 if name.to_string() == "expected" && eq.as_char() == '=' =>
294 {
295 let raw = value.to_string();
296 if raw.starts_with('"') && raw.ends_with('"') && raw.len() >= 2 {
297 return Some(raw[1..raw.len() - 1].to_string());
298 }
299 }
300 _ => {}
301 }
302 }
303 None
304}
305
306fn expected_match_arm(expected: &Option<String>) -> String {
307 match expected {
308 Some(expected) => format!(
309 "if __allure_message.contains({expected:?}) {{
310 (
311 ::allure_cargotest::Status::Passed,
312 None,
313 Some(__allure_payload),
314 )
315 }} else {{
316 (
317 ::allure_cargotest::Status::Failed,
318 Some(::allure_cargotest::StatusDetails {{
319 message: Some(format!(\"panic message mismatch: expected substring {{:?}}, got {{:?}}\", {expected:?}, __allure_message)),
320 trace: None,
321 actual: None,
322 expected: None,
323 }}),
324 Some(__allure_payload),
325 )
326 }}"
327 ),
328 None => "(
329 ::allure_cargotest::Status::Passed,
330 None,
331 Some(__allure_payload),
332 )"
333 .to_string(),
334 }
335}
336
337fn has_return_type(tokens: &[TokenTree]) -> bool {
338 for window in tokens.windows(2) {
339 if let [TokenTree::Punct(first), TokenTree::Punct(second)] = window {
340 if first.as_char() == '-' && second.as_char() == '>' {
341 return true;
342 }
343 }
344 }
345 false
346}
347
348fn compile_error(message: &str) -> TokenStream {
349 format!("compile_error!({message:?});")
350 .parse()
351 .unwrap_or_default()
352}
353
354fn transform_step_fn(attrs: StepAttrArgs, input: TokenStream) -> TokenStream {
355 let mut tokens: Vec<TokenTree> = input.into_iter().collect();
356
357 let fn_index = tokens
358 .iter()
359 .position(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "fn"));
360 let Some(fn_index) = fn_index else {
361 return compile_error("#[step] can be applied only to functions");
362 };
363
364 let fn_name = tokens.iter().skip(fn_index + 1).find_map(|t| match t {
365 TokenTree::Ident(id) => Some(id.to_string()),
366 _ => None,
367 });
368 let Some(fn_name) = fn_name else {
369 return compile_error("failed to parse function name");
370 };
371
372 let body_index = tokens.iter().position(
373 |t| matches!(t, TokenTree::Group(group) if group.delimiter() == Delimiter::Brace),
374 );
375 let Some(body_index) = body_index else {
376 return compile_error("failed to parse function body");
377 };
378
379 let original_body = match &tokens[body_index] {
380 TokenTree::Group(group) => group.stream().to_string(),
381 _ => return compile_error("failed to parse function body"),
382 };
383
384 let step_name = attrs.name.unwrap_or(fn_name);
385 let wrapped_body_src = format!(
386 "{{
387 let __allure_step_name = {step_name:?};
388 match ::allure_cargotest::__private::current_allure() {{
389 Some(__allure_step_allure) => {{
390 let __allure_step_guard = __allure_step_allure.step(__allure_step_name);
391 {original_body}
392 }}
393 None => {{
394 {original_body}
395 }}
396 }}
397}}"
398 );
399
400 let wrapped_body_stream: TokenStream = match wrapped_body_src.parse() {
401 Ok(stream) => stream,
402 Err(_) => return compile_error("failed to generate transformed function body"),
403 };
404 let wrapped_group = match wrapped_body_stream.into_iter().next() {
405 Some(TokenTree::Group(group)) => group,
406 _ => return compile_error("failed to generate transformed function body"),
407 };
408
409 tokens[body_index] = TokenTree::Group(Group::new(Delimiter::Brace, wrapped_group.stream()));
410
411 TokenStream::from_iter(tokens)
412}