1use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use std::sync::{Arc, Mutex};
16use tokio::sync::{oneshot, watch};
17use tokio_util::task::AbortOnDropHandle;
18use tracing::error;
19
20use crate::{
21 error::{Error, ResidualError, Result},
22 internal_bail,
23};
24#[async_trait]
25pub trait Runner: Send + Sync {
26 type Input: Send;
27 type Output: Send;
28
29 async fn run(
30 &self,
31 inputs: Vec<Self::Input>,
32 ) -> Result<impl ExactSizeIterator<Item = Self::Output>>;
33}
34
35struct Batch<I, O> {
36 inputs: Vec<I>,
37 output_txs: Vec<oneshot::Sender<Result<O>>>,
38 num_cancelled_tx: watch::Sender<usize>,
39 num_cancelled_rx: watch::Receiver<usize>,
40}
41
42impl<I, O> Default for Batch<I, O> {
43 fn default() -> Self {
44 let (num_cancelled_tx, num_cancelled_rx) = watch::channel(0);
45 Self {
46 inputs: Vec::new(),
47 output_txs: Vec::new(),
48 num_cancelled_tx,
49 num_cancelled_rx,
50 }
51 }
52}
53
54#[derive(Default)]
55enum BatcherState<I, O> {
56 #[default]
57 Idle,
58 Busy {
59 pending_batch: Option<Batch<I, O>>,
60 ongoing_count: usize,
61 },
62}
63
64struct BatcherData<R: Runner + 'static> {
65 runner: R,
66 state: Mutex<BatcherState<R::Input, R::Output>>,
67}
68
69impl<R: Runner + 'static> BatcherData<R> {
70 async fn run_batch(self: &Arc<Self>, batch: Batch<R::Input, R::Output>) {
71 let _kick_off_next = BatchKickOffNext { batcher_data: self };
72 let num_inputs = batch.inputs.len();
73
74 let mut num_cancelled_rx = batch.num_cancelled_rx;
75 let outputs = tokio::select! {
76 outputs = self.runner.run(batch.inputs) => {
77 outputs
78 }
79 _ = num_cancelled_rx.wait_for(|v| *v == num_inputs) => {
80 return;
81 }
82 };
83
84 match outputs {
85 Ok(outputs) => {
86 if outputs.len() != batch.output_txs.len() {
87 let message = format!(
88 "Batched output length mismatch: expected {} outputs, got {}",
89 batch.output_txs.len(),
90 outputs.len()
91 );
92 error!("{message}");
93 for sender in batch.output_txs {
94 sender.send(Err(Error::internal_msg(&message))).ok();
95 }
96 return;
97 }
98 for (output, sender) in outputs.zip(batch.output_txs) {
99 sender.send(Ok(output)).ok();
100 }
101 }
102 Err(err) => {
103 let mut senders_iter = batch.output_txs.into_iter();
104 if let Some(sender) = senders_iter.next() {
105 if senders_iter.len() > 0 {
106 let residual_err = ResidualError::new(&err);
107 for sender in senders_iter {
108 sender.send(Err(residual_err.clone().into())).ok();
109 }
110 }
111 sender.send(Err(err)).ok();
112 }
113 }
114 }
115 }
116}
117
118pub struct Batcher<R: Runner + 'static> {
119 data: Arc<BatcherData<R>>,
120 options: BatchingOptions,
121}
122
123enum BatchExecutionAction<R: Runner + 'static> {
124 Inline {
125 input: R::Input,
126 },
127 Batched {
128 output_rx: oneshot::Receiver<Result<R::Output>>,
129 num_cancelled_tx: watch::Sender<usize>,
130 },
131}
132
133#[derive(Default, Clone, Serialize, Deserialize)]
134pub struct BatchingOptions {
135 pub max_batch_size: Option<usize>,
136}
137impl<R: Runner + 'static> Batcher<R> {
138 pub fn new(runner: R, options: BatchingOptions) -> Self {
139 Self {
140 data: Arc::new(BatcherData {
141 runner,
142 state: Mutex::new(BatcherState::Idle),
143 }),
144 options,
145 }
146 }
147 pub async fn run(&self, input: R::Input) -> Result<R::Output> {
148 let batch_exec_action: BatchExecutionAction<R> = {
149 let mut state = self.data.state.lock().unwrap();
150 match &mut *state {
151 state @ BatcherState::Idle => {
152 *state = BatcherState::Busy {
153 pending_batch: None,
154 ongoing_count: 1,
155 };
156 BatchExecutionAction::Inline { input }
157 }
158 BatcherState::Busy {
159 pending_batch,
160 ongoing_count,
161 } => {
162 let batch = pending_batch.get_or_insert_default();
163 batch.inputs.push(input);
164
165 let (output_tx, output_rx) = oneshot::channel();
166 batch.output_txs.push(output_tx);
167
168 let num_cancelled_tx = batch.num_cancelled_tx.clone();
169
170 let should_flush = self
172 .options
173 .max_batch_size
174 .map(|max_size| batch.inputs.len() >= max_size)
175 .unwrap_or(false);
176
177 if should_flush {
178 let batch_to_run = pending_batch.take().unwrap();
180 *ongoing_count += 1;
181 let data = self.data.clone();
182 tokio::spawn(async move { data.run_batch(batch_to_run).await });
183 }
184
185 BatchExecutionAction::Batched {
186 output_rx,
187 num_cancelled_tx,
188 }
189 }
190 }
191 };
192 match batch_exec_action {
193 BatchExecutionAction::Inline { input } => {
194 let _kick_off_next = BatchKickOffNext {
195 batcher_data: &self.data,
196 };
197
198 let data = self.data.clone();
199 let handle = AbortOnDropHandle::new(tokio::spawn(async move {
200 let mut outputs = data.runner.run(vec![input]).await?;
201 if outputs.len() != 1 {
202 internal_bail!("Expected 1 output, got {}", outputs.len());
203 }
204 Ok(outputs.next().unwrap())
205 }));
206 Ok(handle.await??)
207 }
208 BatchExecutionAction::Batched {
209 output_rx,
210 num_cancelled_tx,
211 } => {
212 let mut guard = BatchRecvCancellationGuard::new(Some(num_cancelled_tx));
213 let output = output_rx.await?;
214 guard.done();
215 output
216 }
217 }
218 }
219}
220
221struct BatchKickOffNext<'a, R: Runner + 'static> {
222 batcher_data: &'a Arc<BatcherData<R>>,
223}
224
225impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> {
226 fn drop(&mut self) {
227 let mut state = self.batcher_data.state.lock().unwrap();
228
229 match &mut *state {
230 BatcherState::Idle => {
231 }
233 BatcherState::Busy {
234 pending_batch,
235 ongoing_count,
236 } => {
237 *ongoing_count -= 1;
239
240 if *ongoing_count == 0 {
241 if let Some(batch) = pending_batch.take() {
243 *ongoing_count = 1;
245 let data = self.batcher_data.clone();
246 tokio::spawn(async move { data.run_batch(batch).await });
247 } else {
248 *state = BatcherState::Idle;
250 }
251 }
252 }
253 }
254 }
255}
256
257struct BatchRecvCancellationGuard {
258 num_cancelled_tx: Option<watch::Sender<usize>>,
259}
260
261impl Drop for BatchRecvCancellationGuard {
262 fn drop(&mut self) {
263 if let Some(num_cancelled_tx) = self.num_cancelled_tx.take() {
264 num_cancelled_tx.send_modify(|v| *v += 1);
265 }
266 }
267}
268
269impl BatchRecvCancellationGuard {
270 pub fn new(num_cancelled_tx: Option<watch::Sender<usize>>) -> Self {
271 Self { num_cancelled_tx }
272 }
273
274 pub fn done(&mut self) {
275 self.num_cancelled_tx = None;
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::sync::{Arc, Mutex};
283 use tokio::sync::oneshot;
284 use tokio::time::{Duration, sleep};
285
286 struct TestRunner {
287 recorded_calls: Arc<Mutex<Vec<Vec<i64>>>>,
289 }
290
291 #[async_trait]
292 impl Runner for TestRunner {
293 type Input = (i64, oneshot::Receiver<()>);
294 type Output = i64;
295
296 async fn run(
297 &self,
298 inputs: Vec<Self::Input>,
299 ) -> Result<impl ExactSizeIterator<Item = Self::Output>> {
300 let mut values: Vec<i64> = inputs.iter().map(|(v, _)| *v).collect();
302 values.sort();
303 self.recorded_calls.lock().unwrap().push(values);
304
305 let (vals, rxs): (Vec<i64>, Vec<oneshot::Receiver<()>>) =
307 inputs.into_iter().map(|(v, rx)| (v, rx)).unzip();
308
309 for (_i, rx) in rxs.into_iter().enumerate() {
311 let _ = rx.await;
312 }
313
314 let outputs: Vec<i64> = vals.into_iter().map(|v| v * 2).collect();
316 Ok(outputs.into_iter())
317 }
318 }
319
320 async fn wait_until_len(recorded: &Arc<Mutex<Vec<Vec<i64>>>>, expected_len: usize) {
321 for _ in 0..200 {
322 if recorded.lock().unwrap().len() == expected_len {
324 return;
325 }
326 sleep(Duration::from_millis(10)).await;
327 }
328 panic!("timed out waiting for recorded_calls length {expected_len}");
329 }
330
331 #[tokio::test(flavor = "current_thread")]
332 async fn batches_after_first_inline_call() -> Result<()> {
333 let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
334 let runner = TestRunner {
335 recorded_calls: recorded_calls.clone(),
336 };
337 let batcher = Arc::new(Batcher::new(runner, BatchingOptions::default()));
338
339 let (n1_tx, n1_rx) = oneshot::channel::<()>();
340 let (n2_tx, n2_rx) = oneshot::channel::<()>();
341 let (n3_tx, n3_rx) = oneshot::channel::<()>();
342
343 let b1 = batcher.clone();
345 let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
346
347 wait_until_len(&recorded_calls, 1).await;
349
350 let b2 = batcher.clone();
352 let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
353
354 let b3 = batcher.clone();
355 let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
356
357 {
359 let len_now = recorded_calls.lock().unwrap().len();
360 assert_eq!(
361 len_now, 1,
362 "second invocation should not have started before unblocking first"
363 );
364 }
365
366 let _ = n1_tx.send(());
368
369 wait_until_len(&recorded_calls, 2).await;
371
372 let v1 = f1.await??;
374 assert_eq!(v1, 2);
375
376 let _ = n2_tx.send(());
378 let _ = n3_tx.send(());
379
380 let v2 = f2.await??;
381 let v3 = f3.await??;
382 assert_eq!(v2, 4);
383 assert_eq!(v3, 6);
384
385 let calls = recorded_calls.lock().unwrap().clone();
387 assert_eq!(calls.len(), 2);
388 assert_eq!(calls[0], vec![1]);
389 assert_eq!(calls[1], vec![2, 3]);
390
391 Ok(())
392 }
393
394 #[tokio::test(flavor = "current_thread")]
395 async fn respects_max_batch_size() -> Result<()> {
396 let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
397 let runner = TestRunner {
398 recorded_calls: recorded_calls.clone(),
399 };
400 let batcher = Arc::new(Batcher::new(
401 runner,
402 BatchingOptions {
403 max_batch_size: Some(2),
404 },
405 ));
406
407 let (n1_tx, n1_rx) = oneshot::channel::<()>();
408 let (n2_tx, n2_rx) = oneshot::channel::<()>();
409 let (n3_tx, n3_rx) = oneshot::channel::<()>();
410 let (n4_tx, n4_rx) = oneshot::channel::<()>();
411
412 let b1 = batcher.clone();
414 let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
415
416 wait_until_len(&recorded_calls, 1).await;
418
419 let b2 = batcher.clone();
421 let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
422
423 let b3 = batcher.clone();
426 let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
427
428 wait_until_len(&recorded_calls, 2).await;
430
431 {
433 let calls = recorded_calls.lock().unwrap();
434 assert_eq!(calls.len(), 2, "second batch should have started");
435 assert_eq!(calls[1], vec![2, 3], "second batch should contain [2, 3]");
436 }
437
438 let b4 = batcher.clone();
440 let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
441
442 sleep(Duration::from_millis(50)).await;
444 {
445 let len_now = recorded_calls.lock().unwrap().len();
446 assert_eq!(
447 len_now, 2,
448 "third batch should not start until all ongoing batches complete"
449 );
450 }
451
452 let _ = n1_tx.send(());
454
455 let v1 = f1.await??;
457 assert_eq!(v1, 2);
458
459 sleep(Duration::from_millis(50)).await;
461 {
462 let len_now = recorded_calls.lock().unwrap().len();
463 assert_eq!(
464 len_now, 2,
465 "third batch should not start until all ongoing batches complete"
466 );
467 }
468
469 let _ = n2_tx.send(());
471 let _ = n3_tx.send(());
472
473 let v2 = f2.await??;
474 let v3 = f3.await??;
475 assert_eq!(v2, 4);
476 assert_eq!(v3, 6);
477
478 wait_until_len(&recorded_calls, 3).await;
480
481 let _ = n4_tx.send(());
483 let v4 = f4.await??;
484 assert_eq!(v4, 8);
485
486 let calls = recorded_calls.lock().unwrap().clone();
488 assert_eq!(calls.len(), 3);
489 assert_eq!(calls[0], vec![1]);
490 assert_eq!(calls[1], vec![2, 3]);
491 assert_eq!(calls[2], vec![4]);
492
493 Ok(())
494 }
495
496 #[tokio::test(flavor = "current_thread")]
497 async fn tracks_multiple_concurrent_batches() -> Result<()> {
498 let recorded_calls = Arc::new(Mutex::new(Vec::<Vec<i64>>::new()));
499 let runner = TestRunner {
500 recorded_calls: recorded_calls.clone(),
501 };
502 let batcher = Arc::new(Batcher::new(
503 runner,
504 BatchingOptions {
505 max_batch_size: Some(2),
506 },
507 ));
508
509 let (n1_tx, n1_rx) = oneshot::channel::<()>();
510 let (n2_tx, n2_rx) = oneshot::channel::<()>();
511 let (n3_tx, n3_rx) = oneshot::channel::<()>();
512 let (n4_tx, n4_rx) = oneshot::channel::<()>();
513 let (n5_tx, n5_rx) = oneshot::channel::<()>();
514 let (n6_tx, n6_rx) = oneshot::channel::<()>();
515
516 let b1 = batcher.clone();
518 let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await });
519 wait_until_len(&recorded_calls, 1).await;
520
521 let b2 = batcher.clone();
523 let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await });
524 let b3 = batcher.clone();
525 let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await });
526 wait_until_len(&recorded_calls, 2).await;
527
528 let b4 = batcher.clone();
530 let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await });
531 let b5 = batcher.clone();
532 let f5 = tokio::spawn(async move { b5.run((5_i64, n5_rx)).await });
533 wait_until_len(&recorded_calls, 3).await;
534
535 let b6 = batcher.clone();
537 let f6 = tokio::spawn(async move { b6.run((6_i64, n6_rx)).await });
538
539 sleep(Duration::from_millis(50)).await;
541 {
542 let len_now = recorded_calls.lock().unwrap().len();
543 assert_eq!(
544 len_now, 3,
545 "fourth batch should not start with ongoing batches"
546 );
547 }
548
549 let _ = n2_tx.send(());
551 let _ = n3_tx.send(());
552 let v2 = f2.await??;
553 let v3 = f3.await??;
554 assert_eq!(v2, 4);
555 assert_eq!(v3, 6);
556
557 sleep(Duration::from_millis(50)).await;
558 {
559 let len_now = recorded_calls.lock().unwrap().len();
560 assert_eq!(
561 len_now, 3,
562 "batch [6] should still not start (batch 1 and batch [4,5] still ongoing)"
563 );
564 }
565
566 let _ = n4_tx.send(());
568 let _ = n5_tx.send(());
569 let v4 = f4.await??;
570 let v5 = f5.await??;
571 assert_eq!(v4, 8);
572 assert_eq!(v5, 10);
573
574 sleep(Duration::from_millis(50)).await;
575 {
576 let len_now = recorded_calls.lock().unwrap().len();
577 assert_eq!(
578 len_now, 3,
579 "batch [6] should still not start (batch 1 still ongoing)"
580 );
581 }
582
583 let _ = n1_tx.send(());
585 let v1 = f1.await??;
586 assert_eq!(v1, 2);
587
588 wait_until_len(&recorded_calls, 4).await;
589
590 let _ = n6_tx.send(());
592 let v6 = f6.await??;
593 assert_eq!(v6, 12);
594
595 let calls = recorded_calls.lock().unwrap().clone();
597 assert_eq!(calls.len(), 4);
598 assert_eq!(calls[0], vec![1]);
599 assert_eq!(calls[1], vec![2, 3]);
600 assert_eq!(calls[2], vec![4, 5]);
601 assert_eq!(calls[3], vec![6]);
602
603 Ok(())
604 }
605}