use crate::error::CanoError;
use crate::resource::Resources;
use crate::task::{RetryMode, TaskConfig, TaskResult};
use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
pub type DefaultBatchItem = Box<dyn std::any::Any + Send + Sync>;
pub type DefaultBatchItemOutput = Box<dyn std::any::Any + Send>;
#[crate::task::batch]
pub trait BatchTask<TState, TResourceKey = Cow<'static, str>>: Send + Sync
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
type Item: Send + Sync + 'static;
type ItemOutput: Send + 'static;
fn concurrency(&self) -> usize {
1
}
fn item_retry(&self) -> RetryMode {
RetryMode::None
}
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(std::any::type_name::<Self>())
}
async fn load(&self, res: &Resources<TResourceKey>) -> Result<Vec<Self::Item>, CanoError>;
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError>;
async fn finish(
&self,
res: &Resources<TResourceKey>,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<TState>, CanoError>;
}
pub async fn run_batch<B, S, K>(b: &B, res: &Resources<K>) -> Result<TaskResult<S>, CanoError>
where
B: BatchTask<S, K> + ?Sized,
S: Clone + fmt::Debug + Send + Sync + 'static,
K: Hash + Eq + Send + Sync + 'static,
{
use futures_util::StreamExt as _;
use std::sync::Arc;
let items: Arc<Vec<B::Item>> = Arc::new(b.load(res).await?);
let retry_mode = b.item_retry();
let conc = b.concurrency().max(1);
let n = items.len();
let mut indexed: Vec<(usize, Result<B::ItemOutput, CanoError>)> =
futures_util::stream::iter(0..n)
.map(|i| {
let items_ref = Arc::clone(&items);
let mode = retry_mode.clone();
async move {
let result = run_item_with_retry(b, &items_ref[i], &mode).await;
(i, result)
}
})
.buffer_unordered(conc)
.collect()
.await;
indexed.sort_unstable_by_key(|(i, _)| *i);
let outputs: Vec<Result<B::ItemOutput, CanoError>> =
indexed.into_iter().map(|(_, r)| r).collect();
b.finish(res, outputs).await
}
async fn run_item_with_retry<B, S, K>(
b: &B,
item: &B::Item,
retry_mode: &RetryMode,
) -> Result<B::ItemOutput, CanoError>
where
B: BatchTask<S, K> + ?Sized,
S: Clone + fmt::Debug + Send + Sync + 'static,
K: Hash + Eq + Send + Sync + 'static,
{
let max_attempts = retry_mode.max_attempts();
let mut attempt = 0usize;
loop {
match b.process_item(item).await {
Ok(output) => return Ok(output),
Err(e) => {
attempt += 1;
if attempt >= max_attempts {
return Err(e);
}
if let Some(delay) = retry_mode.delay_for_attempt(attempt - 1)
&& delay.as_millis() > 0
{
tokio::time::sleep(delay).await;
}
}
}
}
}
pub type DynBatchTask<TState, TResourceKey = Cow<'static, str>> = dyn BatchTask<TState, TResourceKey, Item = DefaultBatchItem, ItemOutput = DefaultBatchItemOutput>
+ Send
+ Sync;
pub type BatchTaskObject<TState, TResourceKey = Cow<'static, str>> =
std::sync::Arc<DynBatchTask<TState, TResourceKey>>;
#[cfg(test)]
mod tests {
use super::*;
use crate::resource::Resources;
use crate::task;
use crate::task::Task;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Step {
Process,
Done,
}
struct IndexedBatch {
n: usize,
}
#[task::batch]
impl BatchTask<Step> for IndexedBatch {
type Item = usize;
type ItemOutput = usize;
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok((0..self.n).collect())
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
let got: Vec<usize> = outputs.into_iter().map(|r| r.unwrap()).collect();
let expected: Vec<usize> = (0..self.n).collect();
assert_eq!(got, expected, "finish must receive items in input order");
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_sequential_items_in_input_order() {
let task = IndexedBatch { n: 5 };
let res = Resources::new();
let result = Task::run(&task, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
struct ConcurrentBatch {
n: usize,
}
#[task::batch]
impl BatchTask<Step> for ConcurrentBatch {
type Item = usize;
type ItemOutput = usize;
fn concurrency(&self) -> usize {
4
}
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok((0..self.n).collect())
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
let delay = self.n.saturating_sub(*item) as u64;
if delay > 0 {
tokio::time::sleep(Duration::from_millis(delay)).await;
}
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
let got: Vec<usize> = outputs.into_iter().map(|r| r.unwrap()).collect();
let expected: Vec<usize> = (0..self.n).collect();
assert_eq!(
got, expected,
"concurrent batch must preserve input order in finish"
);
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_concurrent_items_preserve_input_order() {
let task = ConcurrentBatch { n: 4 };
let res = Resources::new();
let result = Task::run(&task, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
struct PartialFailBatch;
#[task::batch]
impl BatchTask<Step> for PartialFailBatch {
type Item = usize;
type ItemOutput = usize;
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok(vec![0, 1, 2])
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
if *item == 1 {
Err(CanoError::task_execution("item 1 failed"))
} else {
Ok(*item)
}
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
assert_eq!(outputs.len(), 3);
assert!(outputs[0].is_ok(), "item 0 should be Ok");
assert!(outputs[1].is_err(), "item 1 should be Err");
assert!(outputs[2].is_ok(), "item 2 should be Ok");
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_failing_item_lands_in_outputs_not_batch_error() {
let task = PartialFailBatch;
let res = Resources::new();
let result = Task::run(&task, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
struct FlakySingleBatch {
call_count: AtomicU32,
}
#[task::batch]
impl BatchTask<Step> for FlakySingleBatch {
type Item = u32;
type ItemOutput = u32;
fn item_retry(&self) -> RetryMode {
RetryMode::fixed(2, Duration::from_millis(1))
}
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok(vec![42])
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
let n = self.call_count.fetch_add(1, Ordering::Relaxed);
if n < 2 {
Err(CanoError::task_execution("transient failure"))
} else {
Ok(*item)
}
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
assert_eq!(outputs.len(), 1);
assert_eq!(
outputs[0].as_ref().unwrap(),
&42,
"flaky item should succeed after retries"
);
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_item_retry_recovers_flaky_item() {
let task = FlakySingleBatch {
call_count: AtomicU32::new(0),
};
let res = Resources::new();
let result = Task::run(&task, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
assert_eq!(task.call_count.load(Ordering::Relaxed), 3);
}
struct LoadFailsBatch {
finish_called: Arc<AtomicBool>,
}
#[task::batch]
impl BatchTask<Step> for LoadFailsBatch {
type Item = u32;
type ItemOutput = u32;
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Err(CanoError::task_execution("load failed"))
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
_outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
self.finish_called.store(true, Ordering::Relaxed);
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_load_error_propagates_and_finish_not_called() {
let finish_called = Arc::new(AtomicBool::new(false));
let task = LoadFailsBatch {
finish_called: Arc::clone(&finish_called),
};
let res = Resources::new();
let err = Task::run(&task, &res).await.unwrap_err();
assert!(
matches!(err, CanoError::TaskExecution(_)),
"load Err should propagate as TaskExecution, got: {err:?}"
);
assert!(
!finish_called.load(Ordering::Relaxed),
"finish must not be called when load fails"
);
}
struct FinishFailsBatch;
#[task::batch]
impl BatchTask<Step> for FinishFailsBatch {
type Item = u32;
type ItemOutput = u32;
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok(vec![1, 2])
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
_outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
Err(CanoError::task_execution("finish failed"))
}
}
#[tokio::test]
async fn test_finish_error_propagates() {
let task = FinishFailsBatch;
let res = Resources::new();
let err = Task::run(&task, &res).await.unwrap_err();
assert!(
matches!(err, CanoError::TaskExecution(_)),
"finish Err should propagate, got: {err:?}"
);
}
struct ZeroConcurrencyBatch;
#[task::batch]
impl BatchTask<Step> for ZeroConcurrencyBatch {
type Item = u32;
type ItemOutput = u32;
fn concurrency(&self) -> usize {
0 }
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok(vec![1, 2, 3])
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
assert_eq!(outputs.len(), 3);
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_zero_concurrency_does_not_deadlock() {
let task = ZeroConcurrencyBatch;
let res = Resources::new();
let result = Task::run(&task, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
struct SimpleBatch;
#[task::batch]
impl BatchTask<Step> for SimpleBatch {
type Item = u32;
type ItemOutput = u32;
async fn load(&self, _res: &Resources) -> Result<Vec<Self::Item>, CanoError> {
Ok(vec![7])
}
async fn process_item(&self, item: &Self::Item) -> Result<Self::ItemOutput, CanoError> {
Ok(*item)
}
async fn finish(
&self,
_res: &Resources,
outputs: Vec<Result<Self::ItemOutput, CanoError>>,
) -> Result<TaskResult<Step>, CanoError> {
let val = outputs[0].as_ref().unwrap();
assert_eq!(*val, 7);
Ok(TaskResult::Single(Step::Done))
}
}
#[tokio::test]
async fn test_batch_task_as_dyn_task() {
let task: Arc<dyn Task<Step>> = Arc::new(SimpleBatch);
let res = Resources::new();
let result = Task::run(task.as_ref(), &res).await.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
#[tokio::test]
async fn test_batch_task_in_workflow() {
use crate::workflow::Workflow;
let workflow = Workflow::bare()
.register(Step::Process, IndexedBatch { n: 3 })
.add_exit_state(Step::Done);
let result = workflow.orchestrate(Step::Process).await.unwrap();
assert_eq!(result, Step::Done);
}
#[tokio::test]
async fn test_run_batch_dyn_dispatch() {
let task = SimpleBatch;
let res = Resources::new();
let result = run_batch::<SimpleBatch, Step, _>(&task, &res)
.await
.unwrap();
assert_eq!(result, TaskResult::Single(Step::Done));
}
}