1use std::fmt::Debug;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::Semaphore;
7use tracing::{debug, info};
8use uuid::Uuid;
9
10use crate::action::ActionType;
11use crate::error::FloxideError;
12use crate::node::{Node, NodeId, NodeOutcome};
13use crate::workflow::{Workflow, WorkflowError};
14
15pub trait BatchContext<T>
17where
18    T: Clone + Send + Sync + 'static,
19{
20    fn get_batch_items(&self) -> Result<Vec<T>, FloxideError>;
22
23    fn create_item_context(&self, item: T) -> Result<Self, FloxideError>
25    where
26        Self: Sized;
27
28    fn update_with_results(
30        &mut self,
31        results: &[Result<T, FloxideError>],
32    ) -> Result<(), FloxideError>;
33}
34
35pub struct BatchNode<Context, ItemType, A = crate::action::DefaultAction>
37where
38    Context: BatchContext<ItemType> + Send + Sync + 'static,
39    ItemType: Clone + Send + Sync + 'static,
40    A: ActionType + Clone + Send + Sync + 'static,
41{
42    id: NodeId,
43    item_workflow: Workflow<Context, A>,
44    parallelism: usize,
45    _phantom: PhantomData<(Context, ItemType, A)>,
46}
47
48impl<Context, ItemType, A> BatchNode<Context, ItemType, A>
49where
50    Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
51    ItemType: Clone + Send + Sync + 'static,
52    A: ActionType + Clone + Send + Sync + 'static,
53{
54    pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
56        Self {
57            id: Uuid::new_v4().to_string(),
58            item_workflow,
59            parallelism,
60            _phantom: PhantomData,
61        }
62    }
63}
64
65impl<Context, ItemType, A> Debug for BatchNode<Context, ItemType, A>
66where
67    Context: BatchContext<ItemType> + Send + Sync + 'static,
68    ItemType: Clone + Send + Sync + 'static,
69    A: ActionType + Send + Sync + 'static,
70{
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("BatchNode")
73            .field("id", &self.id)
74            .field("parallelism", &self.parallelism)
75            .finish()
76    }
77}
78
79#[async_trait]
80impl<Context, ItemType, A> Node<Context, A> for BatchNode<Context, ItemType, A>
81where
82    Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
83    ItemType: Clone + Send + Sync + 'static,
84    A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
85{
86    type Output = Vec<Result<ItemType, FloxideError>>;
87
88    fn id(&self) -> NodeId {
89        self.id.clone()
90    }
91
92    async fn process(
93        &self,
94        ctx: &mut Context,
95    ) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
96        debug!(node_id = %self.id, "Getting batch items to process");
98        let items = ctx.get_batch_items()?;
99        info!(node_id = %self.id, item_count = items.len(), "Processing batch items");
100
101        let mut results = Vec::with_capacity(items.len());
102
103        let semaphore = Arc::new(Semaphore::new(self.parallelism));
105        let mut handles = Vec::with_capacity(items.len());
106
107        for item in items {
109            let semaphore = semaphore.clone();
110            let workflow = self.item_workflow.clone();
111            let ctx_clone = ctx.clone();
112
113            let item_clone = item.clone();
115
116            let handle = tokio::spawn(async move {
118                let _permit = semaphore.acquire().await.unwrap();
120
121                match ctx_clone.create_item_context(item_clone) {
122                    Ok(mut item_ctx) => match workflow.execute(&mut item_ctx).await {
123                        Ok(_) => Ok(item),
124                        Err(e) => Err(FloxideError::batch_processing(
125                            "Failed to process item",
126                            Box::new(e),
127                        )),
128                    },
129                    Err(e) => Err(e),
130                }
131            });
132
133            handles.push(handle);
134        }
135
136        for handle in handles {
138            match handle.await {
139                Ok(result) => results.push(result),
140                Err(e) => results.push(Err(FloxideError::JoinError(e.to_string()))),
141            }
142        }
143
144        ctx.update_with_results(&results)?;
146
147        Ok(NodeOutcome::Success(results))
149    }
150}
151
152pub struct BatchFlow<Context, ItemType, A = crate::action::DefaultAction>
154where
155    Context: BatchContext<ItemType> + Send + Sync + 'static,
156    ItemType: Clone + Send + Sync + 'static,
157    A: ActionType + Clone + Send + Sync + 'static,
158{
159    id: NodeId,
160    batch_node: BatchNode<Context, ItemType, A>,
161}
162
163impl<Context, ItemType, A> BatchFlow<Context, ItemType, A>
164where
165    Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
166    ItemType: Clone + Send + Sync + 'static,
167    A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
168{
169    pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
171        Self {
172            id: Uuid::new_v4().to_string(),
173            batch_node: BatchNode::new(item_workflow, parallelism),
174        }
175    }
176
177    pub async fn execute(
179        &self,
180        ctx: &mut Context,
181    ) -> Result<Vec<Result<ItemType, FloxideError>>, WorkflowError> {
182        match self.batch_node.process(ctx).await {
183            Ok(NodeOutcome::Success(results)) => Ok(results),
184            Ok(_) => Err(WorkflowError::NodeExecution(
185                FloxideError::unexpected_outcome("Expected Success outcome from BatchNode"),
186            )),
187            Err(e) => Err(WorkflowError::NodeExecution(e)),
188        }
189    }
190}
191
192impl<Context, ItemType, A> Debug for BatchFlow<Context, ItemType, A>
193where
194    Context: BatchContext<ItemType> + Send + Sync + 'static,
195    ItemType: Clone + Send + Sync + 'static,
196    A: ActionType + Clone + Send + Sync + 'static,
197{
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("BatchFlow")
200            .field("id", &self.id)
201            .field("batch_node", &self.batch_node)
202            .finish()
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::action::DefaultAction;
210    use crate::node::closure::node;
211
212    #[derive(Debug, Clone)]
214    struct TestBatchContext {
215        items: Vec<i32>,
216        results: Vec<Result<i32, FloxideError>>,
217    }
218
219    impl BatchContext<i32> for TestBatchContext {
221        fn get_batch_items(&self) -> Result<Vec<i32>, FloxideError> {
222            Ok(self.items.clone())
223        }
224
225        fn create_item_context(&self, item: i32) -> Result<Self, FloxideError> {
226            Ok(TestBatchContext {
227                items: vec![item],
228                results: Vec::new(),
229            })
230        }
231
232        fn update_with_results(
233            &mut self,
234            results: &[Result<i32, FloxideError>],
235        ) -> Result<(), FloxideError> {
236            self.results = results.to_vec();
237            Ok(())
238        }
239    }
240
241    #[tokio::test]
242    async fn test_batch_node_processing() {
243        let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
245            let item = ctx.items[0] * 2;
247            ctx.items = vec![item];
248            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
249        }));
250
251        let batch_node = BatchNode::new(item_workflow, 4);
253
254        let mut ctx = TestBatchContext {
256            items: vec![1, 2, 3, 4, 5],
257            results: Vec::new(),
258        };
259
260        let result = batch_node.process(&mut ctx).await.unwrap();
262
263        match result {
265            NodeOutcome::Success(results) => {
266                assert_eq!(results.len(), 5);
267                assert!(results.iter().all(|r| r.is_ok()));
268            }
269            _ => panic!("Expected Success outcome"),
270        }
271    }
272
273    #[tokio::test]
274    async fn test_batch_flow_execution() {
275        let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
277            let item = ctx.items[0] * 2;
279            ctx.items = vec![item];
280            Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
281        }));
282
283        let batch_flow = BatchFlow::new(item_workflow, 4);
285
286        let mut ctx = TestBatchContext {
288            items: vec![1, 2, 3, 4, 5],
289            results: Vec::new(),
290        };
291
292        let results = batch_flow.execute(&mut ctx).await.unwrap();
294
295        assert_eq!(results.len(), 5);
297        assert!(results.iter().all(|r| r.is_ok()));
298    }
299}