use std::{fmt::Arguments, io::Write, sync::atomic::{AtomicBool, Ordering}};
use atomicbox::AtomicOptionBox;
use crate::{seq::VectorFn, unstable_sealed::UnstableSealed};
pub fn no_error<T>(error: !) -> T {
error
}
pub trait ComputationController: Clone + UnstableSealed {
type Abort: Send;
#[stability::unstable(feature = "enable")]
fn checkpoint(&self, _description: Arguments) -> Result<(), Self::Abort> {
Ok(())
}
#[stability::unstable(feature = "enable")]
fn run_computation<F, T>(self, _description: Arguments, computation: F) -> T
where F: FnOnce(Self) -> T
{
computation(self)
}
#[stability::unstable(feature = "enable")]
fn log(&self, _description: Arguments) {}
#[stability::unstable(feature = "enable")]
fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce(Self) -> RA + Send,
B: FnOnce(Self) -> RB + Send,
RA: Send,
RB: Send
{
(oper_a(self.clone()), oper_b(self.clone()))
}
}
pub enum ShortCircuitingComputationAbort<E> {
Finished,
Abort(E)
}
pub struct ShortCircuitingComputation<T, Controller>
where T: Send,
Controller: ComputationController
{
finished: AtomicBool,
abort: AtomicOptionBox<Controller::Abort>,
result: AtomicOptionBox<T>,
}
pub struct ShortCircuitingComputationHandle<'a, T, Controller>
where T: Send,
Controller: ComputationController
{
controller: Controller,
executor: &'a ShortCircuitingComputation<T, Controller>
}
impl<'a, T, Controller> Clone for ShortCircuitingComputationHandle<'a, T, Controller>
where T: Send,
Controller: ComputationController
{
fn clone(&self) -> Self {
Self {
controller: self.controller.clone(),
executor: self.executor
}
}
}
impl<'a, T, Controller> ShortCircuitingComputationHandle<'a, T, Controller>
where T: Send,
Controller: ComputationController
{
#[stability::unstable(feature = "enable")]
pub fn controller(&self) -> &Controller {
&self.controller
}
#[stability::unstable(feature = "enable")]
pub fn checkpoint(&self, description: Arguments) -> Result<(), ShortCircuitingComputationAbort<Controller::Abort>> {
if self.executor.finished.load(Ordering::Relaxed) {
return Err(ShortCircuitingComputationAbort::Finished);
} else if let Err(e) = self.controller.checkpoint(description) {
return Err(ShortCircuitingComputationAbort::Abort(e));
} else {
return Ok(());
}
}
#[stability::unstable(feature = "enable")]
pub fn log(&self, description: Arguments) {
self.controller.log(description)
}
#[stability::unstable(feature = "enable")]
pub fn join_many<V, F>(self, operations: V)
where V: VectorFn<F> + Sync,
F: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
{
fn join_many_internal<'a, T, V, F, Controller>(controller: Controller, executor: &'a ShortCircuitingComputation<T, Controller>, tasks: &V, from: usize, to: usize, batch_tasks: usize)
where T: Send,
Controller: ComputationController,
V: VectorFn<F> + Sync,
F: FnOnce(ShortCircuitingComputationHandle<'a, T, Controller>) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>>
{
if executor.finished.load(Ordering::Relaxed) {
return;
} else if from == to {
return;
} else if from + batch_tasks >= to {
for i in from..to {
match tasks.at(i)(ShortCircuitingComputationHandle {
controller: controller.clone(),
executor: executor
}) {
Ok(Some(result)) => {
executor.finished.store(true, Ordering::Relaxed);
executor.result.store(Some(Box::new(result)), Ordering::AcqRel);
},
Err(ShortCircuitingComputationAbort::Abort(abort)) => {
executor.finished.store(true, Ordering::Relaxed);
executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
},
Err(ShortCircuitingComputationAbort::Finished) | Ok(None) => {}
}
}
} else {
let mid = (from + to) / 2;
controller.join(move |controller| join_many_internal(controller, executor, tasks, from, mid, batch_tasks), move |controller| join_many_internal(controller, executor, tasks, mid, to, batch_tasks));
}
}
join_many_internal(self.controller, self.executor, &operations, 0, operations.len(), 1)
}
#[stability::unstable(feature = "enable")]
pub fn join<A, B>(self, oper_a: A, oper_b: B)
where
A: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send,
B: FnOnce(Self) -> Result<Option<T>, ShortCircuitingComputationAbort<Controller::Abort>> + Send
{
let success_fn = |value: T| {
self.executor.finished.store(true, Ordering::Relaxed);
self.executor.result.store(Some(Box::new(value)), Ordering::AcqRel);
};
let abort_fn = |abort: Controller::Abort| {
self.executor.finished.store(true, Ordering::Relaxed);
self.executor.abort.store(Some(Box::new(abort)), Ordering::AcqRel);
};
self.controller.join(
|controller| {
if self.executor.finished.load(Ordering::Relaxed) {
return;
}
match oper_a(ShortCircuitingComputationHandle {
controller,
executor: self.executor
}) {
Ok(Some(result)) => success_fn(result),
Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
Err(ShortCircuitingComputationAbort::Finished) => {},
Ok(None) => {}
}
},
|controller| {
if self.executor.finished.load(Ordering::Relaxed) {
return;
}
match oper_b(ShortCircuitingComputationHandle {
controller,
executor: self.executor
}) {
Ok(Some(result)) => success_fn(result),
Err(ShortCircuitingComputationAbort::Abort(abort)) => abort_fn(abort),
Err(ShortCircuitingComputationAbort::Finished) => {},
Ok(None) => {}
}
}
);
}
}
impl<T, Controller> ShortCircuitingComputation<T, Controller>
where T: Send,
Controller: ComputationController
{
#[stability::unstable(feature = "enable")]
pub fn new() -> Self {
Self {
finished: AtomicBool::new(false),
abort: AtomicOptionBox::none(),
result: AtomicOptionBox::none()
}
}
#[stability::unstable(feature = "enable")]
pub fn handle<'a>(&'a self, controller: Controller) -> ShortCircuitingComputationHandle<'a, T, Controller> {
ShortCircuitingComputationHandle {
controller: controller,
executor: self
}
}
#[stability::unstable(feature = "enable")]
pub fn finish(self) -> Result<Option<T>, Controller::Abort> {
if let Some(abort) = self.abort.swap(None, Ordering::AcqRel) {
return Err(*abort);
} else if let Some(result) = self.result.swap(None, Ordering::AcqRel) {
return Ok(Some(*result));
} else {
return Ok(None);
}
}
}
#[macro_export]
macro_rules! checkpoint {
($controller:expr) => {
($controller).checkpoint(std::format_args!(""))?
};
($controller:expr, $($args:tt)*) => {
($controller).checkpoint(std::format_args!($($args)*))?
};
}
#[macro_export]
macro_rules! log_progress {
($controller:expr, $($args:tt)*) => {
($controller).log(std::format_args!($($args)*))
};
}
#[derive(Clone, Copy)]
pub struct LogProgress;
impl UnstableSealed for LogProgress {}
impl ComputationController for LogProgress {
type Abort = !;
#[stability::unstable(feature = "enable")]
fn log(&self, description: Arguments) {
print!("{}", description);
std::io::stdout().flush().unwrap();
}
#[stability::unstable(feature = "enable")]
fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
where F: FnOnce(Self) -> T
{
self.log(description);
let result = computation(self);
self.log(format_args!("\n"));
return result;
}
#[stability::unstable(feature = "enable")]
fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
self.log(description);
Ok(())
}
}
#[derive(Clone, Copy)]
pub struct DontObserve;
impl UnstableSealed for DontObserve {}
impl ComputationController for DontObserve {
type Abort = !;
}
#[cfg(feature = "parallel")]
mod parallel_controller {
use super::*;
#[stability::unstable(feature = "enable")]
pub struct ExecuteMultithreaded<Rest: ComputationController + Send> {
rest: Rest
}
impl<Rest: ComputationController + Send + Copy> Copy for ExecuteMultithreaded<Rest> {}
impl<Rest: ComputationController + Send> Clone for ExecuteMultithreaded<Rest> {
fn clone(&self) -> Self {
Self { rest: self.rest.clone() }
}
}
impl<Rest: ComputationController + Send> UnstableSealed for ExecuteMultithreaded<Rest> {}
impl<Rest: ComputationController + Send> ComputationController for ExecuteMultithreaded<Rest> {
type Abort = Rest::Abort;
#[stability::unstable(feature = "enable")]
fn checkpoint(&self, description: Arguments) -> Result<(), Self::Abort> {
self.rest.checkpoint(description)
}
#[stability::unstable(feature = "enable")]
fn run_computation<F, T>(self, description: Arguments, computation: F) -> T
where F: FnOnce(Self) -> T
{
self.rest.run_computation(description, |rest| computation(ExecuteMultithreaded { rest }))
}
#[stability::unstable(feature = "enable")]
fn join<A, B, RA, RB>(self, oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce(Self) -> RA + Send,
B: FnOnce(Self) -> RB + Send,
RA: Send,
RB: Send
{
let self1 = self.clone();
let self2 = self;
rayon::join(|| oper_a(self1), || oper_b(self2))
}
}
#[stability::unstable(feature = "enable")]
#[allow(non_upper_case_globals)]
pub static RunMultithreadedLogProgress: ExecuteMultithreaded<LogProgress> = ExecuteMultithreaded { rest: LogProgress };
#[stability::unstable(feature = "enable")]
#[allow(non_upper_case_globals)]
pub static RunMultithreaded: ExecuteMultithreaded<DontObserve> = ExecuteMultithreaded { rest: DontObserve };
}
#[cfg(feature = "parallel")]
pub use parallel_controller::*;