use crate::Executor;
use std::{borrow::Cow, sync::Arc};
pub struct BatchExecutor<E>
where
E: Executor,
{
label: Cow<'static, str>,
_execute_task: Arc<tokio::task::JoinHandle<()>>,
execute_request_tx: tokio::sync::mpsc::Sender<ExecuteRequest<E::Value, E::Result>>,
}
impl<E> BatchExecutor<E>
where
E: Executor + Send + Sync + 'static,
{
pub fn build(executor: E) -> BatchExecutorBuilder<E> {
BatchExecutorBuilder {
executor,
delay_duration: tokio::time::Duration::from_millis(10),
eager_batch_size: Some(100),
label: "unlabeled-batch-executor".into(),
}
}
#[tracing::instrument(skip_all, fields(batch_executor = %self.label))]
pub async fn execute(&self, key: E::Value) -> Result<Option<E::Result>, ExecuteError> {
let mut values = self.execute_values(vec![key]).await?;
Ok(values.pop())
}
#[tracing::instrument(skip_all, fields(batch_executor = %self.label, num_values = values.len()))]
pub async fn execute_many(
&self,
values: Vec<E::Value>,
) -> Result<Vec<E::Result>, ExecuteError> {
let results = self.execute_values(values).await?;
Ok(results)
}
async fn execute_values(&self, values: Vec<E::Value>) -> Result<Vec<E::Result>, ExecuteError> {
let execute_request_tx = self.execute_request_tx.clone();
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
tracing::debug!(
batch_executor = %self.label,
"sending a batch of values to execute",
);
let execute_request = ExecuteRequest { values, result_tx };
execute_request_tx
.send(execute_request)
.await
.map_err(|_| ExecuteError::SendError)?;
match result_rx.await {
Ok(Ok(results)) => {
tracing::debug!(batch_executor = %self.label, "fetch response returned successfully");
Ok(results)
}
Ok(Err(execute_error)) => {
tracing::info!("error returned while executing: {execute_error}");
Err(ExecuteError::ExecutorError(execute_error))
}
Err(recv_error) => {
panic!(
"Batch result channel for batch executor {} hung up with error: {recv_error}",
self.label,
);
}
}
}
}
impl<E> Clone for BatchExecutor<E>
where
E: Executor,
{
fn clone(&self) -> Self {
BatchExecutor {
_execute_task: self._execute_task.clone(),
execute_request_tx: self.execute_request_tx.clone(),
label: self.label.clone(),
}
}
}
pub struct BatchExecutorBuilder<E>
where
E: Executor + Send + Sync + 'static,
{
executor: E,
delay_duration: tokio::time::Duration,
eager_batch_size: Option<usize>,
label: Cow<'static, str>,
}
impl<E> BatchExecutorBuilder<E>
where
E: Executor + Send + Sync + 'static,
{
pub fn delay_duration(mut self, delay: tokio::time::Duration) -> Self {
self.delay_duration = delay;
self
}
pub fn eager_batch_size(mut self, eager_batch_size: Option<usize>) -> Self {
self.eager_batch_size = eager_batch_size;
self
}
pub fn label(mut self, label: impl Into<Cow<'static, str>>) -> Self {
self.label = label.into();
self
}
pub fn finish(self) -> BatchExecutor<E> {
let (execute_request_tx, mut execute_request_rx) =
tokio::sync::mpsc::channel::<ExecuteRequest<E::Value, E::Result>>(1);
let label = self.label.clone();
let execute_task = tokio::spawn({
async move {
'task: loop {
let mut pending_values = vec![];
let mut result_txs = vec![];
tracing::trace!(batch_executor = %self.label, "waiting for values to execute...");
match execute_request_rx.recv().await {
Some(execute_request) => {
tracing::trace!(batch_executor = %self.label, num_execute_request_values = execute_request.values.len(), "received initial execute request");
let result_start_index = pending_values.len();
pending_values.extend(execute_request.values);
result_txs.push((result_start_index, execute_request.result_tx));
}
None => {
break 'task;
}
};
'wait_for_more_values: loop {
let should_run_batch_now = match self.eager_batch_size {
Some(eager_batch_size) => pending_values.len() >= eager_batch_size,
None => false,
};
if should_run_batch_now {
tracing::trace!(
batch_executor = %self.label,
num_pending_values = pending_values.len(),
eager_batch_size = ?self.eager_batch_size,
"batch filled up, ready to execute now",
);
break 'wait_for_more_values;
}
let delay = tokio::time::sleep(self.delay_duration);
tokio::pin!(delay);
tokio::select! {
execute_request = execute_request_rx.recv() => {
match execute_request {
Some(execute_request) => {
tracing::trace!(batch_executor = %self.label, num_execute_request_values = execute_request.values.len(), "retrieved additional execute request");
let result_start_index = pending_values.len();
pending_values.extend(execute_request.values);
result_txs.push((result_start_index, execute_request.result_tx));
}
None => {
tracing::debug!(batch_executor = %self.label, num_pending_values = pending_values.len(), "execute channel closed");
break 'wait_for_more_values;
}
}
}
_ = &mut delay => {
tracing::trace!(
batch_executor = %self.label,
num_pending_values = pending_values.len(),
"delay reached while waiting for more values to fetch"
);
break 'wait_for_more_values;
}
};
}
tracing::trace!(batch_executor = %self.label, num_pending_values = pending_values.len(), num_pending_channels = result_txs.len(), "fetching values");
let mut result = self
.executor
.execute(pending_values)
.await
.map_err(|error| error.to_string());
for (result_range, result_tx) in result_txs.into_iter().rev() {
let result = match &mut result {
Ok(result) => {
if result_range <= result.len() {
Ok(result.split_off(result_range))
} else {
Ok(vec![])
}
}
Err(error) => Err(error.clone()),
};
let _ = result_tx.send(result);
}
}
}
});
BatchExecutor {
label,
_execute_task: Arc::new(execute_task),
execute_request_tx,
}
}
}
struct ExecuteRequest<V, R> {
values: Vec<V>,
result_tx: tokio::sync::oneshot::Sender<Result<Vec<R>, String>>,
}
#[derive(Debug, thiserror::Error)]
pub enum ExecuteError {
#[error("error while executing batch: {}", _0)]
ExecutorError(String),
#[error("error sending execution request")]
SendError,
}