1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as Tokens;
8
9use quote::quote;
10
11use syn::parse::Parse;
12use syn::parse_macro_input;
13use syn::Attribute;
14use syn::Expr;
15use syn::ItemFn;
16use syn::Lit;
17use syn::Meta;
18
19
20#[proc_macro_attribute]
21pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
22 let item = parse_macro_input!(item as ItemFn);
23 try_test(attr, item)
24 .unwrap_or_else(syn::Error::into_compile_error)
25 .into()
26}
27
28fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
29 let mut attribute_args = AttributeArgs::default();
30 if cfg!(feature = "unstable") {
31 let mut ignored_attrs = vec![];
32 for attr in attrs {
33 let matched = attribute_args.try_parse_attr_single(&attr)?;
34 if !matched {
36 ignored_attrs.push(attr);
37 }
38 }
39
40 Ok((attribute_args, ignored_attrs))
41 } else {
42 Ok((attribute_args, attrs))
43 }
44}
45
46fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
47 let inner_test = if attr.is_empty() {
48 quote! { ::core::prelude::v1::test }
49 } else {
50 attr.into()
51 };
52
53 let ItemFn {
54 attrs,
55 vis,
56 sig,
57 block,
58 } = input;
59
60 let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
61 let logging_init = expand_logging_init(&attribute_args);
62 let tracing_init = expand_tracing_init(&attribute_args);
63
64 let result = quote! {
65 #[#inner_test]
66 #(#ignored_attrs)*
67 #vis #sig {
68 mod init {
79 pub fn init() {
80 #logging_init
81 #tracing_init
82 }
83 }
84
85 init::init();
86
87 #block
88 }
89 };
90 Ok(result)
91}
92
93
94#[derive(Debug, Default)]
95struct AttributeArgs {
96 default_log_filter: Option<String>,
97}
98
99impl AttributeArgs {
100 fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
101 if !attr.path().is_ident("test_trace") {
102 return Ok(false)
103 }
104
105 let nested_meta = attr.parse_args_with(Meta::parse)?;
106 let name_value = if let Meta::NameValue(name_value) = nested_meta {
107 name_value
108 } else {
109 return Err(syn::Error::new_spanned(
110 &nested_meta,
111 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
112 ))
113 };
114
115 let ident = if let Some(ident) = name_value.path.get_ident() {
116 ident
117 } else {
118 return Err(syn::Error::new_spanned(
119 &name_value.path,
120 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
121 ))
122 };
123
124 let arg_ref = if ident == "default_log_filter" {
125 &mut self.default_log_filter
126 } else {
127 return Err(syn::Error::new_spanned(
128 &name_value.path,
129 "Unrecognized attribute, see documentation for details.",
130 ))
131 };
132
133 if let Expr::Lit(lit) = &name_value.value {
134 if let Lit::Str(lit_str) = &lit.lit {
135 *arg_ref = Some(lit_str.value());
136 }
137 }
138
139 if arg_ref.is_none() {
142 return Err(syn::Error::new_spanned(
143 &name_value.value,
144 "Failed to parse value, expected a string",
145 ))
146 }
147
148 Ok(true)
149 }
150}
151
152
153#[cfg(all(feature = "log", not(feature = "trace")))]
155fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
156 let add_default_log_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter
157 {
158 quote! {
159 let env_logger_builder = env_logger_builder
160 .parse_env(::test_trace::env_logger::Env::default().default_filter_or(#default_log_filter));
161 }
162 } else {
163 quote! {}
164 };
165
166 quote! {
167 {
168 let mut env_logger_builder = ::test_trace::env_logger::builder();
169 #add_default_log_filter
170 let _ = env_logger_builder.is_test(true).try_init();
171 }
172 }
173}
174
175#[cfg(not(all(feature = "log", not(feature = "trace"))))]
176fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
177 quote! {}
178}
179
180#[cfg(feature = "trace")]
182fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
183 let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
184 quote! {
185 ::test_trace::tracing_subscriber::EnvFilter::builder()
186 .with_default_directive(
187 #default_log_filter
188 .parse()
189 .expect("test-trace: default_log_filter must be valid")
190 )
191 .from_env_lossy()
192 }
193 } else {
194 quote! {
195 ::test_trace::tracing_subscriber::EnvFilter::builder()
196 .with_default_directive(
197 ::test_trace::tracing_subscriber::filter::LevelFilter::TRACE.into()
198 ).from_env_lossy()
199 }
200 };
201
202 quote! {
203 {
204 let __internal_event_filter = {
205 use ::test_trace::tracing_subscriber::fmt::format::FmtSpan;
206
207 match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
208 Some(mut value) => {
209 value.make_ascii_lowercase();
210 let value = value.to_str().expect("test-trace: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
211 value
212 .split(",")
213 .map(|filter| match filter.trim() {
214 "new" => FmtSpan::NEW,
215 "enter" => FmtSpan::ENTER,
216 "exit" => FmtSpan::EXIT,
217 "close" => FmtSpan::CLOSE,
218 "active" => FmtSpan::ACTIVE,
219 "full" => FmtSpan::FULL,
220 _ => panic!("test-trace: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
221 For example: `active` or `new,close`\n\t\
222 Supported filters: new, enter, exit, close, active, full\n\t\
223 Got: {}", value),
224 })
225 .fold(FmtSpan::NONE, |acc, filter| filter | acc)
226 },
227 None => FmtSpan::NONE,
228 }
229 };
230
231 let _ = ::test_trace::tracing_subscriber::FmtSubscriber::builder()
232 .with_env_filter(#env_filter)
233 .with_span_events(__internal_event_filter)
234 .with_test_writer()
235 .try_init();
236 }
237 }
238}
239
240#[cfg(not(feature = "trace"))]
241fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
242 quote! {}
243}