floxide_core/
batch.rs

1use crate::context::Context;
2use crate::error::FloxideError;
3use crate::node::Node;
4use crate::transition::Transition;
5use async_trait::async_trait;
6use futures::future::join_all;
7use std::marker::PhantomData;
8use std::vec::Vec;
9use tokio::task;
10use tracing;
11
12/// A node adapter that runs an inner node on a batch of inputs, collecting outputs in parallel
13#[derive(Clone, Debug)]
14pub struct BatchNode<C: Context, N: Node<C>> {
15    pub node: N,
16    pub batch_size: usize,
17    _phantom: PhantomData<C>,
18}
19
20impl<C: Context, N: Node<C>> BatchNode<C, N> {
21    /// Wraps an existing node into a batch adapter with a batch size
22    pub fn new(node: N, batch_size: usize) -> Self {
23        BatchNode {
24            node,
25            batch_size,
26            _phantom: PhantomData,
27        }
28    }
29}
30
31impl<C: Context, N: Node<C>> BatchNode<C, N> {
32    /// Process a batch of inputs, where the associated Input/Output types are Vecs
33    pub async fn process_batch(
34        &self,
35        // take ownership of the context so it can be cloned into blocking tasks
36        ctx: C,
37        inputs: <Self as Node<C>>::Input,
38    ) -> Result<<Self as Node<C>>::Output, FloxideError>
39    where
40        C: Context + 'static,
41        N: Node<C> + Clone + Send + Sync + 'static,
42        <N as Node<C>>::Input: Clone + Send + 'static,
43        <N as Node<C>>::Output: Send + 'static,
44    {
45        use tracing::{debug, error};
46        debug!(
47            batch_size = self.batch_size,
48            num_inputs = inputs.len(),
49            "Starting batch processing"
50        );
51        let mut outputs = Vec::new();
52        let node = self.node.clone();
53        let ctx_clone = ctx.clone();
54        let mut tasks = Vec::new();
55
56        for input in inputs.into_iter() {
57            let node = node.clone();
58            // clone the context reference (does not require C: Clone)
59            let ctx = ctx_clone.clone();
60            let task = task::spawn_blocking(move || {
61                tokio::runtime::Handle::current()
62                    .block_on(async move { node.process(&ctx, input).await })
63            });
64            tasks.push(task);
65
66            if tasks.len() >= self.batch_size {
67                debug!(current_batch = tasks.len(), "Processing batch");
68                let results = join_all(tasks).await;
69                tasks = Vec::new();
70                for res in results {
71                    match res {
72                        Ok(Ok(Transition::Next(o))) => outputs.push(o),
73                        Ok(Ok(Transition::NextAll(os))) => outputs.extend(os),
74                        Ok(Ok(Transition::Hold)) => {}
75                        Ok(Ok(Transition::Abort(e))) => {
76                            error!(?e, "Node aborted during batch");
77                            return Err(e);
78                        }
79                        Ok(Err(e)) => {
80                            error!(?e, "Node errored during batch");
81                            return Err(e);
82                        }
83                        Err(e) => {
84                            error!(?e, "Join error during batch");
85                            return Err(FloxideError::Generic(format!("Join error: {e}")));
86                        }
87                    }
88                }
89            }
90        }
91
92        if !tasks.is_empty() {
93            debug!(final_batch = tasks.len(), "Processing final batch");
94            let results = join_all(tasks).await;
95            for res in results {
96                match res {
97                    Ok(Ok(Transition::Next(o))) => outputs.push(o),
98                    Ok(Ok(Transition::NextAll(os))) => outputs.extend(os),
99                    Ok(Ok(Transition::Hold)) => {}
100                    Ok(Ok(Transition::Abort(e))) => {
101                        error!(?e, "Node aborted during final batch");
102                        return Err(e);
103                    }
104                    Ok(Err(e)) => {
105                        error!(?e, "Node errored during final batch");
106                        return Err(e);
107                    }
108                    Err(e) => {
109                        error!(?e, "Join error during final batch");
110                        return Err(FloxideError::Generic(format!("Join error: {e}")));
111                    }
112                }
113            }
114        }
115
116        debug!(num_outputs = outputs.len(), "Batch processing complete");
117        Ok(outputs)
118    }
119}
120
121#[async_trait]
122impl<C, N> Node<C> for BatchNode<C, N>
123where
124    C: Context + 'static,
125    N: Node<C> + Clone + Send + Sync + 'static,
126    <N as Node<C>>::Input: Clone + Send + 'static,
127    <N as Node<C>>::Output: Send + 'static,
128{
129    type Input = Vec<<N as Node<C>>::Input>;
130    type Output = Vec<<N as Node<C>>::Output>;
131
132    async fn process(
133        &self,
134        ctx: &C,
135        inputs: Self::Input,
136    ) -> Result<Transition<Self::Output>, FloxideError> {
137        let outputs = self.process_batch((*ctx).clone(), inputs).await?;
138        Ok(Transition::Next(outputs))
139    }
140}