test_log_macros/
lib.rs

1// Copyright (C) 2019-2025 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4//! Procedural macro powering `test-log`.
5
6use std::borrow::Cow;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as Tokens;
10
11use quote::quote;
12
13use syn::parse::Parse;
14use syn::parse_macro_input;
15use syn::Attribute;
16use syn::Expr;
17use syn::ItemFn;
18use syn::Lit;
19use syn::Meta;
20
21
22// Documented in `test-log` crate's re-export.
23#[allow(missing_docs)]
24#[proc_macro_attribute]
25pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
26  let item = parse_macro_input!(item as ItemFn);
27  try_test(attr, item)
28    .unwrap_or_else(syn::Error::into_compile_error)
29    .into()
30}
31
32fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
33  let mut attribute_args = AttributeArgs::default();
34  if cfg!(feature = "unstable") {
35    let mut ignored_attrs = vec![];
36    for attr in attrs {
37      let matched = attribute_args.try_parse_attr_single(&attr)?;
38      // Keep only attrs that didn't match the #[test_log(_)] syntax.
39      if !matched {
40        ignored_attrs.push(attr);
41      }
42    }
43
44    Ok((attribute_args, ignored_attrs))
45  } else {
46    Ok((attribute_args, attrs))
47  }
48}
49
50// Check whether given attribute is a test attribute of forms:
51// * `#[test]`
52// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]`
53// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]`
54fn is_test_attribute(attr: &Attribute) -> bool {
55  let path = match &attr.meta {
56    syn::Meta::Path(path) => path,
57    _ => return false,
58  };
59  let candidates = [
60    ["core", "prelude", "*", "test"],
61    ["std", "prelude", "*", "test"],
62  ];
63  if path.leading_colon.is_none()
64    && path.segments.len() == 1
65    && path.segments[0].arguments.is_none()
66    && path.segments[0].ident == "test"
67  {
68    return true;
69  } else if path.segments.len() != candidates[0].len() {
70    return false;
71  }
72  candidates.into_iter().any(|segments| {
73    path
74      .segments
75      .iter()
76      .zip(segments)
77      .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
78  })
79}
80
81fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
82  let ItemFn {
83    attrs,
84    vis,
85    sig,
86    block,
87  } = input;
88
89  let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
90  let logging_init = expand_logging_init(&attribute_args);
91  let tracing_init = expand_tracing_init(&attribute_args);
92
93  let (inner_test, generated_test) = if attr.is_empty() {
94    let has_test = ignored_attrs.iter().any(is_test_attribute);
95    let generated_test = if has_test {
96      quote! {}
97    } else {
98      quote! { #[::core::prelude::v1::test]}
99    };
100    (quote! {}, generated_test)
101  } else {
102    let attr = Tokens::from(attr);
103    (quote! { #[#attr] }, quote! {})
104  };
105
106  let result = quote! {
107    #inner_test
108    #(#ignored_attrs)*
109    #generated_test
110    #vis #sig {
111      // We put all initialization code into a separate module here in
112      // order to prevent potential ambiguities that could result in
113      // compilation errors. E.g., client code could use traits that
114      // could have methods that interfere with ones we use as part of
115      // initialization; with a `Foo` trait that is implemented for T
116      // and that contains a `map` (or similarly common named) method
117      // that could cause an ambiguity with `Iterator::map`, for
118      // example.
119      // The alternative would be to use fully qualified call syntax in
120      // all initialization code, but that's much harder to control.
121      mod init {
122        pub fn init() {
123          #logging_init
124          #tracing_init
125        }
126      }
127
128      init::init();
129
130      #block
131    }
132  };
133  Ok(result)
134}
135
136
137#[derive(Debug, Default)]
138struct AttributeArgs {
139  default_log_filter: Option<Cow<'static, str>>,
140}
141
142impl AttributeArgs {
143  fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
144    if !attr.path().is_ident("test_log") {
145      return Ok(false)
146    }
147
148    let nested_meta = attr.parse_args_with(Meta::parse)?;
149    let name_value = if let Meta::NameValue(name_value) = nested_meta {
150      name_value
151    } else {
152      return Err(syn::Error::new_spanned(
153        &nested_meta,
154        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
155      ))
156    };
157
158    let ident = if let Some(ident) = name_value.path.get_ident() {
159      ident
160    } else {
161      return Err(syn::Error::new_spanned(
162        &name_value.path,
163        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
164      ))
165    };
166
167    let arg_ref = if ident == "default_log_filter" {
168      &mut self.default_log_filter
169    } else {
170      return Err(syn::Error::new_spanned(
171        &name_value.path,
172        "Unrecognized attribute, see documentation for details.",
173      ))
174    };
175
176    if let Expr::Lit(lit) = &name_value.value {
177      if let Lit::Str(lit_str) = &lit.lit {
178        *arg_ref = Some(Cow::from(lit_str.value()));
179      }
180    }
181
182    // If we couldn't parse the value on the right-hand side because it was some
183    // unexpected type, e.g. #[test_log::log(default_log_filter=10)], return an error.
184    if arg_ref.is_none() {
185      return Err(syn::Error::new_spanned(
186        &name_value.value,
187        "Failed to parse value, expected a string",
188      ))
189    }
190
191    Ok(true)
192  }
193}
194
195
196/// Expand the initialization code for the `log` crate.
197#[cfg(all(feature = "log", not(feature = "trace")))]
198fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
199  let default_filter = attribute_args
200    .default_log_filter
201    .as_ref()
202    .unwrap_or(&Cow::Borrowed("info"));
203
204  quote! {
205    {
206      let _result = ::test_log::env_logger::builder()
207        .parse_env(
208          ::test_log::env_logger::Env::default()
209            .default_filter_or(#default_filter)
210        )
211        .target(::test_log::env_logger::Target::Stderr)
212        .is_test(true)
213        .try_init();
214    }
215  }
216}
217
218#[cfg(not(all(feature = "log", not(feature = "trace"))))]
219fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
220  quote! {}
221}
222
223/// Expand the initialization code for the `tracing` crate.
224#[cfg(feature = "trace")]
225fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
226  let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
227    quote! {
228      ::test_log::tracing_subscriber::EnvFilter::builder()
229        .with_default_directive(
230          #default_log_filter
231            .parse()
232            .expect("test-log: default_log_filter must be valid")
233        )
234        .from_env_lossy()
235    }
236  } else {
237    quote! {
238      ::test_log::tracing_subscriber::EnvFilter::builder()
239        .with_default_directive(
240          ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
241        )
242        .from_env_lossy()
243    }
244  };
245
246  quote! {
247    {
248      let __internal_event_filter = {
249        use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
250
251        match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
252          Some(mut value) => {
253            value.make_ascii_lowercase();
254            let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
255            value
256              .split(",")
257              .map(|filter| match filter.trim() {
258                "new" => FmtSpan::NEW,
259                "enter" => FmtSpan::ENTER,
260                "exit" => FmtSpan::EXIT,
261                "close" => FmtSpan::CLOSE,
262                "active" => FmtSpan::ACTIVE,
263                "full" => FmtSpan::FULL,
264                _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
265                  For example: `active` or `new,close`\n\t\
266                  Supported filters: new, enter, exit, close, active, full\n\t\
267                  Got: {}", value),
268              })
269              .fold(FmtSpan::NONE, |acc, filter| filter | acc)
270          },
271          None => FmtSpan::NONE,
272        }
273      };
274
275      let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
276        .with_env_filter(#env_filter)
277        .with_span_events(__internal_event_filter)
278        .with_writer(::test_log::tracing_subscriber::fmt::TestWriter::with_stderr)
279        .try_init();
280    }
281  }
282}
283
284#[cfg(not(feature = "trace"))]
285fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
286  quote! {}
287}