use std::sync::{
Arc,
Mutex,
};
use qubit_function::{
Callable,
Runnable,
};
use crate::{
BatchExecutionError,
BatchExecutionResult,
};
use super::BatchCallResult;
pub trait BatchExecutor: Send + Sync {
fn execute<T, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchExecutionResult<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = T>,
T: Runnable<E> + Send,
E: Send;
fn call<C, R, E, I>(
&self,
tasks: I,
count: usize,
) -> Result<BatchCallResult<R, E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = C>,
C: Callable<R, E> + Send,
R: Send,
E: Send,
{
let outputs = Arc::new(
(0..count)
.map(|_| Mutex::new(None))
.collect::<Vec<Mutex<Option<R>>>>(),
);
let runnable_tasks = tasks.into_iter().enumerate().map({
let outputs = Arc::clone(&outputs);
move |(index, callable)| CallableTask::new(callable, index, Arc::clone(&outputs))
});
let execution_result = self.execute(runnable_tasks, count)?;
let values = collect_call_outputs(outputs);
Ok(BatchCallResult::new(execution_result, values))
}
fn for_each<Item, E, I, F>(
&self,
items: I,
count: usize,
action: F,
) -> Result<BatchExecutionResult<E>, BatchExecutionError<E>>
where
I: IntoIterator<Item = Item>,
Item: Send,
F: Fn(Item) -> Result<(), E> + Send + Sync,
E: Send,
{
let action = Arc::new(action);
let tasks = items
.into_iter()
.map(move |item| ForEachTask::new(item, Arc::clone(&action)));
self.execute(tasks, count)
}
}
struct ForEachTask<Item, E, F>
where
F: Fn(Item) -> Result<(), E> + Send + Sync,
{
item: Option<Item>,
action: Arc<F>,
}
struct CallableTask<C, R> {
callable: Option<C>,
index: usize,
outputs: Arc<Vec<Mutex<Option<R>>>>,
}
impl<C, R> CallableTask<C, R> {
fn new(callable: C, index: usize, outputs: Arc<Vec<Mutex<Option<R>>>>) -> Self {
Self {
callable: Some(callable),
index,
outputs,
}
}
}
impl<C, R, E> Runnable<E> for CallableTask<C, R>
where
C: Callable<R, E>,
{
fn run(&mut self) -> Result<(), E> {
let mut callable = self
.callable
.take()
.expect("callable task may only run once");
let value = callable.call()?;
let mut slot = self
.outputs
.get(self.index)
.expect("callable index must be within the declared count")
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*slot = Some(value);
Ok(())
}
}
fn collect_call_outputs<R>(outputs: Arc<Vec<Mutex<Option<R>>>>) -> Vec<Option<R>> {
let slots = match Arc::try_unwrap(outputs) {
Ok(slots) => slots,
Err(_) => panic!("callable output slots should have a single owner after execution"),
};
slots
.into_iter()
.map(|slot| {
slot.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner)
})
.collect()
}
impl<Item, E, F> ForEachTask<Item, E, F>
where
F: Fn(Item) -> Result<(), E> + Send + Sync,
{
fn new(item: Item, action: Arc<F>) -> Self {
Self {
item: Some(item),
action,
}
}
}
impl<Item, E, F> Runnable<E> for ForEachTask<Item, E, F>
where
F: Fn(Item) -> Result<(), E> + Send + Sync,
{
fn run(&mut self) -> Result<(), E> {
let item = self.item.take().expect("for_each task may only run once");
(self.action)(item)
}
}