1use crate::context::Context;
2use crate::error::FloxideError;
3use crate::node::Node;
4use crate::transition::Transition;
5use async_trait::async_trait;
6use futures::future::join_all;
7use std::marker::PhantomData;
8use std::vec::Vec;
9use tokio::task;
10use tracing;
11
12#[derive(Clone, Debug)]
14pub struct BatchNode<C: Context, N: Node<C>> {
15 pub node: N,
16 pub batch_size: usize,
17 _phantom: PhantomData<C>,
18}
19
20impl<C: Context, N: Node<C>> BatchNode<C, N> {
21 pub fn new(node: N, batch_size: usize) -> Self {
23 BatchNode {
24 node,
25 batch_size,
26 _phantom: PhantomData,
27 }
28 }
29}
30
31impl<C: Context, N: Node<C>> BatchNode<C, N> {
32 pub async fn process_batch(
34 &self,
35 ctx: C,
37 inputs: <Self as Node<C>>::Input,
38 ) -> Result<<Self as Node<C>>::Output, FloxideError>
39 where
40 C: Context + 'static,
41 N: Node<C> + Clone + Send + Sync + 'static,
42 <N as Node<C>>::Input: Clone + Send + 'static,
43 <N as Node<C>>::Output: Send + 'static,
44 {
45 use tracing::{debug, error};
46 debug!(
47 batch_size = self.batch_size,
48 num_inputs = inputs.len(),
49 "Starting batch processing"
50 );
51 let mut outputs = Vec::new();
52 let node = self.node.clone();
53 let ctx_clone = ctx.clone();
54 let mut tasks = Vec::new();
55
56 for input in inputs.into_iter() {
57 let node = node.clone();
58 let ctx = ctx_clone.clone();
60 let task = task::spawn_blocking(move || {
61 tokio::runtime::Handle::current()
62 .block_on(async move { node.process(&ctx, input).await })
63 });
64 tasks.push(task);
65
66 if tasks.len() >= self.batch_size {
67 debug!(current_batch = tasks.len(), "Processing batch");
68 let results = join_all(tasks).await;
69 tasks = Vec::new();
70 for res in results {
71 match res {
72 Ok(Ok(Transition::Next(o))) => outputs.push(o),
73 Ok(Ok(Transition::NextAll(os))) => outputs.extend(os),
74 Ok(Ok(Transition::Hold)) => {}
75 Ok(Ok(Transition::Abort(e))) => {
76 error!(?e, "Node aborted during batch");
77 return Err(e);
78 }
79 Ok(Err(e)) => {
80 error!(?e, "Node errored during batch");
81 return Err(e);
82 }
83 Err(e) => {
84 error!(?e, "Join error during batch");
85 return Err(FloxideError::Generic(format!("Join error: {e}")));
86 }
87 }
88 }
89 }
90 }
91
92 if !tasks.is_empty() {
93 debug!(final_batch = tasks.len(), "Processing final batch");
94 let results = join_all(tasks).await;
95 for res in results {
96 match res {
97 Ok(Ok(Transition::Next(o))) => outputs.push(o),
98 Ok(Ok(Transition::NextAll(os))) => outputs.extend(os),
99 Ok(Ok(Transition::Hold)) => {}
100 Ok(Ok(Transition::Abort(e))) => {
101 error!(?e, "Node aborted during final batch");
102 return Err(e);
103 }
104 Ok(Err(e)) => {
105 error!(?e, "Node errored during final batch");
106 return Err(e);
107 }
108 Err(e) => {
109 error!(?e, "Join error during final batch");
110 return Err(FloxideError::Generic(format!("Join error: {e}")));
111 }
112 }
113 }
114 }
115
116 debug!(num_outputs = outputs.len(), "Batch processing complete");
117 Ok(outputs)
118 }
119}
120
121#[async_trait]
122impl<C, N> Node<C> for BatchNode<C, N>
123where
124 C: Context + 'static,
125 N: Node<C> + Clone + Send + Sync + 'static,
126 <N as Node<C>>::Input: Clone + Send + 'static,
127 <N as Node<C>>::Output: Send + 'static,
128{
129 type Input = Vec<<N as Node<C>>::Input>;
130 type Output = Vec<<N as Node<C>>::Output>;
131
132 async fn process(
133 &self,
134 ctx: &C,
135 inputs: Self::Input,
136 ) -> Result<Transition<Self::Output>, FloxideError> {
137 let outputs = self.process_batch((*ctx).clone(), inputs).await?;
138 Ok(Transition::Next(outputs))
139 }
140}