1use std::{fmt::Arguments, io::Write, sync::atomic::{AtomicBool, Ordering}};
2
3use atomicbox::AtomicOptionBox;
4
5use crate::{seq::VectorFn, unstable_sealed::UnstableSealed};
6
7pub fn no_error<T>(error: !) -> T {
17 error
18}
19
20pub trait ComputationController: Clone + UnstableSealed {
72
73 type Abort: Send;
74
75 #[stability::unstable(feature = "enable")]
80 fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> {
81 Ok(())
82 }
83
84 #[stability::unstable(feature = "enable")]
95 fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
96 where F: FnOnce(Self) -> T
97 {
98 computation(self)
99 }
100
101 #[stability::unstable(feature = "enable")]
102 fn log(&self, _description: Arguments) {}
103
104 #[stability::unstable(feature = "enable")]
111 fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
112 where
113 A: FnOnce(Self) -> RA + Send,
114 B: FnOnce(Self) -> RB + Send,
115 RA: Send,
116 RB: Send
117 {
118 (oper_a(self.clone()), oper_b(self.clone()))
119 }
120}
121
122pub enum ShortCircuitingComputationAbort<E> {
130 Finished,
131 Abort(E)
132}
133
134pub struct ShortCircuitingComputation<T, Controller>
138 where T: Send,
139 Controller: ComputationController
140{
141 finished: AtomicBool,
142 abort: AtomicOptionBox<Controller::Abort>,
143 result: AtomicOptionBox<T>,
144}
145
146pub struct ShortCircuitingComputationHandle<'a, T, Controller>
150 where T: Send,
151 Controller: ComputationController
152{
153 controller: Controller,
154 executor: &'a ShortCircuitingComputation<T, Controller>
155}
156
157impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
158 where T: Send,
159 Controller: ComputationController
160{
161 fn clone(&self) -> Self {
162 Self {
163 controller: self.controller.clone(),
164 executor: self.executor
165 }
166 }
167}
168
169impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
170 where T: Send,
171 Controller: ComputationController
172{
173 #[stability::unstable(feature = "enable")]
174 pub fn controller(&self) -> &Controller {
175 &self.controller
176 }
177
178 #[stability::unstable(feature = "enable")]
179 pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> {
180 if self.executor.finished.load(Ordering::Relaxed) {
181 return Err(ShortCircuitingComputationAbort::Finished);
182 } else if let Err(e) = self.controller.checkpoint(description) {
183 return Err(ShortCircuitingComputationAbort::Abort(e));
184 } else {
185 return Ok(());
186 }
187 }
188
189 #[stability::unstable(feature = "enable")]
190 pub fn log(&self, description: Arguments) {
191 self.controller.log(description)
192 }
193
194 #[stability::unstable(feature = "enable")]
195 pub fn join_many<V, F>(self, operations: V)
196 where V: VectorFn<F> + Sync,
197 F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
198 {
199 fn join_many_internal<'a, T, V, F, Controller>(controller: Controller, executor: &'a ShortCircuitingComputation<T, Controller>, tasks: &V, from: usize, to: usize, batch_tasks: usize)
200 where T: Send,
201 Controller: ComputationController,
202 V: VectorFn<F> + Sync,
203 F: FnOnce(ShortCircuitingComputationHandle<'a, T, Controller>) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
204 {
205 if executor.finished.load(Ordering::Relaxed) {
206 return;
207 } else if from == to {
208 return;
209 } else if from + batch_tasks >= to {
210 for i in from..to {
211 match tasks.at(i)(ShortCircuitingComputationHandle {
212 controller: controller.clone(),
213 executor: executor
214 }) {
215 Ok(Some(result)) => {
216 executor.finished.store(true, Ordering::Relaxed);
217 executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
218 },
219 Err(ShortCircuitingComputationAbort::Abort(abort)) => {
220 executor.finished.store(true, Ordering::Relaxed);
221 executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
222 },
223 Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
224 }
225 }
226 } else {
227 let mid = (from + to) / 2;
228 controller.join(move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks), move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks));
229 }
230 }
231 join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
232 }
233
234 #[stability::unstable(feature = "enable")]
235 pub fn join<A, B>(self, oper_a: A, oper_b: B)
236 where
237 A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
238 B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send
239 {
240 let success_fn = |value: T| {
241 self.executor.finished.store(true, Ordering::Relaxed);
242 self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
243 };
244 let abort_fn = |abort: Controller::Abort| {
245 self.executor.finished.store(true, Ordering::Relaxed);
246 self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
247 };
248 self.controller.join(
249 |controller| {
250 if self.executor.finished.load(Ordering::Relaxed) {
251 return;
252 }
253 match oper_a(ShortCircuitingComputationHandle {
254 controller,
255 executor: self.executor
256 }) {
257 Ok(Some(result)) => success_fn(result),
258 Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
259 Err(ShortCircuitingComputationAbort::Finished) => {},
260 Ok(None) => {}
261 }
262 },
263 |controller| {
264 if self.executor.finished.load(Ordering::Relaxed) {
265 return;
266 }
267 match oper_b(ShortCircuitingComputationHandle {
268 controller,
269 executor: self.executor
270 }) {
271 Ok(Some(result)) => success_fn(result),
272 Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
273 Err(ShortCircuitingComputationAbort::Finished) => {},
274 Ok(None) => {}
275 }
276 }
277 );
278 }
279}
280
281impl<T, Controller> ShortCircuitingComputation<T, Controller>
282 where T: Send,
283 Controller: ComputationController
284{
285 #[stability::unstable(feature = "enable")]
286 pub fn new() -> Self {
287 Self {
288 finished: AtomicBool::new(false),
289 abort: AtomicOptionBox::none(),
290 result: AtomicOptionBox::none()
291 }
292 }
293
294 #[stability::unstable(feature = "enable")]
295 pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
296 ShortCircuitingComputationHandle {
297 controller: controller,
298 executor: self
299 }
300 }
301
302 #[stability::unstable(feature = "enable")]
303 pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
304 if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
305 return Err(*abort);
306 } else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
307 return Ok(Some(*result));
308 } else {
309 return Ok(None);
310 }
311 }
312}
313
314#[macro_export]
315macro_rules! checkpoint {
316 ($controller:expr) => {
317 ($controller).checkpoint(std::format_args!(""))?
318 };
319 ($controller:expr, $($args:tt)*) => {
320 ($controller).checkpoint(std::format_args!($($args)*))?
321 };
322}
323
324#[macro_export]
325macro_rules! log_progress {
326 ($controller:expr, $($args:tt)*) => {
327 ($controller).log(std::format_args!($($args)*))
328 };
329}
330
331#[derive(Clone, Copy, Debug)]
332pub struct LogProgress {
333 inner_comp: bool
334}
335
336pub const LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
337
338#[cfg(test)]
343pub(crate) const TEST_LOG_PROGRESS: LogProgress = LogProgress { inner_comp: false };
344
345impl UnstableSealed for LogProgress {}
346
347impl ComputationController for LogProgress {
348
349 type Abort = !;
350
351 #[stability::unstable(feature = "enable")]
352 fn log(&self, description: Arguments) {
353 print!("{}", description);
354 std::io::stdout().flush().unwrap();
355 }
356
357 #[stability::unstable(feature = "enable")]
358 fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
359 where F: FnOnce(Self) -> T
360 {
361 self.log(description);
362 let result = computation(Self { inner_comp: true });
363 if self.inner_comp {
364 self.log(format_args!("done."));
365 } else {
366 self.log(format_args!("done.\n"));
367 }
368 return result;
369 }
370
371 #[stability::unstable(feature = "enable")]
372 fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
373 self.log(description);
374 Ok(())
375 }
376}
377
378#[derive(Clone, Copy, Debug)]
379pub struct DontObserve;
380
381impl UnstableSealed for DontObserve {}
382
383impl ComputationController for DontObserve {
384
385 type Abort = !;
386}
387
388#[cfg(feature = "parallel")]
389mod parallel_controller {
390
391 use super::*;
392
393 #[stability::unstable(feature = "enable")]
394 pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
395 rest: Rest
396 }
397
398 impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
399
400 impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
401 fn clone(&self) -> Self {
402 Self { rest: self.rest.clone() }
403 }
404 }
405
406 impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
407
408 impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
409 type Abort = Rest::Abort;
410
411 #[stability::unstable(feature = "enable")]
412 fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
413 self.rest.checkpoint(description)
414 }
415
416 #[stability::unstable(feature = "enable")]
417 fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
418 where F: FnOnce(Self) -> T
419 {
420 self.rest.run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
421 }
422
423 #[stability::unstable(feature = "enable")]
424 fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
425 where
426 A: FnOnce(Self) -> RA + Send,
427 B: FnOnce(Self) -> RB + Send,
428 RA: Send,
429 RB: Send
430 {
431 let self1 = self.clone();
432 let self2 = self;
433 rayon::join(|| oper_a(self1), || oper_b(self2))
434 }
435 }
436
437 #[stability::unstable(feature = "enable")]
438 #[allow(non_upper_case_globals)]
439 pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> = ExecuteMultithreaded { rest: LOG_PROGRESS };
440 #[stability::unstable(feature = "enable")]
441 #[allow(non_upper_case_globals)]
442 pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
443}
444
445#[cfg(feature = "parallel")]
446pub use parallel_controller::*;