1use std::fmt::Debug;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::Semaphore;
7use tracing::{debug, info};
8use uuid::Uuid;
9
10use crate::action::ActionType;
11use crate::error::FloxideError;
12use crate::node::{Node, NodeId, NodeOutcome};
13use crate::workflow::{Workflow, WorkflowError};
14
15pub trait BatchContext<T>
17where
18 T: Clone + Send + Sync + 'static,
19{
20 fn get_batch_items(&self) -> Result<Vec<T>, FloxideError>;
22
23 fn create_item_context(&self, item: T) -> Result<Self, FloxideError>
25 where
26 Self: Sized;
27
28 fn update_with_results(
30 &mut self,
31 results: &[Result<T, FloxideError>],
32 ) -> Result<(), FloxideError>;
33}
34
35pub struct BatchNode<Context, ItemType, A = crate::action::DefaultAction>
37where
38 Context: BatchContext<ItemType> + Send + Sync + 'static,
39 ItemType: Clone + Send + Sync + 'static,
40 A: ActionType + Clone + Send + Sync + 'static,
41{
42 id: NodeId,
43 item_workflow: Workflow<Context, A>,
44 parallelism: usize,
45 _phantom: PhantomData<(Context, ItemType, A)>,
46}
47
48impl<Context, ItemType, A> BatchNode<Context, ItemType, A>
49where
50 Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
51 ItemType: Clone + Send + Sync + 'static,
52 A: ActionType + Clone + Send + Sync + 'static,
53{
54 pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
56 Self {
57 id: Uuid::new_v4().to_string(),
58 item_workflow,
59 parallelism,
60 _phantom: PhantomData,
61 }
62 }
63}
64
65impl<Context, ItemType, A> Debug for BatchNode<Context, ItemType, A>
66where
67 Context: BatchContext<ItemType> + Send + Sync + 'static,
68 ItemType: Clone + Send + Sync + 'static,
69 A: ActionType + Send + Sync + 'static,
70{
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("BatchNode")
73 .field("id", &self.id)
74 .field("parallelism", &self.parallelism)
75 .finish()
76 }
77}
78
79#[async_trait]
80impl<Context, ItemType, A> Node<Context, A> for BatchNode<Context, ItemType, A>
81where
82 Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
83 ItemType: Clone + Send + Sync + 'static,
84 A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
85{
86 type Output = Vec<Result<ItemType, FloxideError>>;
87
88 fn id(&self) -> NodeId {
89 self.id.clone()
90 }
91
92 async fn process(
93 &self,
94 ctx: &mut Context,
95 ) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
96 debug!(node_id = %self.id, "Getting batch items to process");
98 let items = ctx.get_batch_items()?;
99 info!(node_id = %self.id, item_count = items.len(), "Processing batch items");
100
101 let mut results = Vec::with_capacity(items.len());
102
103 let semaphore = Arc::new(Semaphore::new(self.parallelism));
105 let mut handles = Vec::with_capacity(items.len());
106
107 for item in items {
109 let semaphore = semaphore.clone();
110 let workflow = self.item_workflow.clone();
111 let ctx_clone = ctx.clone();
112
113 let item_clone = item.clone();
115
116 let handle = tokio::spawn(async move {
118 let _permit = semaphore.acquire().await.unwrap();
120
121 match ctx_clone.create_item_context(item_clone) {
122 Ok(mut item_ctx) => match workflow.execute(&mut item_ctx).await {
123 Ok(_) => Ok(item),
124 Err(e) => Err(FloxideError::batch_processing(
125 "Failed to process item",
126 Box::new(e),
127 )),
128 },
129 Err(e) => Err(e),
130 }
131 });
132
133 handles.push(handle);
134 }
135
136 for handle in handles {
138 match handle.await {
139 Ok(result) => results.push(result),
140 Err(e) => results.push(Err(FloxideError::JoinError(e.to_string()))),
141 }
142 }
143
144 ctx.update_with_results(&results)?;
146
147 Ok(NodeOutcome::Success(results))
149 }
150}
151
152pub struct BatchFlow<Context, ItemType, A = crate::action::DefaultAction>
154where
155 Context: BatchContext<ItemType> + Send + Sync + 'static,
156 ItemType: Clone + Send + Sync + 'static,
157 A: ActionType + Clone + Send + Sync + 'static,
158{
159 id: NodeId,
160 batch_node: BatchNode<Context, ItemType, A>,
161}
162
163impl<Context, ItemType, A> BatchFlow<Context, ItemType, A>
164where
165 Context: BatchContext<ItemType> + Clone + Send + Sync + 'static,
166 ItemType: Clone + Send + Sync + 'static,
167 A: ActionType + Default + Debug + Clone + Send + Sync + 'static,
168{
169 pub fn new(item_workflow: Workflow<Context, A>, parallelism: usize) -> Self {
171 Self {
172 id: Uuid::new_v4().to_string(),
173 batch_node: BatchNode::new(item_workflow, parallelism),
174 }
175 }
176
177 pub async fn execute(
179 &self,
180 ctx: &mut Context,
181 ) -> Result<Vec<Result<ItemType, FloxideError>>, WorkflowError> {
182 match self.batch_node.process(ctx).await {
183 Ok(NodeOutcome::Success(results)) => Ok(results),
184 Ok(_) => Err(WorkflowError::NodeExecution(
185 FloxideError::unexpected_outcome("Expected Success outcome from BatchNode"),
186 )),
187 Err(e) => Err(WorkflowError::NodeExecution(e)),
188 }
189 }
190}
191
192impl<Context, ItemType, A> Debug for BatchFlow<Context, ItemType, A>
193where
194 Context: BatchContext<ItemType> + Send + Sync + 'static,
195 ItemType: Clone + Send + Sync + 'static,
196 A: ActionType + Clone + Send + Sync + 'static,
197{
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("BatchFlow")
200 .field("id", &self.id)
201 .field("batch_node", &self.batch_node)
202 .finish()
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::action::DefaultAction;
210 use crate::node::closure::node;
211
212 #[derive(Debug, Clone)]
214 struct TestBatchContext {
215 items: Vec<i32>,
216 results: Vec<Result<i32, FloxideError>>,
217 }
218
219 impl BatchContext<i32> for TestBatchContext {
221 fn get_batch_items(&self) -> Result<Vec<i32>, FloxideError> {
222 Ok(self.items.clone())
223 }
224
225 fn create_item_context(&self, item: i32) -> Result<Self, FloxideError> {
226 Ok(TestBatchContext {
227 items: vec![item],
228 results: Vec::new(),
229 })
230 }
231
232 fn update_with_results(
233 &mut self,
234 results: &[Result<i32, FloxideError>],
235 ) -> Result<(), FloxideError> {
236 self.results = results.to_vec();
237 Ok(())
238 }
239 }
240
241 #[tokio::test]
242 async fn test_batch_node_processing() {
243 let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
245 let item = ctx.items[0] * 2;
247 ctx.items = vec![item];
248 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
249 }));
250
251 let batch_node = BatchNode::new(item_workflow, 4);
253
254 let mut ctx = TestBatchContext {
256 items: vec![1, 2, 3, 4, 5],
257 results: Vec::new(),
258 };
259
260 let result = batch_node.process(&mut ctx).await.unwrap();
262
263 match result {
265 NodeOutcome::Success(results) => {
266 assert_eq!(results.len(), 5);
267 assert!(results.iter().all(|r| r.is_ok()));
268 }
269 _ => panic!("Expected Success outcome"),
270 }
271 }
272
273 #[tokio::test]
274 async fn test_batch_flow_execution() {
275 let item_workflow = Workflow::new(node(|mut ctx: TestBatchContext| async move {
277 let item = ctx.items[0] * 2;
279 ctx.items = vec![item];
280 Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
281 }));
282
283 let batch_flow = BatchFlow::new(item_workflow, 4);
285
286 let mut ctx = TestBatchContext {
288 items: vec![1, 2, 3, 4, 5],
289 results: Vec::new(),
290 };
291
292 let results = batch_flow.execute(&mut ctx).await.unwrap();
294
295 assert_eq!(results.len(), 5);
297 assert!(results.iter().all(|r| r.is_ok()));
298 }
299}