Skip to main content

allure_test_macros/
lib.rs

1use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
2
3struct ShouldPanicConfig {
4    enabled: bool,
5    expected: Option<String>,
6}
7
8#[derive(Debug)]
9struct AttrArgs {
10    name: Option<String>,
11    id: Option<String>,
12}
13
14#[derive(Debug)]
15struct StepAttrArgs {
16    name: Option<String>,
17}
18
19#[proc_macro_attribute]
20pub fn step(args: TokenStream, input: TokenStream) -> TokenStream {
21    let attrs = match parse_step_args(args) {
22        Ok(attrs) => attrs,
23        Err(err) => return compile_error(err),
24    };
25
26    transform_step_fn(attrs, input)
27}
28
29#[proc_macro_attribute]
30pub fn allure_test(args: TokenStream, input: TokenStream) -> TokenStream {
31    let attrs = match parse_args(args) {
32        Ok(attrs) => attrs,
33        Err(err) => return compile_error(err),
34    };
35
36    transform_fn(attrs, input)
37}
38
39fn parse_step_args(args: TokenStream) -> Result<StepAttrArgs, &'static str> {
40    let attrs = parse_kv_args(args, &["name"])?;
41    Ok(StepAttrArgs { name: attrs.name })
42}
43
44fn parse_args(args: TokenStream) -> Result<AttrArgs, &'static str> {
45    parse_kv_args(args, &["name", "id"])
46}
47
48fn parse_kv_args(args: TokenStream, allowed_keys: &[&str]) -> Result<AttrArgs, &'static str> {
49    let tokens: Vec<TokenTree> = args.into_iter().collect();
50    if tokens.is_empty() {
51        return Ok(AttrArgs {
52            name: None,
53            id: None,
54        });
55    }
56
57    let mut idx = 0;
58    let mut parsed = AttrArgs {
59        name: None,
60        id: None,
61    };
62
63    while idx < tokens.len() {
64        let key = match &tokens[idx] {
65            TokenTree::Ident(id) => id.to_string(),
66            _ => return Err("unsupported attribute arguments, expected key = \"value\""),
67        };
68        idx += 1;
69
70        match tokens.get(idx) {
71            Some(TokenTree::Punct(eq)) if eq.as_char() == '=' => idx += 1,
72            _ => return Err("unsupported attribute arguments, expected key = \"value\""),
73        }
74
75        let raw = match tokens.get(idx) {
76            Some(TokenTree::Literal(value)) => value.to_string(),
77            _ => return Err("unsupported attribute arguments, expected key = \"value\""),
78        };
79        idx += 1;
80
81        if !raw.starts_with('"') || !raw.ends_with('"') || raw.len() < 2 {
82            return Err("attribute values must be string literals");
83        }
84
85        let value = raw[1..raw.len() - 1].to_string();
86        match key.as_str() {
87            "name" if allowed_keys.contains(&"name") => {
88                if parsed.name.is_some() {
89                    return Err("duplicate attribute argument: name");
90                }
91                parsed.name = Some(value);
92            }
93            "id" if allowed_keys.contains(&"id") => {
94                if parsed.id.is_some() {
95                    return Err("duplicate attribute argument: id");
96                }
97                parsed.id = Some(value);
98            }
99            _ => {
100                return Err(match allowed_keys {
101                    ["name"] => "unsupported attribute argument, expected: name",
102                    _ => "unsupported attribute argument, expected one of: name, id",
103                });
104            }
105        }
106
107        if idx < tokens.len() {
108            match &tokens[idx] {
109                TokenTree::Punct(p) if p.as_char() == ',' => idx += 1,
110                _ => {
111                    return Err(
112                        "unsupported attribute arguments, expected comma-separated key/value pairs",
113                    );
114                }
115            }
116        }
117    }
118
119    Ok(parsed)
120}
121
122fn transform_fn(attrs: AttrArgs, input: TokenStream) -> TokenStream {
123    let mut tokens: Vec<TokenTree> = input.into_iter().collect();
124
125    let fn_index = tokens
126        .iter()
127        .position(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "fn"));
128    let Some(fn_index) = fn_index else {
129        return compile_error("#[allure_test] can be applied only to functions");
130    };
131
132    if tokens[..fn_index]
133        .iter()
134        .any(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "async"))
135    {
136        return compile_error("#[allure_test] does not support async functions");
137    }
138
139    let should_panic = parse_should_panic_config(&tokens[..fn_index]);
140
141    let fn_name = tokens.iter().skip(fn_index + 1).find_map(|t| match t {
142        TokenTree::Ident(id) => Some(id.to_string()),
143        _ => None,
144    });
145    let Some(fn_name) = fn_name else {
146        return compile_error("failed to parse function name");
147    };
148
149    let body_index = tokens.iter().position(
150        |t| matches!(t, TokenTree::Group(group) if group.delimiter() == Delimiter::Brace),
151    );
152    let Some(body_index) = body_index else {
153        return compile_error("failed to parse function body");
154    };
155
156    if has_return_type(&tokens[fn_index..body_index]) {
157        return compile_error(
158            "#[allure_test] currently supports only test functions that return ()",
159        );
160    }
161
162    let original_body = match &tokens[body_index] {
163        TokenTree::Group(group) => group.stream().to_string(),
164        _ => return compile_error("failed to parse function body"),
165    };
166
167    let test_name = attrs.name.unwrap_or(fn_name.clone());
168    let allure_id_setup = match attrs.id.as_ref() {
169        Some(id) => format!("allure.id({id:?});"),
170        None => String::new(),
171    };
172
173    let wrapped_body_src = if should_panic.enabled {
174        format!(
175            "{{
176  let __allure_results_dir = ::std::env::var(\"ALLURE_RESULTS_DIR\")
177    .unwrap_or_else(|_| \"target/allure-results\".to_string());
178  let __allure_reporter = ::allure_cargotest::CargoTestReporter::new(__allure_results_dir)
179    .expect(\"allure reporter should be created\");
180  if !__allure_reporter.is_selected({test_name:?}, Some({test_name:?}), None, None) {{
181    return;
182  }}
183  __allure_reporter.run_test_with_result({test_name:?}, |allure| {{
184    {allure_id_setup}
185    let __allure_result = ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {{ {original_body} }}));
186    match __allure_result {{
187      Ok(()) => (
188        ::allure_cargotest::Status::Failed,
189        Some(::allure_cargotest::StatusDetails {{
190          message: Some(\"expected panic but none occurred\".to_string()),
191          trace: None,
192          actual: None,
193          expected: None,
194        }}),
195        None,
196      ),
197      Err(__allure_payload) => {{
198        let __allure_message = if let Some(__allure_msg) = __allure_payload.downcast_ref::<&str>() {{
199          (*__allure_msg).to_string()
200        }} else if let Some(__allure_msg) = __allure_payload.downcast_ref::<String>() {{
201          __allure_msg.clone()
202        }} else {{
203          \"panic without string payload\".to_string()
204        }};
205        {}
206      }}
207    }}
208  }});
209}}",
210            expected_match_arm(&should_panic.expected)
211        )
212    } else {
213        format!(
214            "{{
215  let __allure_results_dir = ::std::env::var(\"ALLURE_RESULTS_DIR\")
216    .unwrap_or_else(|_| \"target/allure-results\".to_string());
217  let __allure_reporter = ::allure_cargotest::CargoTestReporter::new(__allure_results_dir)
218    .expect(\"allure reporter should be created\");
219  let __allure_full_name = format!(\"{{}}::{{}}\", module_path!(), {fn_name:?});
220  __allure_reporter.run_test_with_metadata({test_name:?}, Some(&__allure_full_name), None, None, |allure| {{ {allure_id_setup} {original_body} }});
221}}"
222        )
223    };
224
225    let wrapped_body_stream: TokenStream = match wrapped_body_src.parse() {
226        Ok(stream) => stream,
227        Err(_) => return compile_error("failed to generate transformed test body"),
228    };
229    let wrapped_group = match wrapped_body_stream.into_iter().next() {
230        Some(TokenTree::Group(group)) => group,
231        _ => return compile_error("failed to generate transformed test body"),
232    };
233
234    tokens[body_index] = TokenTree::Group(Group::new(Delimiter::Brace, wrapped_group.stream()));
235
236    TokenStream::from_iter(tokens)
237}
238
239fn parse_should_panic_config(tokens: &[TokenTree]) -> ShouldPanicConfig {
240    let mut index = 0;
241    while index + 1 < tokens.len() {
242        let Some(TokenTree::Punct(pound)) = tokens.get(index) else {
243            index += 1;
244            continue;
245        };
246        if pound.as_char() != '#' {
247            index += 1;
248            continue;
249        }
250
251        let Some(TokenTree::Group(group)) = tokens.get(index + 1) else {
252            index += 1;
253            continue;
254        };
255        if group.delimiter() != Delimiter::Bracket {
256            index += 2;
257            continue;
258        }
259
260        let mut attr_tokens = group.stream().into_iter();
261        let Some(TokenTree::Ident(name)) = attr_tokens.next() else {
262            index += 2;
263            continue;
264        };
265        if name.to_string() != "should_panic" {
266            index += 2;
267            continue;
268        }
269
270        let expected = attr_tokens.find_map(|token| match token {
271            TokenTree::Group(arguments) if arguments.delimiter() == Delimiter::Parenthesis => {
272                parse_should_panic_expected(arguments.stream())
273            }
274            _ => None,
275        });
276        return ShouldPanicConfig {
277            enabled: true,
278            expected,
279        };
280    }
281
282    ShouldPanicConfig {
283        enabled: false,
284        expected: None,
285    }
286}
287
288fn parse_should_panic_expected(tokens: TokenStream) -> Option<String> {
289    let parsed: Vec<TokenTree> = tokens.into_iter().collect();
290    for window in parsed.windows(3) {
291        match window {
292            [TokenTree::Ident(name), TokenTree::Punct(eq), TokenTree::Literal(value)]
293                if name.to_string() == "expected" && eq.as_char() == '=' =>
294            {
295                let raw = value.to_string();
296                if raw.starts_with('"') && raw.ends_with('"') && raw.len() >= 2 {
297                    return Some(raw[1..raw.len() - 1].to_string());
298                }
299            }
300            _ => {}
301        }
302    }
303    None
304}
305
306fn expected_match_arm(expected: &Option<String>) -> String {
307    match expected {
308        Some(expected) => format!(
309            "if __allure_message.contains({expected:?}) {{
310          (
311            ::allure_cargotest::Status::Passed,
312            None,
313            Some(__allure_payload),
314          )
315        }} else {{
316          (
317            ::allure_cargotest::Status::Failed,
318            Some(::allure_cargotest::StatusDetails {{
319              message: Some(format!(\"panic message mismatch: expected substring {{:?}}, got {{:?}}\", {expected:?}, __allure_message)),
320              trace: None,
321              actual: None,
322              expected: None,
323            }}),
324            Some(__allure_payload),
325          )
326        }}"
327        ),
328        None => "(
329          ::allure_cargotest::Status::Passed,
330          None,
331          Some(__allure_payload),
332        )"
333        .to_string(),
334    }
335}
336
337fn has_return_type(tokens: &[TokenTree]) -> bool {
338    for window in tokens.windows(2) {
339        if let [TokenTree::Punct(first), TokenTree::Punct(second)] = window {
340            if first.as_char() == '-' && second.as_char() == '>' {
341                return true;
342            }
343        }
344    }
345    false
346}
347
348fn compile_error(message: &str) -> TokenStream {
349    format!("compile_error!({message:?});")
350        .parse()
351        .unwrap_or_default()
352}
353
354fn transform_step_fn(attrs: StepAttrArgs, input: TokenStream) -> TokenStream {
355    let mut tokens: Vec<TokenTree> = input.into_iter().collect();
356
357    let fn_index = tokens
358        .iter()
359        .position(|t| matches!(t, TokenTree::Ident(id) if id.to_string() == "fn"));
360    let Some(fn_index) = fn_index else {
361        return compile_error("#[step] can be applied only to functions");
362    };
363
364    let fn_name = tokens.iter().skip(fn_index + 1).find_map(|t| match t {
365        TokenTree::Ident(id) => Some(id.to_string()),
366        _ => None,
367    });
368    let Some(fn_name) = fn_name else {
369        return compile_error("failed to parse function name");
370    };
371
372    let body_index = tokens.iter().position(
373        |t| matches!(t, TokenTree::Group(group) if group.delimiter() == Delimiter::Brace),
374    );
375    let Some(body_index) = body_index else {
376        return compile_error("failed to parse function body");
377    };
378
379    let original_body = match &tokens[body_index] {
380        TokenTree::Group(group) => group.stream().to_string(),
381        _ => return compile_error("failed to parse function body"),
382    };
383
384    let step_name = attrs.name.unwrap_or(fn_name);
385    let wrapped_body_src = format!(
386        "{{
387  let __allure_step_name = {step_name:?};
388  match ::allure_cargotest::__private::current_allure() {{
389    Some(__allure_step_allure) => {{
390      let __allure_step_guard = __allure_step_allure.step(__allure_step_name);
391      {original_body}
392    }}
393    None => {{
394      {original_body}
395    }}
396  }}
397}}"
398    );
399
400    let wrapped_body_stream: TokenStream = match wrapped_body_src.parse() {
401        Ok(stream) => stream,
402        Err(_) => return compile_error("failed to generate transformed function body"),
403    };
404    let wrapped_group = match wrapped_body_stream.into_iter().next() {
405        Some(TokenTree::Group(group)) => group,
406        _ => return compile_error("failed to generate transformed function body"),
407    };
408
409    tokens[body_index] = TokenTree::Group(Group::new(Delimiter::Brace, wrapped_group.stream()));
410
411    TokenStream::from_iter(tokens)
412}