1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::time::Duration;
7use std::time::Instant;
8
9use parking_lot::Mutex;
10use rayon::ThreadPool;
11
12use crate::NO_CAPTURE;
13use crate::collection::CollectedCategoryOrTest;
14use crate::collection::CollectedTest;
15use crate::collection::CollectedTestCategory;
16use crate::reporter::LogReporter;
17use crate::reporter::Reporter;
18use crate::reporter::ReporterContext;
19use crate::reporter::ReporterFailure;
20
21type RunTestFunc<TData> =
22 Arc<dyn (Fn(&CollectedTest<TData>) -> TestResult) + Send + Sync>;
23
24struct Context<TData: Clone + Send + 'static> {
25 failures: Vec<ReporterFailure<TData>>,
26 parallelism: Parallelism,
27 run_test: RunTestFunc<TData>,
28 reporter: Arc<dyn Reporter<TData>>,
29 pool: ThreadPool,
30}
31
32static GLOBAL_PANIC_HOOK_COUNT: Mutex<usize> = Mutex::new(0);
33
34type PanicHook = Box<dyn Fn(&std::panic::PanicHookInfo) + Sync + Send>;
35
36thread_local! {
37 static LOCAL_PANIC_HOOK: RefCell<Option<PanicHook>> = RefCell::new(None);
38}
39
40#[derive(Debug, Clone)]
41pub struct SubTestResult {
42 pub name: String,
43 pub result: TestResult,
44}
45
46#[must_use]
47#[derive(Debug, Clone)]
48pub enum TestResult {
49 Passed {
51 duration: Option<Duration>,
53 },
54 Ignored,
56 Failed {
58 duration: Option<Duration>,
60 output: Vec<u8>,
62 },
63 SubTests {
65 duration: Option<Duration>,
67 sub_tests: Vec<SubTestResult>,
68 },
69}
70
71impl TestResult {
72 pub fn duration(&self) -> Option<Duration> {
73 match self {
74 TestResult::Passed { duration } => *duration,
75 TestResult::Ignored => None,
76 TestResult::Failed { duration, .. } => *duration,
77 TestResult::SubTests { duration, .. } => *duration,
78 }
79 }
80
81 pub fn with_duration(self, duration: Duration) -> Self {
82 match self {
83 TestResult::Passed { duration: _ } => TestResult::Passed {
84 duration: Some(duration),
85 },
86 TestResult::Ignored => TestResult::Ignored,
87 TestResult::Failed {
88 duration: _,
89 output,
90 } => TestResult::Failed {
91 duration: Some(duration),
92 output,
93 },
94 TestResult::SubTests {
95 duration: _,
96 sub_tests,
97 } => TestResult::SubTests {
98 duration: Some(duration),
99 sub_tests,
100 },
101 }
102 }
103
104 pub fn is_failed(&self) -> bool {
105 match self {
106 TestResult::Passed { .. } | TestResult::Ignored => false,
107 TestResult::Failed { .. } => true,
108 TestResult::SubTests { sub_tests, .. } => {
109 sub_tests.iter().any(|s| s.result.is_failed())
110 }
111 }
112 }
113
114 pub fn from_maybe_panic(
119 func: impl FnOnce() + std::panic::UnwindSafe,
120 ) -> Self {
121 Self::from_maybe_panic_or_result(|| {
122 func();
123 TestResult::Passed { duration: None }
124 })
125 }
126
127 pub fn from_maybe_panic_or_result(
133 func: impl FnOnce() -> TestResult + std::panic::UnwindSafe,
134 ) -> Self {
135 {
137 let mut hook_count = GLOBAL_PANIC_HOOK_COUNT.lock();
138 if *hook_count == 0 {
139 let _ = std::panic::take_hook();
140 std::panic::set_hook(Box::new(|info| {
141 LOCAL_PANIC_HOOK.with(|hook| {
142 if let Some(hook) = &*hook.borrow() {
143 hook(info);
144 }
145 });
146 }));
147 }
148 *hook_count += 1;
149 drop(hook_count); }
151
152 let panic_message = Arc::new(Mutex::new(Vec::<u8>::new()));
153
154 let previous_panic_hook = LOCAL_PANIC_HOOK.with(|hook| {
155 let panic_message = panic_message.clone();
156 hook.borrow_mut().replace(Box::new(move |info| {
157 let backtrace = capture_backtrace();
158 panic_message.lock().extend(
159 format!(
160 "{}{}",
161 info,
162 backtrace
163 .map(|trace| format!("\n{}", trace))
164 .unwrap_or_default()
165 )
166 .into_bytes(),
167 );
168 }))
169 });
170
171 let result = std::panic::catch_unwind(func);
172
173 LOCAL_PANIC_HOOK.with(|hook| {
175 *hook.borrow_mut() = previous_panic_hook;
176 });
177
178 {
180 let mut hook_count = GLOBAL_PANIC_HOOK_COUNT.lock();
181 *hook_count -= 1;
182 if *hook_count == 0 {
183 let _ = std::panic::take_hook();
184 }
185 drop(hook_count); }
187
188 result.unwrap_or_else(|_| TestResult::Failed {
189 duration: None,
190 output: panic_message.lock().clone(),
191 })
192 }
193}
194
195fn capture_backtrace() -> Option<String> {
196 let backtrace = std::backtrace::Backtrace::capture();
197 if backtrace.status() != std::backtrace::BacktraceStatus::Captured {
198 return None;
199 }
200 let text = format!("{}", backtrace);
201 let lines = text.lines().collect::<Vec<_>>();
203 let last_position = lines
204 .iter()
205 .position(|line| line.contains("core::panicking::panic_fmt"));
206 Some(match last_position {
207 Some(position) => lines[position + 2..].join("\n"),
208 None => text,
209 })
210}
211
212#[derive(Debug, Copy, Clone)]
213pub struct Parallelism(NonZeroUsize);
214
215impl Default for Parallelism {
216 fn default() -> Self {
217 Self::from_usize(if *NO_CAPTURE {
218 1
219 } else {
220 std::env::var("FILE_TEST_RUNNER_PARALLELISM")
221 .ok()
222 .and_then(|v| v.parse().ok())
223 .unwrap_or_else(|| {
224 std::thread::available_parallelism()
225 .map(|v| v.get())
226 .unwrap_or(2)
227 - 1
228 })
229 })
230 }
231}
232
233impl Parallelism {
234 pub fn from_bool(value: bool) -> Self {
235 if value {
236 Default::default()
237 } else {
238 Self::from_usize(1)
239 }
240 }
241
242 pub fn from_usize(value: usize) -> Self {
243 Self(NonZeroUsize::new(value).unwrap_or(NonZeroUsize::new(1).unwrap()))
244 }
245
246 pub fn get(&self) -> usize {
247 self.0.get()
248 }
249}
250
251#[derive(Clone)]
252pub struct RunOptions<TData> {
253 pub parallelism: Parallelism,
254 pub reporter: Arc<dyn Reporter<TData>>,
255}
256
257impl<TData> Default for RunOptions<TData> {
258 fn default() -> Self {
259 Self {
260 parallelism: Default::default(),
261 reporter: Arc::new(LogReporter::default()),
262 }
263 }
264}
265
266pub struct TestRunSummary {
268 pub failure_count: usize,
269 pub tests_count: usize,
270}
271
272impl TestRunSummary {
273 pub fn panic_on_failures(&self) {
275 if self.failure_count > 0 {
276 panic!("{} failed of {}", self.failure_count, self.tests_count);
277 }
278 }
279}
280
281pub fn run_tests<TData: Clone + Send + 'static>(
283 category: &CollectedTestCategory<TData>,
284 options: RunOptions<TData>,
285 run_test: impl (Fn(&CollectedTest<TData>) -> TestResult) + Send + Sync + 'static,
286) {
287 run_tests_summary(category, options, run_test).panic_on_failures();
288}
289
290pub fn run_tests_summary<TData: Clone + Send + 'static>(
292 category: &CollectedTestCategory<TData>,
293 options: RunOptions<TData>,
294 run_test: impl (Fn(&CollectedTest<TData>) -> TestResult) + Send + Sync + 'static,
295) -> TestRunSummary {
296 let total_tests = category.test_count();
297 if total_tests == 0 {
298 return TestRunSummary {
299 failure_count: 0,
300 tests_count: 0,
301 };
302 }
303
304 let run_test = Arc::new(run_test);
305
306 let pool = rayon::ThreadPoolBuilder::new()
308 .num_threads(options.parallelism.get() + 1)
310 .build()
311 .expect("Failed to create thread pool");
312
313 let mut context = Context {
314 failures: Vec::new(),
315 run_test,
316 parallelism: options.parallelism,
317 reporter: options.reporter,
318 pool,
319 };
320 run_category(category, &mut context);
321
322 context
323 .reporter
324 .report_failures(&context.failures, total_tests);
325
326 TestRunSummary {
327 failure_count: context.failures.len(),
328 tests_count: total_tests,
329 }
330}
331
332fn run_category<TData: Clone + Send>(
333 category: &CollectedTestCategory<TData>,
334 context: &mut Context<TData>,
335) {
336 let mut tests = Vec::new();
337 let mut categories = Vec::new();
338 for child in &category.children {
339 match child {
340 CollectedCategoryOrTest::Category(c) => {
341 categories.push(c);
342 }
343 CollectedCategoryOrTest::Test(t) => {
344 tests.push(t.clone());
345 }
346 }
347 }
348
349 if !tests.is_empty() {
350 run_tests_for_category(category, tests, context);
351 }
352
353 for category in categories {
354 run_category(category, context);
355 }
356}
357
358fn run_tests_for_category<TData: Clone + Send>(
359 category: &CollectedTestCategory<TData>,
360 tests: Vec<CollectedTest<TData>>,
361 context: &mut Context<TData>,
362) {
363 enum SendMessage<TData> {
364 Start {
365 test: CollectedTest<TData>,
366 },
367 Result {
368 test: CollectedTest<TData>,
369 duration: Duration,
370 result: TestResult,
371 },
372 }
373
374 if tests.is_empty() {
375 return; }
377
378 let reporter = &context.reporter;
379 let max_parallelism = context.parallelism.get();
380 let reporter_context = ReporterContext {
381 is_parallel: max_parallelism > 1,
382 };
383 reporter.report_category_start(category, &reporter_context);
384
385 let receive_receiver = {
386 let (receiver_sender, receive_receiver) =
387 crossbeam_channel::unbounded::<SendMessage<TData>>();
388 let (send_sender, send_receiver) =
389 crossbeam_channel::bounded::<CollectedTest<TData>>(max_parallelism);
390 for _ in 0..max_parallelism {
391 let send_receiver = send_receiver.clone();
392 let sender = receiver_sender.clone();
393 let run_test = context.run_test.clone();
394 context.pool.spawn(move || {
395 let run_test = &run_test;
396 while let Ok(test) = send_receiver.recv() {
397 let start = Instant::now();
398 _ = sender.send(SendMessage::Start { test: test.clone() });
401 let result = (run_test)(&test);
402 if sender
403 .send(SendMessage::Result {
404 test,
405 duration: start.elapsed(),
406 result,
407 })
408 .is_err()
409 {
410 return;
411 }
412 }
413 });
414 }
415
416 context.pool.spawn(move || {
417 for test in tests {
418 if send_sender.send(test).is_err() {
419 return; }
421 }
422 });
423
424 receive_receiver
425 };
426
427 while let Ok(message) = receive_receiver.recv() {
428 match message {
429 SendMessage::Start { test } => {
430 reporter.report_test_start(&test, &reporter_context)
431 }
432 SendMessage::Result {
433 test,
434 duration,
435 result,
436 } => {
437 reporter.report_test_end(&test, duration, &result, &reporter_context);
438 let is_failure = result.is_failed();
439 let failure_output = collect_failure_output(result);
440 if is_failure {
441 context.failures.push(ReporterFailure {
442 test,
443 output: failure_output,
444 });
445 }
446 }
447 }
448 }
449
450 reporter.report_category_end(category, &reporter_context);
451}
452
453fn collect_failure_output(result: TestResult) -> Vec<u8> {
454 fn output_sub_tests(
455 sub_tests: &[SubTestResult],
456 failure_output: &mut Vec<u8>,
457 ) {
458 for sub_test in sub_tests {
459 match &sub_test.result {
460 TestResult::Passed { .. } | TestResult::Ignored => {}
461 TestResult::Failed { output, .. } => {
462 if !failure_output.is_empty() {
463 failure_output.push(b'\n');
464 }
465 failure_output.extend(output);
466 }
467 TestResult::SubTests { sub_tests, .. } => {
468 if !sub_tests.is_empty() {
469 output_sub_tests(sub_tests, failure_output);
470 }
471 }
472 }
473 }
474 }
475
476 let mut failure_output = Vec::new();
477 match result {
478 TestResult::Passed { .. } | TestResult::Ignored => {}
479 TestResult::Failed { output, .. } => {
480 failure_output = output;
481 }
482 TestResult::SubTests { sub_tests, .. } => {
483 output_sub_tests(&sub_tests, &mut failure_output);
484 }
485 }
486
487 failure_output
488}
489
490#[cfg(test)]
491mod test {
492 use super::*;
493
494 #[test]
495 fn test_collect_failure_output_failed() {
496 let failure_output = collect_failure_output(super::TestResult::Failed {
497 duration: None,
498 output: b"error".to_vec(),
499 });
500 assert_eq!(failure_output, b"error");
501 }
502
503 #[test]
504 fn test_collect_failure_output_sub_tests() {
505 let failure_output = collect_failure_output(super::TestResult::SubTests {
506 duration: None,
507 sub_tests: vec![
508 super::SubTestResult {
509 name: "step1".to_string(),
510 result: super::TestResult::Passed { duration: None },
511 },
512 super::SubTestResult {
513 name: "step2".to_string(),
514 result: super::TestResult::Failed {
515 duration: None,
516 output: b"error1".to_vec(),
517 },
518 },
519 super::SubTestResult {
520 name: "step3".to_string(),
521 result: super::TestResult::Failed {
522 duration: None,
523 output: b"error2".to_vec(),
524 },
525 },
526 super::SubTestResult {
527 name: "step4".to_string(),
528 result: super::TestResult::SubTests {
529 duration: None,
530 sub_tests: vec![
531 super::SubTestResult {
532 name: "sub-step1".to_string(),
533 result: super::TestResult::Passed { duration: None },
534 },
535 super::SubTestResult {
536 name: "sub-step2".to_string(),
537 result: super::TestResult::Failed {
538 duration: None,
539 output: b"error3".to_vec(),
540 },
541 },
542 ],
543 },
544 },
545 ],
546 });
547
548 assert_eq!(
549 String::from_utf8(failure_output).unwrap(),
550 "error1\nerror2\nerror3"
551 );
552 }
553}