recoco_utils/
batching.rs

1// ReCoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for ReCoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (ReCoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the ReCoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use std::sync::{Arc, Mutex};
16use tokio::sync::{oneshot, watch};
17use tokio_util::task::AbortOnDropHandle;
18use tracing::error;
19
20use crate::{
21    error::{Error, ResidualError, Result},
22    internal_bail,
23};
24#[async_trait]
25pub trait Runner: Send + Sync {
26    type Input: Send;
27    type Output: Send;
28
29    async fn run(
30        &self,
31        inputs: Vec<Self::Input>,
32    ) -> Result<impl ExactSizeIterator<Item = Self::Output>>;
33}
34
35struct Batch<I, O> {
36    inputs: Vec<I>,
37    output_txs: Vec<oneshot::Sender<Result<O>>>,
38    num_cancelled_tx: watch::Sender<usize>,
39    num_cancelled_rx: watch::Receiver<usize>,
40}
41
42impl<I, O> Default for Batch<I, O> {
43    fn default() -> Self {
44        let (num_cancelled_tx, num_cancelled_rx) = watch::channel(0);
45        Self {
46            inputs: Vec::new(),
47            output_txs: Vec::new(),
48            num_cancelled_tx,
49            num_cancelled_rx,
50        }
51    }
52}
53
54#[derive(Default)]
55enum BatcherState<I, O> {
56    #[default]
57    Idle,
58    Busy {
59        pending_batch: Option<Batch<I, O>>,
60        ongoing_count: usize,
61    },
62}
63
64struct BatcherData<R: Runner + 'static> {
65    runner: R,
66    state: Mutex<BatcherState<R::Input, R::Output>>,
67}
68
69impl<R: Runner + 'static> BatcherData<R> {
70    async fn run_batch(self: &Arc<Self>, batch: Batch<R::Input, R::Output>) {
71        let _kick_off_next = BatchKickOffNext { batcher_data: self };
72        let num_inputs = batch.inputs.len();
73
74        let mut num_cancelled_rx = batch.num_cancelled_rx;
75        let outputs = tokio::select! {
76            outputs = self.runner.run(batch.inputs) => {
77                outputs
78            }
79            _ = num_cancelled_rx.wait_for(|v| *v == num_inputs) => {
80                return;
81            }
82        };
83
84        match outputs {
85            Ok(outputs) => {
86                if outputs.len() != batch.output_txs.len() {
87                    let message = format!(
88                        "Batched output length mismatch: expected {} outputs, got {}",
89                        batch.output_txs.len(),
90                        outputs.len()
91                    );
92                    error!("{message}");
93                    for sender in batch.output_txs {
94                        sender.send(Err(Error::internal_msg(&message))).ok();
95                    }
96                    return;
97                }
98                for (output, sender) in outputs.zip(batch.output_txs) {
99                    sender.send(Ok(output)).ok();
100                }
101            }
102            Err(err) => {
103                let mut senders_iter = batch.output_txs.into_iter();
104                if let Some(sender) = senders_iter.next() {
105                    if senders_iter.len() > 0 {
106                        let residual_err = ResidualError::new(&err);
107                        for sender in senders_iter {
108                            sender.send(Err(residual_err.clone().into())).ok();
109                        }
110                    }
111                    sender.send(Err(err)).ok();
112                }
113            }
114        }
115    }
116}
117
118pub struct Batcher<R: Runner + 'static> {
119    data: Arc<BatcherData<R>>,
120    options: BatchingOptions,
121}
122
123enum BatchExecutionAction<R: Runner + 'static> {
124    Inline {
125        input: R::Input,
126    },
127    Batched {
128        output_rx: oneshot::Receiver<Result<R::Output>>,
129        num_cancelled_tx: watch::Sender<usize>,
130    },
131}
132
133#[derive(Default, Clone, Serialize, Deserialize)]
134pub struct BatchingOptions {
135    pub max_batch_size: Option<usize>,
136}
137impl<R: Runner + 'static> Batcher<R> {
138    pub fn new(runner: R, options: BatchingOptions) -> Self {
139        Self {
140            data: Arc::new(BatcherData {
141                runner,
142                state: Mutex::new(BatcherState::Idle),
143            }),
144            options,
145        }
146    }
147    pub async fn run(&self, input: R::Input) -> Result<R::Output> {
148        let batch_exec_action: BatchExecutionAction<R> = {
149            let mut state = self.data.state.lock().unwrap();
150            match &mut *state {
151                state @ BatcherState::Idle => {
152                    *state = BatcherState::Busy {
153                        pending_batch: None,
154                        ongoing_count: 1,
155                    };
156                    BatchExecutionAction::Inline { input }
157                }
158                BatcherState::Busy {
159                    pending_batch,
160                    ongoing_count,
161                } => {
162                    let batch = pending_batch.get_or_insert_default();
163                    batch.inputs.push(input);
164
165                    let (output_tx, output_rx) = oneshot::channel();
166                    batch.output_txs.push(output_tx);
167
168                    let num_cancelled_tx = batch.num_cancelled_tx.clone();
169
170                    // Check if we've reached max_batch_size and need to flush immediately
171                    let should_flush = self
172                        .options
173                        .max_batch_size
174                        .map(|max_size| batch.inputs.len() >= max_size)
175                        .unwrap_or(false);
176
177                    if should_flush {
178                        // Take the batch and trigger execution
179                        let batch_to_run = pending_batch.take().unwrap();
180                        *ongoing_count += 1;
181                        let data = self.data.clone();
182                        tokio::spawn(async move { data.run_batch(batch_to_run).await });
183                    }
184
185                    BatchExecutionAction::Batched {
186                        output_rx,
187                        num_cancelled_tx,
188                    }
189                }
190            }
191        };
192        match batch_exec_action {
193            BatchExecutionAction::Inline { input } => {
194                let _kick_off_next = BatchKickOffNext {
195                    batcher_data: &self.data,
196                };
197
198                let data = self.data.clone();
199                let handle = AbortOnDropHandle::new(tokio::spawn(async move {
200                    let mut outputs = data.runner.run(vec![input]).await?;
201                    if outputs.len() != 1 {
202                        internal_bail!("Expected 1 output, got {}", outputs.len());
203                    }
204                    Ok(outputs.next().unwrap())
205                }));
206                Ok(handle.await??)
207            }
208            BatchExecutionAction::Batched {
209                output_rx,
210                num_cancelled_tx,
211            } => {
212                let mut guard = BatchRecvCancellationGuard::new(Some(num_cancelled_tx));
213                let output = output_rx.await?;
214                guard.done();
215                output
216            }
217        }
218    }
219}
220
221struct BatchKickOffNext<'a, R: Runner + 'static> {
222    batcher_data: &'a Arc<BatcherData<R>>,
223}
224
225impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> {
226    fn drop(&mut self) {
227        let mut state = self.batcher_data.state.lock().unwrap();
228
229        match &mut *state {
230            BatcherState::Idle => {
231                // Nothing to do, already idle
232            }
233            BatcherState::Busy {
234                pending_batch,
235                ongoing_count,
236            } => {
237                // Decrement the ongoing count first
238                *ongoing_count -= 1;
239
240                if *ongoing_count == 0 {
241                    // All batches done, check if there's a pending batch
242                    if let Some(batch) = pending_batch.take() {
243                        // Kick off the pending batch and set ongoing_count to 1
244                        *ongoing_count = 1;
245                        let data = self.batcher_data.clone();
246                        tokio::spawn(async move { data.run_batch(batch).await });
247                    } else {
248                        // No pending batch, transition to Idle
249                        *state = BatcherState::Idle;
250                    }
251                }
252            }
253        }
254    }
255}
256
257struct BatchRecvCancellationGuard {
258    num_cancelled_tx: Option<watch::Sender<usize>>,
259}
260
261impl Drop for BatchRecvCancellationGuard {
262    fn drop(&mut self) {
263        if let Some(num_cancelled_tx) = self.num_cancelled_tx.take() {
264            num_cancelled_tx.send_modify(|v| *v += 1);
265        }
266    }
267}
268
269impl BatchRecvCancellationGuard {
270    pub fn new(num_cancelled_tx: Option<watch::Sender<usize>>) -> Self {
271        Self { num_cancelled_tx }
272    }
273
274    pub fn done(&mut self) {
275        self.num_cancelled_tx = None;
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::sync::{Arc, Mutex};
283    use tokio::sync::oneshot;
284    use tokio::time::{Duration, sleep};
285
286    struct TestRunner {
287        // Records each call's input values as a vector, in call order
288        recorded_calls: Arc<Mutex<Vec<Vec<i64>>>>,
289    }
290
291    #[async_trait]
292    impl Runner for TestRunner {
293        type Input = (i64, oneshot::Receiver<()>);
294        type Output = i64;
295
296        async fn run(
297            &self,
298            inputs: Vec<Self::Input>,
299        ) -> Result<impl ExactSizeIterator<Item = Self::Output>> {
300            // Record the values for this invocation (order-agnostic)
301            let mut values: Vec<i64> = inputs.iter().map(|(v, _)| *v).collect();
302            values.sort();
303            self.recorded_calls.lock().unwrap().push(values);
304
305            // Split into values and receivers so we can await by value (send-before-wait safe)
306            let (vals, rxs): (Vec<i64>, Vec<oneshot::Receiver<()>>) =
307                inputs.into_iter().map(|(v, rx)| (v, rx)).unzip();
308
309            // Block until every input's signal is fired
310            for (_i, rx) in rxs.into_iter().enumerate() {
311                let _ = rx.await;
312            }
313
314            // Return outputs mapping v -> v * 2
315            let outputs: Vec<i64> = vals.into_iter().map(|v| v * 2).collect();
316            Ok(outputs.into_iter())
317        }
318    }
319
320    async fn wait_until_len(recorded: &Arc<Mutex<Vec<Vec<i64>>>>, expected_len: usize) {
321        for _ in 0..200 {
322            // up to ~2s
323            if recorded.lock().unwrap().len() == expected_len {
324                return;
325            }
326            sleep(Duration::from_millis(10)).await;
327        }
328        panic!("timed out waiting for recorded_calls length {expected_len}");
329    }
330
331    #[tokio::test(flavor = "current_thread")]
332    async fn batches_after_first_inline_call() -> Result<()> {
333        let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
334        let runner = TestRunner {
335            recorded_calls: recorded_calls.clone(),
336        };
337        let batcher = Arc::new(Batcher::new(runner, BatchingOptions::default()));
338
339        let (n1_tx, n1_rx) = oneshot::channel::<()>();
340        let (n2_tx, n2_rx) = oneshot::channel::<()>();
341        let (n3_tx, n3_rx) = oneshot::channel::<()>();
342
343        // Submit first call; it should execute inline and block on n1
344        let b1 = batcher.clone();
345        let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
346
347        // Wait until the runner has recorded the first inline call
348        wait_until_len(&recorded_calls, 1).await;
349
350        // Submit the next two calls; they should be batched together and not run yet
351        let b2 = batcher.clone();
352        let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
353
354        let b3 = batcher.clone();
355        let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
356
357        // Ensure no new batch has started yet
358        {
359            let len_now = recorded_calls.lock().unwrap().len();
360            assert_eq!(
361                len_now, 1,
362                "second invocation should not have started before unblocking first"
363            );
364        }
365
366        // Unblock the first call; this should trigger the next batch of [2,3]
367        let _ = n1_tx.send(());
368
369        // Wait for the batch call to be recorded
370        wait_until_len(&recorded_calls, 2).await;
371
372        // First result should now be available
373        let v1 = f1.await??;
374        assert_eq!(v1, 2);
375
376        // The batched call is waiting on n2 and n3; now unblock both and collect results
377        let _ = n2_tx.send(());
378        let _ = n3_tx.send(());
379
380        let v2 = f2.await??;
381        let v3 = f3.await??;
382        assert_eq!(v2, 4);
383        assert_eq!(v3, 6);
384
385        // Validate the call recording: first [1], then [2, 3]
386        let calls = recorded_calls.lock().unwrap().clone();
387        assert_eq!(calls.len(), 2);
388        assert_eq!(calls[0], vec![1]);
389        assert_eq!(calls[1], vec![2, 3]);
390
391        Ok(())
392    }
393
394    #[tokio::test(flavor = "current_thread")]
395    async fn respects_max_batch_size() -> Result<()> {
396        let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
397        let runner = TestRunner {
398            recorded_calls: recorded_calls.clone(),
399        };
400        let batcher = Arc::new(Batcher::new(
401            runner,
402            BatchingOptions {
403                max_batch_size: Some(2),
404            },
405        ));
406
407        let (n1_tx, n1_rx) = oneshot::channel::<()>();
408        let (n2_tx, n2_rx) = oneshot::channel::<()>();
409        let (n3_tx, n3_rx) = oneshot::channel::<()>();
410        let (n4_tx, n4_rx) = oneshot::channel::<()>();
411
412        // Submit first call; it should execute inline and block on n1
413        let b1 = batcher.clone();
414        let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
415
416        // Wait until the runner has recorded the first inline call
417        wait_until_len(&recorded_calls, 1).await;
418
419        // Submit second call; it should be batched
420        let b2 = batcher.clone();
421        let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
422
423        // Submit third call; this should trigger a flush because max_batch_size=2
424        // The batch [2, 3] should be executed immediately
425        let b3 = batcher.clone();
426        let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
427
428        // Wait for the second batch to be recorded
429        wait_until_len(&recorded_calls, 2).await;
430
431        // Verify that the second batch was triggered by max_batch_size
432        {
433            let calls = recorded_calls.lock().unwrap();
434            assert_eq!(calls.len(), 2, "second batch should have started");
435            assert_eq!(calls[1], vec![2, 3], "second batch should contain [2, 3]");
436        }
437
438        // Submit fourth call; it should wait because there are still ongoing batches
439        let b4 = batcher.clone();
440        let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
441
442        // Give it a moment to ensure no new batch starts
443        sleep(Duration::from_millis(50)).await;
444        {
445            let len_now = recorded_calls.lock().unwrap().len();
446            assert_eq!(
447                len_now, 2,
448                "third batch should not start until all ongoing batches complete"
449            );
450        }
451
452        // Unblock the first inline call
453        let _ = n1_tx.send(());
454
455        // Wait for first result
456        let v1 = f1.await??;
457        assert_eq!(v1, 2);
458
459        // Batch [2,3] is still running, so batch [4] shouldn't start yet
460        sleep(Duration::from_millis(50)).await;
461        {
462            let len_now = recorded_calls.lock().unwrap().len();
463            assert_eq!(
464                len_now, 2,
465                "third batch should not start until all ongoing batches complete"
466            );
467        }
468
469        // Unblock batch [2,3] - this should trigger batch [4] to start
470        let _ = n2_tx.send(());
471        let _ = n3_tx.send(());
472
473        let v2 = f2.await??;
474        let v3 = f3.await??;
475        assert_eq!(v2, 4);
476        assert_eq!(v3, 6);
477
478        // Now batch [4] should start since all previous batches are done
479        wait_until_len(&recorded_calls, 3).await;
480
481        // Unblock batch [4]
482        let _ = n4_tx.send(());
483        let v4 = f4.await??;
484        assert_eq!(v4, 8);
485
486        // Validate the call recording: [1], [2, 3] (flushed by max_batch_size), [4]
487        let calls = recorded_calls.lock().unwrap().clone();
488        assert_eq!(calls.len(), 3);
489        assert_eq!(calls[0], vec![1]);
490        assert_eq!(calls[1], vec![2, 3]);
491        assert_eq!(calls[2], vec![4]);
492
493        Ok(())
494    }
495
496    #[tokio::test(flavor = "current_thread")]
497    async fn tracks_multiple_concurrent_batches() -> Result<()> {
498        let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
499        let runner = TestRunner {
500            recorded_calls: recorded_calls.clone(),
501        };
502        let batcher = Arc::new(Batcher::new(
503            runner,
504            BatchingOptions {
505                max_batch_size: Some(2),
506            },
507        ));
508
509        let (n1_tx, n1_rx) = oneshot::channel::<()>();
510        let (n2_tx, n2_rx) = oneshot::channel::<()>();
511        let (n3_tx, n3_rx) = oneshot::channel::<()>();
512        let (n4_tx, n4_rx) = oneshot::channel::<()>();
513        let (n5_tx, n5_rx) = oneshot::channel::<()>();
514        let (n6_tx, n6_rx) = oneshot::channel::<()>();
515
516        // Submit first call - executes inline
517        let b1 = batcher.clone();
518        let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
519        wait_until_len(&recorded_calls, 1).await;
520
521        // Submit calls 2-3 - should batch and flush at max_batch_size
522        let b2 = batcher.clone();
523        let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
524        let b3 = batcher.clone();
525        let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
526        wait_until_len(&recorded_calls, 2).await;
527
528        // Submit calls 4-5 - should batch and flush at max_batch_size
529        let b4 = batcher.clone();
530        let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
531        let b5 = batcher.clone();
532        let f5 = tokio::spawn(async move { b5.run((5_i64, n5_rx)).await });
533        wait_until_len(&recorded_calls, 3).await;
534
535        // Submit call 6 - should be batched but not flushed yet
536        let b6 = batcher.clone();
537        let f6 = tokio::spawn(async move { b6.run((6_i64, n6_rx)).await });
538
539        // Give it a moment to ensure no new batch starts
540        sleep(Duration::from_millis(50)).await;
541        {
542            let len_now = recorded_calls.lock().unwrap().len();
543            assert_eq!(
544                len_now, 3,
545                "fourth batch should not start with ongoing batches"
546            );
547        }
548
549        // Unblock batch [2, 3] - should not cause [6] to execute yet (batch 1 still ongoing)
550        let _ = n2_tx.send(());
551        let _ = n3_tx.send(());
552        let v2 = f2.await??;
553        let v3 = f3.await??;
554        assert_eq!(v2, 4);
555        assert_eq!(v3, 6);
556
557        sleep(Duration::from_millis(50)).await;
558        {
559            let len_now = recorded_calls.lock().unwrap().len();
560            assert_eq!(
561                len_now, 3,
562                "batch [6] should still not start (batch 1 and batch [4,5] still ongoing)"
563            );
564        }
565
566        // Unblock batch [4, 5] - should not cause [6] to execute yet (batch 1 still ongoing)
567        let _ = n4_tx.send(());
568        let _ = n5_tx.send(());
569        let v4 = f4.await??;
570        let v5 = f5.await??;
571        assert_eq!(v4, 8);
572        assert_eq!(v5, 10);
573
574        sleep(Duration::from_millis(50)).await;
575        {
576            let len_now = recorded_calls.lock().unwrap().len();
577            assert_eq!(
578                len_now, 3,
579                "batch [6] should still not start (batch 1 still ongoing)"
580            );
581        }
582
583        // Unblock batch 1 - NOW batch [6] should start
584        let _ = n1_tx.send(());
585        let v1 = f1.await??;
586        assert_eq!(v1, 2);
587
588        wait_until_len(&recorded_calls, 4).await;
589
590        // Unblock batch [6]
591        let _ = n6_tx.send(());
592        let v6 = f6.await??;
593        assert_eq!(v6, 12);
594
595        // Validate the call recording
596        let calls = recorded_calls.lock().unwrap().clone();
597        assert_eq!(calls.len(), 4);
598        assert_eq!(calls[0], vec![1]);
599        assert_eq!(calls[1], vec![2, 3]);
600        assert_eq!(calls[2], vec![4, 5]);
601        assert_eq!(calls[3], vec![6]);
602
603        Ok(())
604    }
605}