libtest2_harness/
harness.rs

1use libtest_lexarg::OutputFormat;
2
3use crate::{cli, notify, Case, RunError, RunMode, TestContext};
4
5pub trait HarnessState: sealed::_HarnessState_is_Sealed {}
6
7pub struct Harness<State: HarnessState> {
8    state: State,
9}
10
11pub struct StateInitial {
12    start: std::time::Instant,
13}
14impl HarnessState for StateInitial {}
15impl sealed::_HarnessState_is_Sealed for StateInitial {}
16
17impl Harness<StateInitial> {
18    pub fn new() -> Self {
19        Self {
20            state: StateInitial {
21                start: std::time::Instant::now(),
22            },
23        }
24    }
25
26    pub fn with_env(self) -> std::io::Result<Harness<StateArgs>> {
27        let raw = std::env::args_os();
28        self.with_args(raw)
29    }
30
31    pub fn with_args(
32        self,
33        args: impl IntoIterator<Item = impl Into<std::ffi::OsString>>,
34    ) -> std::io::Result<Harness<StateArgs>> {
35        let raw = expand_args(args)?;
36        Ok(Harness {
37            state: StateArgs {
38                start: self.state.start,
39                raw,
40            },
41        })
42    }
43}
44
45impl Default for Harness<StateInitial> {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51pub struct StateArgs {
52    start: std::time::Instant,
53    raw: Vec<std::ffi::OsString>,
54}
55impl HarnessState for StateArgs {}
56impl sealed::_HarnessState_is_Sealed for StateArgs {}
57
58impl Harness<StateArgs> {
59    pub fn parse(&self) -> Result<Harness<StateParsed>, cli::LexError<'_>> {
60        let mut parser = cli::Parser::new(&self.state.raw);
61        let opts = parse(&mut parser)?;
62
63        #[cfg(feature = "color")]
64        match opts.color {
65            libtest_lexarg::ColorConfig::AutoColor => anstream::ColorChoice::Auto,
66            libtest_lexarg::ColorConfig::AlwaysColor => anstream::ColorChoice::Always,
67            libtest_lexarg::ColorConfig::NeverColor => anstream::ColorChoice::Never,
68        }
69        .write_global();
70
71        let notifier = notifier(&opts);
72
73        Ok(Harness {
74            state: StateParsed {
75                start: self.state.start,
76                opts,
77                notifier,
78            },
79        })
80    }
81}
82
83pub struct StateParsed {
84    start: std::time::Instant,
85    opts: libtest_lexarg::TestOpts,
86    notifier: notify::ArcNotifier,
87}
88impl HarnessState for StateParsed {}
89impl sealed::_HarnessState_is_Sealed for StateParsed {}
90
91impl Harness<StateParsed> {
92    pub fn discover(
93        self,
94        cases: impl IntoIterator<Item = impl Case + 'static>,
95    ) -> std::io::Result<Harness<StateDiscovered>> {
96        self.state.notifier.notify(
97            notify::event::DiscoverStart {
98                elapsed_s: Some(notify::Elapsed(self.state.start.elapsed())),
99            }
100            .into(),
101        )?;
102
103        let mut selected_cases = Vec::new();
104        for case in cases {
105            let selected = case_priority(&case, &self.state.opts).is_some();
106            self.state.notifier.notify(
107                notify::event::DiscoverCase {
108                    name: case.name().to_owned(),
109                    mode: RunMode::Test,
110                    selected,
111                    elapsed_s: Some(notify::Elapsed(self.state.start.elapsed())),
112                }
113                .into(),
114            )?;
115            if selected {
116                selected_cases.push(Box::new(case) as Box<dyn Case>);
117            }
118        }
119
120        selected_cases.sort_unstable_by_key(|case| {
121            let priority = case_priority(case.as_ref(), &self.state.opts);
122            let name = case.name().to_owned();
123            (priority, name)
124        });
125
126        self.state.notifier.notify(
127            notify::event::DiscoverComplete {
128                elapsed_s: Some(notify::Elapsed(self.state.start.elapsed())),
129            }
130            .into(),
131        )?;
132
133        Ok(Harness {
134            state: StateDiscovered {
135                start: self.state.start,
136                opts: self.state.opts,
137                notifier: self.state.notifier,
138                cases: selected_cases,
139            },
140        })
141    }
142}
143
144pub struct StateDiscovered {
145    start: std::time::Instant,
146    opts: libtest_lexarg::TestOpts,
147    notifier: notify::ArcNotifier,
148    cases: Vec<Box<dyn Case>>,
149}
150impl HarnessState for StateDiscovered {}
151impl sealed::_HarnessState_is_Sealed for StateDiscovered {}
152
153impl Harness<StateDiscovered> {
154    pub fn run(self) -> std::io::Result<bool> {
155        if self.state.opts.list {
156            Ok(true)
157        } else {
158            run(
159                &self.state.start,
160                &self.state.opts,
161                self.state.cases,
162                self.state.notifier,
163            )
164        }
165    }
166}
167
168mod sealed {
169    #[allow(unnameable_types)]
170    #[allow(non_camel_case_types)]
171    pub trait _HarnessState_is_Sealed {}
172}
173
174pub const ERROR_EXIT_CODE: i32 = 101;
175
176fn expand_args(
177    args: impl IntoIterator<Item = impl Into<std::ffi::OsString>>,
178) -> std::io::Result<Vec<std::ffi::OsString>> {
179    let mut expanded = Vec::new();
180    for arg in args {
181        let arg = arg.into();
182        if let Some(argfile) = arg.to_str().and_then(|s| s.strip_prefix("@")) {
183            expanded.extend(parse_argfile(std::path::Path::new(argfile))?);
184        } else {
185            expanded.push(arg);
186        }
187    }
188    Ok(expanded)
189}
190
191fn parse_argfile(path: &std::path::Path) -> std::io::Result<Vec<std::ffi::OsString>> {
192    // Logic taken from rust-lang/rust's `compiler/rustc_driver_impl/src/args.rs`
193    let content = std::fs::read_to_string(path)?;
194    Ok(content.lines().map(|s| s.into()).collect())
195}
196
197fn parse<'p>(parser: &mut cli::Parser<'p>) -> Result<libtest_lexarg::TestOpts, cli::LexError<'p>> {
198    let mut test_opts = libtest_lexarg::TestOptsBuilder::new();
199
200    let bin = parser
201        .next_raw()
202        .expect("first arg, no pending values")
203        .unwrap_or(std::ffi::OsStr::new("test"));
204    let mut prev_arg = cli::Arg::Value(bin);
205    while let Some(arg) = parser.next_arg() {
206        match arg {
207            cli::Arg::Short("h") | cli::Arg::Long("help") => {
208                let mut bin = std::path::Path::new(bin);
209                if let Ok(current_dir) = std::env::current_dir() {
210                    // abbreviate the path because cargo always uses absolute paths
211                    bin = bin.strip_prefix(&current_dir).unwrap_or(bin);
212                }
213                let bin = bin.to_string_lossy();
214                let options_help = libtest_lexarg::OPTIONS_HELP.trim();
215                let after_help = libtest_lexarg::AFTER_HELP.trim();
216                println!(
217                    "Usage: {bin} [OPTIONS] [FILTER]...
218
219{options_help}
220
221{after_help}"
222                );
223                std::process::exit(0);
224            }
225            // All values are the same, whether escaped or not, so its a no-op
226            cli::Arg::Escape(_) => {
227                prev_arg = arg;
228                continue;
229            }
230            cli::Arg::Unexpected(_) => {
231                return Err(cli::LexError::msg("unexpected value")
232                    .unexpected(arg)
233                    .within(prev_arg));
234            }
235            _ => {}
236        }
237        prev_arg = arg;
238
239        let arg = test_opts.parse_next(parser, arg)?;
240
241        if let Some(arg) = arg {
242            return Err(cli::LexError::msg("unexpected argument").unexpected(arg));
243        }
244    }
245
246    let mut opts = test_opts.finish()?;
247    // If the platform is single-threaded we're just going to run
248    // the test synchronously, regardless of the concurrency
249    // level.
250    let supports_threads = !cfg!(target_os = "emscripten") && !cfg!(target_family = "wasm");
251    opts.test_threads = if cfg!(feature = "threads") && supports_threads {
252        opts.test_threads
253            .or_else(|| std::thread::available_parallelism().ok())
254    } else {
255        None
256    };
257    Ok(opts)
258}
259
260fn notifier(opts: &libtest_lexarg::TestOpts) -> notify::ArcNotifier {
261    #[cfg(feature = "color")]
262    let stdout = anstream::stdout();
263    #[cfg(not(feature = "color"))]
264    let stdout = std::io::stdout();
265    match opts.format {
266        OutputFormat::Json => notify::ArcNotifier::new(notify::JsonNotifier::new(stdout)),
267        _ if opts.list => notify::ArcNotifier::new(notify::TerseListNotifier::new(stdout)),
268        OutputFormat::Pretty => notify::ArcNotifier::new(notify::PrettyRunNotifier::new(stdout)),
269        OutputFormat::Terse => notify::ArcNotifier::new(notify::TerseRunNotifier::new(stdout)),
270    }
271}
272
273fn case_priority(case: &dyn Case, opts: &libtest_lexarg::TestOpts) -> Option<usize> {
274    let filtered_out =
275        !opts.skip.is_empty() && opts.skip.iter().any(|sf| matches_filter(case, sf, opts));
276    if filtered_out {
277        None
278    } else if opts.filters.is_empty() {
279        Some(0)
280    } else {
281        opts.filters
282            .iter()
283            .position(|filter| matches_filter(case, filter, opts))
284    }
285}
286
287fn matches_filter(case: &dyn Case, filter: &str, opts: &libtest_lexarg::TestOpts) -> bool {
288    let test_name = case.name();
289
290    match opts.filter_exact {
291        true => test_name == filter,
292        false => test_name.contains(filter),
293    }
294}
295
296fn run(
297    start: &std::time::Instant,
298    opts: &libtest_lexarg::TestOpts,
299    cases: Vec<Box<dyn Case>>,
300    notifier: notify::ArcNotifier,
301) -> std::io::Result<bool> {
302    notifier.notify(
303        notify::event::RunStart {
304            elapsed_s: Some(notify::Elapsed(start.elapsed())),
305        }
306        .into(),
307    )?;
308
309    if opts.no_capture {
310        return Err(std::io::Error::new(
311            std::io::ErrorKind::Unsupported,
312            "`--no-capture` is not supported at this time",
313        ));
314    }
315    if opts.show_output {
316        return Err(std::io::Error::new(
317            std::io::ErrorKind::Unsupported,
318            "`--show-output` is not supported at this time",
319        ));
320    }
321
322    let threads = opts.test_threads.map(|t| t.get()).unwrap_or(1);
323
324    let run_ignored = match opts.run_ignored {
325        libtest_lexarg::RunIgnored::Yes | libtest_lexarg::RunIgnored::Only => true,
326        libtest_lexarg::RunIgnored::No => false,
327    };
328    let mode = match (opts.run_tests, opts.bench_benchmarks) {
329        (true, true) => {
330            return Err(std::io::Error::other(
331                "`--test` and `-bench` are mutually exclusive",
332            ));
333        }
334        (true, false) => RunMode::Test,
335        (false, true) => RunMode::Bench,
336        (false, false) => unreachable!("libtest-lexarg` should always ensure at least one is set"),
337    };
338    let context = TestContext {
339        start: *start,
340        mode,
341        run_ignored,
342        notifier,
343        test_name: String::new(),
344    };
345
346    let mut success = true;
347
348    let (exclusive_cases, concurrent_cases) = if threads == 1 || cases.len() == 1 {
349        (cases, vec![])
350    } else {
351        cases
352            .into_iter()
353            .partition::<Vec<_>, _>(|c| c.exclusive(&context))
354    };
355    if !concurrent_cases.is_empty() {
356        context.notifier().threaded(true);
357
358        // Use a deterministic hasher
359        type TestMap = std::collections::HashMap<
360            String,
361            std::thread::JoinHandle<std::io::Result<bool>>,
362            std::hash::BuildHasherDefault<std::collections::hash_map::DefaultHasher>,
363        >;
364
365        let sync_success = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(success));
366        let mut running: TestMap = Default::default();
367        let (tx, rx) = std::sync::mpsc::channel::<String>();
368        let mut remaining = std::collections::VecDeque::from(concurrent_cases);
369        while !running.is_empty() || !remaining.is_empty() {
370            while running.len() < threads && !remaining.is_empty() {
371                let case = remaining.pop_front().unwrap();
372                let case = std::sync::Arc::new(case);
373                let name = case.name().to_owned();
374
375                let cfg = std::thread::Builder::new().name(name.clone());
376                let thread_tx = tx.clone();
377                let thread_case = case.clone();
378                let mut thread_context = context.clone();
379                thread_context.test_name = name.clone();
380                let thread_sync_success = sync_success.clone();
381                let join_handle = cfg.spawn(move || {
382                    let status = run_case(thread_case.as_ref().as_ref(), &thread_context);
383                    if !matches!(status, Ok(true)) {
384                        thread_sync_success.store(false, std::sync::atomic::Ordering::Relaxed);
385                    }
386                    let _ = thread_tx.send(thread_case.name().to_owned());
387                    status
388                });
389                match join_handle {
390                    Ok(join_handle) => {
391                        running.insert(name.clone(), join_handle);
392                    }
393                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
394                        // `ErrorKind::WouldBlock` means hitting the thread limit on some
395                        // platforms, so run the test synchronously here instead.
396                        let case_success = run_case(case.as_ref().as_ref(), &context)?;
397                        if !case_success {
398                            sync_success.store(case_success, std::sync::atomic::Ordering::Relaxed);
399                        }
400                    }
401                    Err(e) => {
402                        return Err(e);
403                    }
404                }
405            }
406
407            let test_name = rx.recv().unwrap();
408            let running_test = running.remove(&test_name).unwrap();
409            let _ = running_test.join();
410            success &= sync_success.load(std::sync::atomic::Ordering::SeqCst);
411            if !success && opts.fail_fast {
412                break;
413            }
414        }
415    }
416
417    if !exclusive_cases.is_empty() {
418        context.notifier().threaded(false);
419        for case in exclusive_cases {
420            success &= run_case(case.as_ref(), &context)?;
421            if !success && opts.fail_fast {
422                break;
423            }
424        }
425    }
426
427    context.notifier().notify(
428        notify::event::RunComplete {
429            elapsed_s: Some(notify::Elapsed(start.elapsed())),
430        }
431        .into(),
432    )?;
433
434    Ok(success)
435}
436
437fn run_case(case: &dyn Case, context: &TestContext) -> std::io::Result<bool> {
438    context.notifier().notify(
439        notify::event::CaseStart {
440            name: case.name().to_owned(),
441            elapsed_s: Some(context.elapased_s()),
442        }
443        .into(),
444    )?;
445
446    let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
447        __rust_begin_short_backtrace(|| case.run(context))
448    }))
449    .unwrap_or_else(|e| {
450        // The `panic` information is just an `Any` object representing the
451        // value the panic was invoked with. For most panics (which use
452        // `panic!` like `println!`), this is either `&str` or `String`.
453        let payload = e
454            .downcast_ref::<String>()
455            .map(|s| s.as_str())
456            .or_else(|| e.downcast_ref::<&str>().copied());
457
458        let msg = match payload {
459            Some(payload) => format!("test panicked: {payload}"),
460            None => "test panicked".to_owned(),
461        };
462        Err(RunError::fail(msg))
463    });
464
465    let mut case_status = None;
466    if let Some(err) = outcome.as_ref().err() {
467        let kind = err.status();
468        case_status = Some(kind);
469        let message = err.cause().map(|c| c.to_string());
470        context.notifier().notify(
471            notify::event::CaseMessage {
472                name: case.name().to_owned(),
473                kind,
474                message,
475                elapsed_s: Some(context.elapased_s()),
476            }
477            .into(),
478        )?;
479    }
480
481    context.notifier().notify(
482        notify::event::CaseComplete {
483            name: case.name().to_owned(),
484            elapsed_s: Some(context.elapased_s()),
485        }
486        .into(),
487    )?;
488
489    Ok(case_status != Some(notify::MessageKind::Error))
490}
491
492/// Fixed frame used to clean the backtrace with `RUST_BACKTRACE=1`.
493#[inline(never)]
494fn __rust_begin_short_backtrace<T, F: FnOnce() -> T>(f: F) -> T {
495    let result = f();
496
497    // prevent this frame from being tail-call optimised away
498    std::hint::black_box(result)
499}