batch_aint_one/
batcher.rs

1use std::{fmt::Debug, sync::Arc};
2
3use bon::bon;
4use tokio::sync::{mpsc, oneshot};
5use tracing::{Level, Span, span};
6
7use crate::{
8    batch::BatchItem,
9    error::BatchResult,
10    policies::{BatchingPolicy, Limits},
11    processor::Processor,
12    worker::{Worker, WorkerDropGuard, WorkerHandle},
13};
14
15/// Groups items to be processed in batches.
16///
17/// Takes inputs one at a time and sends them to a background worker task which groups them into
18/// batches according to the specified [`BatchingPolicy`] and [`Limits`], and processes them using
19/// the provided [`Processor`].
20///
21/// Cheap to clone. Cloned instances share the same background worker task.
22///
23/// ## Drop
24///
25/// When the last instance of a `Batcher` is dropped, the worker task will be aborted (ungracefully
26/// shut down).
27///
28/// If you want to shut down the worker gracefully, call [`WorkerHandle::shut_down()`].
29#[derive(Debug)]
30pub struct Batcher<P: Processor> {
31    name: String,
32    worker: Arc<WorkerHandle>,
33    worker_guard: Arc<WorkerDropGuard>,
34    item_tx: mpsc::Sender<BatchItem<P>>,
35}
36
37#[bon]
38impl<P: Processor> Batcher<P> {
39    /// Create a new batcher.
40    #[builder]
41    pub fn new(
42        name: impl Into<String>,
43        processor: P,
44        limits: Limits,
45        batching_policy: BatchingPolicy,
46    ) -> Self {
47        let name = name.into();
48
49        let (handle, worker_guard, item_tx) =
50            Worker::spawn(name.clone(), processor, limits, batching_policy);
51
52        Self {
53            name,
54            worker: Arc::new(handle),
55            worker_guard: Arc::new(worker_guard),
56            item_tx,
57        }
58    }
59
60    /// Add an item to be batched and processed, and await the result.
61    pub async fn add(&self, key: P::Key, input: P::Input) -> BatchResult<P::Output, P::Error> {
62        // Record the span ID so we can link the shared processing span.
63        let requesting_span = Span::current().clone();
64
65        let (tx, rx) = oneshot::channel();
66        self.item_tx
67            .send(BatchItem {
68                key,
69                input,
70                tx,
71                requesting_span,
72            })
73            .await?;
74
75        let (output, batch_span) = rx.await?;
76
77        {
78            let link_back_span = span!(Level::INFO, "batch finished");
79            if let Some(span) = batch_span {
80                // WARNING: It's very important that we don't drop the span until _after_
81                // follows_from().
82                //
83                // If we did e.g. `.follows_from(span)` then the span would get converted into an ID
84                // and dropped. Any attempt to look up the span by ID _inside_ follows_from() would
85                // then panic, because the span will have been closed and no longer exist.
86                //
87                // Don't ask me how long this took me to debug.
88                link_back_span.follows_from(&span);
89                link_back_span.in_scope(|| {
90                    // Do nothing. This span is just here to work around a Honeycomb limitation:
91                    //
92                    // If the batch span is linked to a parent span like so:
93                    //
94                    // parent_span_1 <-link- batch_span
95                    //
96                    // then in Honeycomb, the link is only shown on the batch span. It it not possible
97                    // to click through to the batch span from the parent.
98                    //
99                    // So, here we link back to the batch to make this easier.
100                });
101            }
102        }
103        output
104    }
105
106    /// Get a handle to the worker.
107    pub fn worker_handle(&self) -> Arc<WorkerHandle> {
108        Arc::clone(&self.worker)
109    }
110}
111
112impl<P: Processor> Clone for Batcher<P> {
113    fn clone(&self) -> Self {
114        Self {
115            name: self.name.clone(),
116            worker: self.worker.clone(),
117            worker_guard: self.worker_guard.clone(),
118            item_tx: self.item_tx.clone(),
119        }
120    }
121}