use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Semaphore;
use tracing::{debug, info};
use uuid::Uuid;
use crate::action::ActionType;
use crate::error::FlowrsError;
use crate::node::{Node, NodeId, NodeOutcome};
use crate::workflow::{Workflow, WorkflowError};
pub trait BatchContext<T>
where
T: Clone + Send + Sync + 'static,
{
fn get_batch_items(&self) -> Result<Vec<T>, FlowrsError>;
fn create_item_context(&self, item: T) -> Result<Self, FlowrsError>
where
Self: Sized;
fn update_with_results(
&mut self,
results: &[Result<T, FlowrsError>],
) -> Result<(), FlowrsError>;
}
pub struct BatchNode<Context, ItemType, A = crate::action::DefaultAction>
where
Context: BatchContext<ItemType> + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Clone + Send + Sync + 'static,
{
id: NodeId,
item_workflow: Workflow<Context, A>,
parallelism: usize,
_phantom: PhantomData<(Context, ItemType, A)>,
}
impl<Context, ItemType, A> BatchNode<Context, ItemType, A>
where
Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Clone + Send + Sync + 'static,
{
pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
Self {
id: Uuid::new_v4().to_string(),
item_workflow,
parallelism,
_phantom: PhantomData,
}
}
}
impl<Context, ItemType, A> Debug for BatchNode<Context, ItemType, A>
where
Context: BatchContext<ItemType> + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchNode")
.field("id", &self.id)
.field("parallelism", &self.parallelism)
.finish()
}
}
#[async_trait]
impl<Context, ItemType, A> Node<Context, A> for BatchNode<Context, ItemType, A>
where
Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
{
type Output = Vec<Result<ItemType, FlowrsError>>;
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(
&self,
ctx: &mut Context,
) -> Result<NodeOutcome<Self::Output, A>, FlowrsError> {
debug!(node_id = %self.id, "Getting batch items to process");
let items = ctx.get_batch_items()?;
info!(node_id = %self.id, item_count = items.len(), "Processing batch items");
let mut results = Vec::with_capacity(items.len());
let semaphore = Arc::new(Semaphore::new(self.parallelism));
let mut handles = Vec::with_capacity(items.len());
for item in items {
let semaphore = semaphore.clone();
let workflow = self.item_workflow.clone();
let ctx_clone = ctx.clone();
let item_clone = item.clone();
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
match ctx_clone.create_item_context(item_clone) {
Ok(mut item_ctx) => match workflow.execute(&mut item_ctx).await {
Ok(_) => Ok(item),
Err(e) => Err(FlowrsError::batch_processing(
"Failed to process item",
Box::new(e),
)),
},
Err(e) => Err(e),
}
});
handles.push(handle);
}
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(e) => results.push(Err(FlowrsError::JoinError(e.to_string()))),
}
}
ctx.update_with_results(&results)?;
Ok(NodeOutcome::Success(results))
}
}
pub struct BatchFlow<Context, ItemType, A = crate::action::DefaultAction>
where
Context: BatchContext<ItemType> + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Clone + Send + Sync + 'static,
{
id: NodeId,
batch_node: BatchNode<Context, ItemType, A>,
}
impl<Context, ItemType, A> BatchFlow<Context, ItemType, A>
where
Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
{
pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
Self {
id: Uuid::new_v4().to_string(),
batch_node: BatchNode::new(item_workflow, parallelism),
}
}
pub async fn execute(
&self,
ctx: &mut Context,
) -> Result<Vec<Result<ItemType, FlowrsError>>, WorkflowError> {
match self.batch_node.process(ctx).await {
Ok(NodeOutcome::Success(results)) => Ok(results),
Ok(_) => Err(WorkflowError::NodeExecution(
FlowrsError::unexpected_outcome("Expected Success outcome from BatchNode"),
)),
Err(e) => Err(WorkflowError::NodeExecution(e)),
}
}
}
impl<Context, ItemType, A> Debug for BatchFlow<Context, ItemType, A>
where
Context: BatchContext<ItemType> + Send + Sync + 'static,
ItemType: Clone + Send + Sync + 'static,
A: ActionType + Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchFlow")
.field("id", &self.id)
.field("batch_node", &self.batch_node)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::action::DefaultAction;
use crate::node::closure::node;
#[derive(Debug, Clone)]
struct TestBatchContext {
items: Vec<i32>,
results: Vec<Result<i32, FlowrsError>>,
}
impl BatchContext<i32> for TestBatchContext {
fn get_batch_items(&self) -> Result<Vec<i32>, FlowrsError> {
Ok(self.items.clone())
}
fn create_item_context(&self, item: i32) -> Result<Self, FlowrsError> {
Ok(TestBatchContext {
items: vec![item],
results: Vec::new(),
})
}
fn update_with_results(
&mut self,
results: &[Result<i32, FlowrsError>],
) -> Result<(), FlowrsError> {
self.results = results.to_vec();
Ok(())
}
}
#[tokio::test]
async fn test_batch_node_processing() {
let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
let item = ctx.items[0] * 2;
ctx.items = vec![item];
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
}));
let batch_node = BatchNode::new(item_workflow, 4);
let mut ctx = TestBatchContext {
items: vec![1, 2, 3, 4, 5],
results: Vec::new(),
};
let result = batch_node.process(&mut ctx).await.unwrap();
match result {
NodeOutcome::Success(results) => {
assert_eq!(results.len(), 5);
assert!(results.iter().all(|r| r.is_ok()));
}
_ => panic!("Expected Success outcome"),
}
}
#[tokio::test]
async fn test_batch_flow_execution() {
let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
let item = ctx.items[0] * 2;
ctx.items = vec![item];
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
}));
let batch_flow = BatchFlow::new(item_workflow, 4);
let mut ctx = TestBatchContext {
items: vec![1, 2, 3, 4, 5],
results: Vec::new(),
};
let results = batch_flow.execute(&mut ctx).await.unwrap();
assert_eq!(results.len(), 5);
assert!(results.iter().all(|r| r.is_ok()));
}
}