sweet_test/native/
test_runner_rayon.rs1use crate::prelude::*;
2use flume::Sender;
3use rayon::iter::IntoParallelIterator;
4use rayon::iter::ParallelIterator;
5use std::cell::Cell;
6use std::sync::Arc;
7use test::TestDesc;
8use thread_local::ThreadLocal;
9
10
11pub fn rayon_with_num_threads(
12 test_threads: Option<usize>,
13) -> Result<rayon::ThreadPool> {
14 let mut local_pool = rayon::ThreadPoolBuilder::new();
15 if let Some(test_threads) = test_threads {
16 local_pool = local_pool.num_threads(test_threads);
17 }
18 let pool = local_pool.build()?;
19 Ok(pool)
20}
21
22pub struct TestRunnerRayon;
26
27
28impl TestRunner for TestRunnerRayon {
29 fn run(
30 config: &TestRunnerConfig,
31 future_tx: Sender<TestDescAndFuture>,
32 result_tx: Sender<TestDescAndResult>,
33 tests: Vec<test::TestDescAndFn>,
34 ) -> Result<()> {
35 let local_pool = rayon_with_num_threads(config.test_threads)?;
39
40 let tls_desc = Arc::new(ThreadLocal::<Cell<Option<TestDesc>>>::new());
41 let default_hook = std::panic::take_hook();
42
43 let tls_desc2 = tls_desc.clone();
44 let result_tx2 = result_tx.clone();
45 std::panic::set_hook(Box::new(move |info| {
46
47 if let Some(desc) = tls_desc2.get() {
48 if let Some(desc) = desc.take() {
49 let result = TestResult::from_panic(info, &desc);
50 if let Err(err) =
51 result_tx2.send(TestDescAndResult::new(desc, result))
52 {
53 eprintln!("failed to register panic: {}", err);
54 }
55 } else {
56 eprintln!("malformed thread local test description");
57 }
58 } else {
59 default_hook(info);
60 }
61 }));
62
63 let _results = local_pool
64 .install(|| {
65 tests.into_par_iter().map_with(
66 tls_desc.clone(),
67 |desc_cell, test| {
68 let tls_desc_cell =
69 desc_cell.get_or(|| Default::default());
70 tls_desc_cell.set(Some(test.desc.clone()));
71
72 let func = TestDescAndFnExt::func(&test);
73 let result =
74 SweetTestCollector::with_scope(&test.desc, || {
75 std::panic::catch_unwind(func)
76 });
77 match result {
78 Ok(Ok(result)) => {
79 result_tx
80 .send(TestDescAndResult::new(
81 test.desc.clone(),
82 TestResult::from_test_result(
83 result, &test.desc,
84 ),
85 ))
86 .expect("channel was dropped");
87 }
89 Ok(Err(_payload)) => {
90 }
92 Err(fut) => {
93 future_tx.send(fut).unwrap();
94 }
95 };
96 let cell = desc_cell.get_or(|| Default::default());
97 cell.set(None);
98 },
100 )
101 })
102 .collect::<Vec<_>>();
103
104
105 let _hook = std::panic::take_hook();
106
107 Ok(())
108 }
109}