1#![cfg_attr(not(feature = "chili"), allow(unused_variables))]
2
3use std::{cell::RefCell, mem::transmute};
4
5#[cfg(all(not(feature = "chili"), not(feature = "rayon"), feature = "parallel"))]
6compile_error!("You must enable `chili` or `rayon` feature if you want to use `parallel` feature");
7
8#[cfg(all(feature = "chili", feature = "rayon"))]
9compile_error!("You must enable `chili` or `rayon` feature, not both");
10
11#[derive(Default)]
12pub struct MaybeScope<'a>(ScopeLike<'a>);
13
14enum ScopeLike<'a> {
15 Scope(Scope<'a>),
16 #[cfg(feature = "chili")]
17 Global(Option<chili::Scope<'a>>),
18}
19
20impl Default for ScopeLike<'_> {
21 fn default() -> Self {
22 #[cfg(feature = "chili")]
23 {
24 ScopeLike::Global(None)
25 }
26
27 #[cfg(not(feature = "chili"))]
28 {
29 ScopeLike::Scope(Scope(std::marker::PhantomData))
30 }
31 }
32}
33
34impl<'a> From<Scope<'a>> for MaybeScope<'a> {
35 fn from(value: Scope<'a>) -> Self {
36 MaybeScope(ScopeLike::Scope(value))
37 }
38}
39
40impl<'a> MaybeScope<'a> {
41 #[allow(clippy::redundant_closure)]
42 pub fn with<F, R>(&mut self, f: F) -> R
43 where
44 F: FnOnce(Scope<'a>) -> R,
45 {
46 #[cfg(feature = "chili")]
47 let scope: &mut chili::Scope = match &mut self.0 {
48 ScopeLike::Scope(scope) => unsafe {
49 transmute::<&mut chili::Scope, &mut chili::Scope>(&mut scope.0)
52 },
53 #[cfg(feature = "chili")]
54 ScopeLike::Global(global_scope) => {
55 let scope = global_scope.get_or_insert_with(|| chili::Scope::global());
57
58 unsafe {
59 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
62 }
63 }
64 };
65
66 #[cfg(feature = "chili")]
67 let scope = Scope(scope);
68
69 #[cfg(not(feature = "chili"))]
70 let scope = Scope(std::marker::PhantomData);
71
72 f(scope)
73 }
74}
75
76#[cfg(not(feature = "chili"))]
77pub struct Scope<'a>(std::marker::PhantomData<&'a ()>);
78
79#[cfg(feature = "chili")]
80pub struct Scope<'a>(&'a mut chili::Scope<'a>);
81
82#[inline]
83pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
84where
85 A: Send + FnOnce() -> RA,
86 B: Send + FnOnce() -> RB,
87 RA: Send,
88 RB: Send,
89{
90 thread_local! {
91 static SCOPE: RefCell<Option<MaybeScope<'static>>> = Default::default();
92 }
93
94 struct RemoveScopeGuard;
95
96 impl Drop for RemoveScopeGuard {
97 fn drop(&mut self) {
98 SCOPE.set(None);
99 }
100 }
101
102 let mut scope = SCOPE.take().unwrap_or_default();
103
104 let (ra, rb) = join_maybe_scoped(
105 &mut scope,
106 |scope| {
107 let scope = unsafe {
108 transmute::<Scope, Scope>(scope)
110 };
111 let _guard = RemoveScopeGuard;
112 SCOPE.set(Some(MaybeScope(ScopeLike::Scope(scope))));
113
114 oper_a()
115 },
116 |scope| {
117 let scope = unsafe {
118 transmute::<Scope, Scope>(scope)
120 };
121 let _guard = RemoveScopeGuard;
122 SCOPE.set(Some(MaybeScope(ScopeLike::Scope(scope))));
123
124 oper_b()
125 },
126 );
127
128 SCOPE.set(Some(scope));
130
131 (ra, rb)
132}
133
134#[inline]
135pub fn join_maybe_scoped<'a, A, B, RA, RB>(
136 scope: &mut MaybeScope<'a>,
137 oper_a: A,
138 oper_b: B,
139) -> (RA, RB)
140where
141 A: Send + FnOnce(Scope<'a>) -> RA,
142 B: Send + FnOnce(Scope<'a>) -> RB,
143 RA: Send,
144 RB: Send,
145{
146 scope.with(|scope| join_scoped(scope, oper_a, oper_b))
147}
148
149#[inline]
150pub fn join_scoped<'a, A, B, RA, RB>(scope: Scope<'a>, oper_a: A, oper_b: B) -> (RA, RB)
151where
152 A: Send + FnOnce(Scope<'a>) -> RA,
153 B: Send + FnOnce(Scope<'a>) -> RB,
154 RA: Send,
155 RB: Send,
156{
157 #[cfg(feature = "chili")]
158 let (ra, rb) = scope.0.join(
159 |scope| {
160 let scope = Scope(unsafe {
161 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
164 });
165
166 oper_a(scope)
167 },
168 |scope| {
169 let scope = Scope(unsafe {
170 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
173 });
174
175 oper_b(scope)
176 },
177 );
178
179 #[cfg(feature = "rayon")]
180 let (ra, rb) = rayon::join(
181 || oper_a(Scope(std::marker::PhantomData)),
182 || oper_b(Scope(std::marker::PhantomData)),
183 );
184
185 #[cfg(not(feature = "parallel"))]
186 let (ra, rb) = (
187 oper_a(Scope(std::marker::PhantomData)),
188 oper_b(Scope(std::marker::PhantomData)),
189 );
190
191 (ra, rb)
192}