1use std::fmt::Arguments;
2use std::io::Write;
3use std::sync::atomic::{AtomicBool, Ordering};
4
5use atomicbox::AtomicOptionBox;
6
7use crate::seq::VectorFn;
8use crate::unstable_sealed::UnstableSealed;
9
10pub fn no_error<T>(error: !) -> T { error }
18
19pub trait ComputationController: Clone + UnstableSealed {
69 type Abort: Send;
70
71 #[stability::unstable(feature = "enable")]
74 fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> { Ok(()) }
75
76 #[stability::unstable(feature = "enable")]
85 fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
86 where
87 F: FnOnce(Self) -> T,
88 {
89 computation(self)
90 }
91
92 #[stability::unstable(feature = "enable")]
93 fn log(&self, _description: Arguments) {}
94
95 #[stability::unstable(feature = "enable")]
99 fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
100 where
101 A: FnOnce(Self) -> RA + Send,
102 B: FnOnce(Self) -> RB + Send,
103 RA: Send,
104 RB: Send,
105 {
106 (oper_a(self.clone()), oper_b(self.clone()))
107 }
108}
109
110pub enum ShortCircuitingComputationAbort<E> {
116 Finished,
117 Abort(E),
118}
119
120pub struct ShortCircuitingComputation<T, Controller>
122where
123 T: Send,
124 Controller: ComputationController,
125{
126 finished: AtomicBool,
127 abort: AtomicOptionBox<Controller::Abort>,
128 result: AtomicOptionBox<T>,
129}
130
131pub struct ShortCircuitingComputationHandle<'a, T, Controller>
133where
134 T: Send,
135 Controller: ComputationController,
136{
137 controller: Controller,
138 executor: &'a ShortCircuitingComputation<T, Controller>,
139}
140
141impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
142where
143 T: Send,
144 Controller: ComputationController,
145{
146 fn clone(&self) -> Self {
147 Self {
148 controller: self.controller.clone(),
149 executor: self.executor,
150 }
151 }
152}
153
154impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
155where
156 T: Send,
157 Controller: ComputationController,
158{
159 #[stability::unstable(feature = "enable")]
160 pub fn controller(&self) -> &Controller { &self.controller }
161
162 #[stability::unstable(feature = "enable")]
163 pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> {
164 if self.executor.finished.load(Ordering::Relaxed) {
165 return Err(ShortCircuitingComputationAbort::Finished);
166 } else if let Err(e) = self.controller.checkpoint(description) {
167 return Err(ShortCircuitingComputationAbort::Abort(e));
168 } else {
169 return Ok(());
170 }
171 }
172
173 #[stability::unstable(feature = "enable")]
174 pub fn log(&self, description: Arguments) { self.controller.log(description) }
175
176 #[stability::unstable(feature = "enable")]
177 pub fn join_many<V, F>(self, operations: V)
178 where
179 V: VectorFn<F> + Sync,
180 F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>,
181 {
182 fn join_many_internal<'a, T, V, F, Controller>(
183 controller: Controller,
184 executor: &'a ShortCircuitingComputation<T, Controller>,
185 tasks: &V,
186 from: usize,
187 to: usize,
188 batch_tasks: usize,
189 ) where
190 T: Send,
191 Controller: ComputationController,
192 V: VectorFn<F> + Sync,
193 F: FnOnce(
194 ShortCircuitingComputationHandle<'a, T, Controller>,
195 ) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>,
196 {
197 if executor.finished.load(Ordering::Relaxed) {
198 return;
199 } else if from == to {
200 return;
201 } else if from + batch_tasks >= to {
202 for i in from..to {
203 match tasks.at(i)(ShortCircuitingComputationHandle {
204 controller: controller.clone(),
205 executor,
206 }) {
207 Ok(Some(result)) => {
208 executor.finished.store(true, Ordering::Relaxed);
209 executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
210 }
211 Err(ShortCircuitingComputationAbort::Abort(abort)) => {
212 executor.finished.store(true, Ordering::Relaxed);
213 executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
214 }
215 Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
216 }
217 }
218 } else {
219 let mid = (from + to) / 2;
220 _ = controller.join(
221 move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks),
222 move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks),
223 );
224 }
225 }
226 join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
227 }
228
229 #[stability::unstable(feature = "enable")]
230 pub fn join<A, B>(self, oper_a: A, oper_b: B)
231 where
232 A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
233 B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
234 {
235 let success_fn = |value: T| {
236 self.executor.finished.store(true, Ordering::Relaxed);
237 self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
238 };
239 let abort_fn = |abort: Controller::Abort| {
240 self.executor.finished.store(true, Ordering::Relaxed);
241 self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
242 };
243 _ = self.controller.join(
244 |controller| {
245 if self.executor.finished.load(Ordering::Relaxed) {
246 return;
247 }
248 match oper_a(ShortCircuitingComputationHandle {
249 controller,
250 executor: self.executor,
251 }) {
252 Ok(Some(result)) => success_fn(result),
253 Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
254 Err(ShortCircuitingComputationAbort::Finished) => {}
255 Ok(None) => {}
256 }
257 },
258 |controller| {
259 if self.executor.finished.load(Ordering::Relaxed) {
260 return;
261 }
262 match oper_b(ShortCircuitingComputationHandle {
263 controller,
264 executor: self.executor,
265 }) {
266 Ok(Some(result)) => success_fn(result),
267 Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
268 Err(ShortCircuitingComputationAbort::Finished) => {}
269 Ok(None) => {}
270 }
271 },
272 );
273 }
274}
275
276impl<T, Controller> ShortCircuitingComputation<T, Controller>
277where
278 T: Send,
279 Controller: ComputationController,
280{
281 #[stability::unstable(feature = "enable")]
282 pub fn new() -> Self {
283 Self {
284 finished: AtomicBool::new(false),
285 abort: AtomicOptionBox::none(),
286 result: AtomicOptionBox::none(),
287 }
288 }
289
290 #[stability::unstable(feature = "enable")]
291 pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
292 ShortCircuitingComputationHandle {
293 controller,
294 executor: self,
295 }
296 }
297
298 #[stability::unstable(feature = "enable")]
299 pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
300 if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
301 return Err(*abort);
302 } else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
303 return Ok(Some(*result));
304 } else {
305 return Ok(None);
306 }
307 }
308}
309
310#[macro_export]
311macro_rules! checkpoint {
312 ($controller:expr) => {
313 ($controller).checkpoint(std::format_args!(""))?
314 };
315 ($controller:expr,$($args:tt)*) => {
316 ($controller).checkpoint(std::format_args!($($args)*))?
317 };
318}
319
320#[macro_export]
321macro_rules! log_progress {
322 ($controller:expr,$($args:tt)*) => {
323 ($controller).log(std::format_args!($($args)*))
324 };
325}
326
327#[derive(Clone, Copy, Debug)]
328pub struct LogProgress {
329 inner_comp: bool,
330}
331
332pub const LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
333
334#[cfg(test)]
337pub(crate) const TEST_LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
338
339impl UnstableSealed for LogProgress {}
340
341impl ComputationController for LogProgress {
342 type Abort = !;
343
344 #[stability::unstable(feature = "enable")]
345 fn log(&self, description: Arguments) {
346 print!("{}", description);
347 std::io::stdout().flush().unwrap();
348 }
349
350 #[stability::unstable(feature = "enable")]
351 fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
352 where
353 F: FnOnce(Self) -> T,
354 {
355 self.log(description);
356 let result = computation(Self { inner_comp: true });
357 if self.inner_comp {
358 self.log(format_args!("done."));
359 } else {
360 self.log(format_args!("done.\n"));
361 }
362 return result;
363 }
364
365 #[stability::unstable(feature = "enable")]
366 fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
367 self.log(description);
368 Ok(())
369 }
370}
371
372#[derive(Clone, Copy, Debug)]
373pub struct DontObserve;
374
375impl UnstableSealed for DontObserve {}
376
377impl ComputationController for DontObserve {
378 type Abort = !;
379}
380
381#[cfg(feature = "parallel")]
382mod parallel_controller {
383
384 use super::*;
385
386 #[stability::unstable(feature = "enable")]
387 pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
388 rest: Rest,
389 }
390
391 impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
392
393 impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
394 fn clone(&self) -> Self {
395 Self {
396 rest: self.rest.clone(),
397 }
398 }
399 }
400
401 impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
402
403 impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
404 type Abort = Rest::Abort;
405
406 #[stability::unstable(feature = "enable")]
407 fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> { self.rest.checkpoint(description) }
408
409 #[stability::unstable(feature = "enable")]
410 fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
411 where
412 F: FnOnce(Self) -> T,
413 {
414 self.rest
415 .run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
416 }
417
418 #[stability::unstable(feature = "enable")]
419 fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
420 where
421 A: FnOnce(Self) -> RA + Send,
422 B: FnOnce(Self) -> RB + Send,
423 RA: Send,
424 RB: Send,
425 {
426 let self1 = self.clone();
427 let self2 = self;
428 rayon::join(|| oper_a(self1), || oper_b(self2))
429 }
430 }
431
432 #[stability::unstable(feature = "enable")]
433 #[allow(non_upper_case_globals)]
434 pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> =
435 ExecuteMultithreaded { rest: LOG_PROGRESS };
436 #[stability::unstable(feature = "enable")]
437 #[allow(non_upper_case_globals)]
438 pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
439}
440
441#[cfg(feature = "parallel")]
442pub use parallel_controller::*;