use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use futures::stream::{self, StreamExt};
use crate::runnable::{Runnable, RunnableConfig};
use crate::{CognisError, Result};
#[derive(Debug, Clone, Default)]
pub struct BatchOptions {
pub max_concurrency: usize,
pub return_exceptions: bool,
pub wave_delay: Option<Duration>,
}
impl BatchOptions {
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = n;
self
}
pub fn with_return_exceptions(mut self, on: bool) -> Self {
self.return_exceptions = on;
self
}
pub fn with_wave_delay(mut self, d: Duration) -> Self {
self.wave_delay = Some(d);
self
}
}
pub struct BatchProcessor<R, I, O> {
inner: R,
options: BatchOptions,
_phantom: PhantomData<fn(I) -> O>,
}
impl<R, I, O> BatchProcessor<R, I, O>
where
R: Runnable<I, O>,
I: Send + 'static,
O: Send + 'static,
{
pub fn new(inner: R) -> Self {
Self {
inner,
options: BatchOptions::default(),
_phantom: PhantomData,
}
}
pub fn with_options(inner: R, options: BatchOptions) -> Self {
Self {
inner,
options,
_phantom: PhantomData,
}
}
pub fn inner(&self) -> &R {
&self.inner
}
pub fn options_mut(&mut self) -> &mut BatchOptions {
&mut self.options
}
fn effective_concurrency(&self, config: &RunnableConfig) -> usize {
if self.options.max_concurrency == 0 {
config.max_concurrency.max(1)
} else {
self.options.max_concurrency
}
}
}
impl<R, I, O> BatchProcessor<R, I, O>
where
R: Runnable<I, O> + Sync,
I: Send + 'static,
O: Send + 'static,
{
pub async fn process(&self, inputs: Vec<I>, config: RunnableConfig) -> Result<Vec<O>> {
if self.options.return_exceptions {
let results = self.process_with_results(inputs, config).await;
results.into_iter().collect()
} else {
let concurrency = self.effective_concurrency(&config);
let wave = self.options.wave_delay;
let cfg = Arc::new(config);
let inner = &self.inner;
stream::iter(inputs)
.map(|input| {
let cfg = cfg.clone();
async move {
if let Some(d) = wave {
tokio::time::sleep(d).await;
}
inner
.invoke(input, RunnableConfig::clone_for_subcall(&cfg))
.await
}
})
.buffer_unordered(concurrency)
.collect::<Vec<_>>()
.await
.into_iter()
.collect()
}
}
pub async fn process_with_results(
&self,
inputs: Vec<I>,
config: RunnableConfig,
) -> Vec<Result<O>> {
let concurrency = self.effective_concurrency(&config);
let wave = self.options.wave_delay;
let cfg = Arc::new(config);
let inner = &self.inner;
let tagged: Vec<(usize, I)> = inputs.into_iter().enumerate().collect();
let total = tagged.len();
let mut results: Vec<Option<Result<O>>> = (0..total).map(|_| None).collect::<Vec<_>>();
let mut stream = stream::iter(tagged)
.map(|(i, input)| {
let cfg = cfg.clone();
async move {
if let Some(d) = wave {
tokio::time::sleep(d).await;
}
let res = inner
.invoke(input, RunnableConfig::clone_for_subcall(&cfg))
.await;
(i, res)
}
})
.buffer_unordered(concurrency);
while let Some((i, res)) = stream.next().await {
results[i] = Some(res);
}
results
.into_iter()
.map(|r| {
r.unwrap_or_else(|| {
Err(CognisError::Internal("batch slot was never filled".into()))
})
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::atomic::{AtomicUsize, Ordering};
struct Counter {
seen: Arc<AtomicUsize>,
}
#[async_trait]
impl Runnable<usize, usize> for Counter {
async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
self.seen.fetch_add(1, Ordering::SeqCst);
Ok(input * 2)
}
}
struct Sometimes;
#[async_trait]
impl Runnable<usize, usize> for Sometimes {
async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
if input.is_multiple_of(2) {
Ok(input)
} else {
Err(CognisError::Internal(format!("odd input {input}")))
}
}
}
#[tokio::test]
async fn process_returns_results_in_order() {
let r = Counter {
seen: Arc::new(AtomicUsize::new(0)),
};
let bp: BatchProcessor<Counter, usize, usize> = BatchProcessor::new(r);
let out = bp
.process(vec![1, 2, 3], RunnableConfig::default())
.await
.unwrap();
assert_eq!(out.len(), 3);
assert!(out.contains(&2));
assert!(out.contains(&4));
assert!(out.contains(&6));
}
#[tokio::test]
async fn process_short_circuits_on_first_error() {
let bp: BatchProcessor<Sometimes, usize, usize> = BatchProcessor::new(Sometimes);
let res = bp.process(vec![1, 2, 3], RunnableConfig::default()).await;
assert!(res.is_err());
}
#[tokio::test]
async fn process_with_results_collects_partials() {
let bp = BatchProcessor::with_options(
Sometimes,
BatchOptions::default()
.with_return_exceptions(true)
.with_max_concurrency(4),
);
let res = bp
.process_with_results(vec![1, 2, 3, 4], RunnableConfig::default())
.await;
assert_eq!(res.len(), 4);
assert!(res[0].is_err());
assert!(res[1].is_ok());
assert!(res[2].is_err());
assert!(res[3].is_ok());
}
#[tokio::test]
async fn return_exceptions_via_process_short_circuits_only_when_all_fail() {
let bp = BatchProcessor::with_options(
Sometimes,
BatchOptions::default().with_return_exceptions(true),
);
let res = bp.process(vec![2, 4], RunnableConfig::default()).await;
assert_eq!(res.unwrap(), vec![2, 4]);
}
#[tokio::test]
async fn explicit_max_concurrency_overrides_config() {
let counter = Arc::new(AtomicUsize::new(0));
let r = Counter {
seen: counter.clone(),
};
let bp: BatchProcessor<Counter, usize, usize> =
BatchProcessor::with_options(r, BatchOptions::default().with_max_concurrency(2));
let _ = bp
.process(vec![1, 2, 3, 4, 5], RunnableConfig::default())
.await
.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
}