sweet_test/native/
test_runner_rayon.rs

1use crate::prelude::*;
2use flume::Sender;
3use rayon::iter::IntoParallelIterator;
4use rayon::iter::ParallelIterator;
5use std::cell::Cell;
6use std::sync::Arc;
7use test::TestDesc;
8use thread_local::ThreadLocal;
9
10
11pub fn rayon_with_num_threads(
12	test_threads: Option<usize>,
13) -> Result<rayon::ThreadPool> {
14	let mut local_pool = rayon::ThreadPoolBuilder::new();
15	if let Some(test_threads) = test_threads {
16		local_pool = local_pool.num_threads(test_threads);
17	}
18	let pool = local_pool.build()?;
19	Ok(pool)
20}
21
22/// Best runner for mostly native sync tests.
23/// A seperate [futures::executor] will be lazily spawned for each async test
24/// as it is found, and block on its completion.
25pub struct TestRunnerRayon;
26
27
28impl TestRunner for TestRunnerRayon {
29	fn run(
30		config: &TestRunnerConfig,
31		future_tx: Sender<TestDescAndFuture>,
32		result_tx: Sender<TestDescAndResult>,
33		tests: Vec<test::TestDescAndFn>,
34	) -> Result<()> {
35		// let tls = Arc::new(ThreadLocal::new());
36
37
38		let local_pool = rayon_with_num_threads(config.test_threads)?;
39
40		let tls_desc = Arc::new(ThreadLocal::<Cell<Option<TestDesc>>>::new());
41		let default_hook = std::panic::take_hook();
42
43		let tls_desc2 = tls_desc.clone();
44		let result_tx2 = result_tx.clone();
45		std::panic::set_hook(Box::new(move |info| {
46
47			if let Some(desc) = tls_desc2.get() {
48				if let Some(desc) = desc.take() {
49					let result = TestResult::from_panic(info, &desc);
50					if let Err(err) =
51						result_tx2.send(TestDescAndResult::new(desc, result))
52					{
53						eprintln!("failed to register panic: {}", err);
54					}
55				} else {
56					eprintln!("malformed thread local test description");
57				}
58			} else {
59				default_hook(info);
60			}
61		}));
62
63		let _results = local_pool
64			.install(|| {
65				tests.into_par_iter().map_with(
66					tls_desc.clone(),
67					|desc_cell, test| {
68						let tls_desc_cell =
69							desc_cell.get_or(|| Default::default());
70						tls_desc_cell.set(Some(test.desc.clone()));
71
72						let func = TestDescAndFnExt::func(&test);
73						let result =
74							SweetTestCollector::with_scope(&test.desc, || {
75								std::panic::catch_unwind(func)
76							});
77						match result {
78							Ok(Ok(result)) => {
79								result_tx
80									.send(TestDescAndResult::new(
81										test.desc.clone(),
82										TestResult::from_test_result(
83											result, &test.desc,
84										),
85									))
86									.expect("channel was dropped");
87								// None
88							}
89							Ok(Err(_payload)) => {
90								// panic result was sent in the hook
91							}
92							Err(fut) => {
93								future_tx.send(fut).unwrap();
94							}
95						};
96						let cell = desc_cell.get_or(|| Default::default());
97						cell.set(None);
98						// return output;
99					},
100				)
101			})
102			.collect::<Vec<_>>();
103
104
105		let _hook = std::panic::take_hook();
106
107		Ok(())
108	}
109}