1#![cfg_attr(not(feature = "chili"), allow(unused_variables))]
55
56use std::{cell::RefCell, mem::transmute};
57
58#[cfg(all(not(feature = "chili"), not(feature = "rayon"), feature = "parallel"))]
59compile_error!("You must enable `chili` or `rayon` feature if you want to use `parallel` feature");
60
61#[cfg(all(feature = "chili", feature = "rayon"))]
62compile_error!("You must enable `chili` or `rayon` feature, not both");
63
64#[derive(Default)]
65pub struct MaybeScope<'a>(ScopeLike<'a>);
66
67enum ScopeLike<'a> {
68 Scope(Scope<'a>),
69 #[cfg(feature = "chili")]
70 Global(Option<chili::Scope<'a>>),
71}
72
73impl Default for ScopeLike<'_> {
74 fn default() -> Self {
75 #[cfg(feature = "chili")]
76 {
77 ScopeLike::Global(None)
78 }
79
80 #[cfg(not(feature = "chili"))]
81 {
82 ScopeLike::Scope(Scope(std::marker::PhantomData))
83 }
84 }
85}
86
87impl<'a> From<Scope<'a>> for MaybeScope<'a> {
88 fn from(value: Scope<'a>) -> Self {
89 MaybeScope(ScopeLike::Scope(value))
90 }
91}
92
93impl<'a> MaybeScope<'a> {
94 #[allow(clippy::redundant_closure)]
95 pub fn with<F, R>(&mut self, f: F) -> R
96 where
97 F: FnOnce(Scope<'a>) -> R,
98 {
99 #[cfg(feature = "chili")]
100 let scope: &mut chili::Scope = match &mut self.0 {
101 ScopeLike::Scope(scope) => unsafe {
102 transmute::<&mut chili::Scope, &mut chili::Scope>(&mut scope.0)
105 },
106 #[cfg(feature = "chili")]
107 ScopeLike::Global(global_scope) => {
108 let scope = global_scope.get_or_insert_with(|| chili::Scope::global());
110
111 unsafe {
112 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
115 }
116 }
117 };
118
119 #[cfg(feature = "chili")]
120 let scope = Scope(scope);
121
122 #[cfg(not(feature = "chili"))]
123 let scope = Scope(std::marker::PhantomData);
124
125 f(scope)
126 }
127}
128
129#[cfg(not(feature = "chili"))]
130pub struct Scope<'a>(std::marker::PhantomData<&'a ()>);
131
132#[cfg(feature = "chili")]
133pub struct Scope<'a>(&'a mut chili::Scope<'a>);
134
135#[inline]
136pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
137where
138 A: Send + FnOnce() -> RA,
139 B: Send + FnOnce() -> RB,
140 RA: Send,
141 RB: Send,
142{
143 thread_local! {
144 static SCOPE: RefCell<Option<MaybeScope<'static>>> = Default::default();
145 }
146
147 struct RemoveScopeGuard;
148
149 impl Drop for RemoveScopeGuard {
150 fn drop(&mut self) {
151 SCOPE.set(None);
152 }
153 }
154
155 let mut scope = SCOPE.take().unwrap_or_default();
156
157 let (ra, rb) = join_maybe_scoped(
158 &mut scope,
159 |scope| {
160 let scope = unsafe {
161 transmute::<Scope, Scope>(scope)
163 };
164 let _guard = RemoveScopeGuard;
165 SCOPE.set(Some(MaybeScope(ScopeLike::Scope(scope))));
166
167 oper_a()
168 },
169 |scope| {
170 let scope = unsafe {
171 transmute::<Scope, Scope>(scope)
173 };
174 let _guard = RemoveScopeGuard;
175 SCOPE.set(Some(MaybeScope(ScopeLike::Scope(scope))));
176
177 oper_b()
178 },
179 );
180
181 SCOPE.set(Some(scope));
183
184 (ra, rb)
185}
186
187#[inline]
188pub fn join_maybe_scoped<'a, A, B, RA, RB>(
189 scope: &mut MaybeScope<'a>,
190 oper_a: A,
191 oper_b: B,
192) -> (RA, RB)
193where
194 A: Send + FnOnce(Scope<'a>) -> RA,
195 B: Send + FnOnce(Scope<'a>) -> RB,
196 RA: Send,
197 RB: Send,
198{
199 scope.with(|scope| join_scoped(scope, oper_a, oper_b))
200}
201
202#[inline]
203pub fn join_scoped<'a, A, B, RA, RB>(scope: Scope<'a>, oper_a: A, oper_b: B) -> (RA, RB)
204where
205 A: Send + FnOnce(Scope<'a>) -> RA,
206 B: Send + FnOnce(Scope<'a>) -> RB,
207 RA: Send,
208 RB: Send,
209{
210 #[cfg(feature = "chili")]
211 let (ra, rb) = scope.0.join(
212 |scope| {
213 let scope = Scope(unsafe {
214 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
217 });
218
219 oper_a(scope)
220 },
221 |scope| {
222 let scope = Scope(unsafe {
223 transmute::<&mut chili::Scope, &mut chili::Scope>(scope)
226 });
227
228 oper_b(scope)
229 },
230 );
231
232 #[cfg(feature = "rayon")]
233 let (ra, rb) = rayon::join(
234 || oper_a(Scope(std::marker::PhantomData)),
235 || oper_b(Scope(std::marker::PhantomData)),
236 );
237
238 #[cfg(not(feature = "parallel"))]
239 let (ra, rb) = (
240 oper_a(Scope(std::marker::PhantomData)),
241 oper_b(Scope(std::marker::PhantomData)),
242 );
243
244 (ra, rb)
245}