batch_aint_one/
batcher.rs

1use std::{fmt::Display, future::Future, hash::Hash, sync::Arc};
2
3use tokio::sync::{mpsc, oneshot};
4use tracing::{span, Level, Span};
5
6use crate::{
7    batch::BatchItem,
8    error::BatchResult,
9    policies::{BatchingPolicy, Limits},
10    worker::{Worker, WorkerHandle},
11};
12
13/// Groups items to be processed in batches.
14///
15/// Takes inputs (`I`) grouped by a key (`K`) and processes multiple together in a batch. An output
16/// (`O`) is produced for each input.
17///
18/// Errors (`E`) can be returned from a batch.
19///
20/// Cheap to clone.
21#[derive(Debug)]
22pub struct Batcher<K, I, O = (), E = String>
23where
24    K: 'static + Send + Eq + Hash + Clone,
25    I: 'static + Send,
26    O: 'static + Send,
27    E: 'static + Send + Clone + Display,
28{
29    worker: Arc<WorkerHandle>,
30    item_tx: mpsc::Sender<BatchItem<K, I, O, E>>,
31}
32
33/// Process a batch of inputs for a given key.
34pub trait Processor<K, I, O = (), E = String>
35where
36    E: Display,
37{
38    /// Process the batch.
39    ///
40    /// The order of the outputs in the returned `Vec` must be the same as the order of the inputs
41    /// in the given iterator.
42    fn process(
43        &self,
44        key: K,
45        inputs: impl Iterator<Item = I> + Send,
46    ) -> impl Future<Output = std::result::Result<Vec<O>, E>> + Send;
47}
48
49impl<K, I, O, E> Batcher<K, I, O, E>
50where
51    K: 'static + Send + Eq + Hash + Clone,
52    I: 'static + Send,
53    O: 'static + Send,
54    E: 'static + Send + Clone + Display,
55{
56    /// Create a new batcher.
57    pub fn new<F>(processor: F, limits: Limits, batching_policy: BatchingPolicy) -> Self
58    where
59        F: 'static + Send + Clone + Processor<K, I, O, E>,
60    {
61        let (handle, item_tx) = Worker::spawn(processor, limits, batching_policy);
62
63        Self {
64            worker: Arc::new(handle),
65            item_tx,
66        }
67    }
68
69    /// Add an item to the batch and await the result.
70    pub async fn add(&self, key: K, input: I) -> BatchResult<O, E> {
71        // Record the span ID so we can link the shared processing span.
72        let requesting_span = Span::current().clone();
73
74        let (tx, rx) = oneshot::channel();
75        self.item_tx
76            .send(BatchItem {
77                key,
78                input,
79                tx,
80                requesting_span,
81            })
82            .await?;
83
84        let (output, batch_span) = rx.await?;
85
86        {
87            let link_back_span = span!(Level::INFO, "batch finished");
88            if let Some(span) = batch_span {
89                // WARNING: It's very important that we don't drop the span until _after_
90                // follows_from().
91                //
92                // If we did e.g. `.follows_from(span)` then the span would get converted into an ID
93                // and dropped. Any attempt to look up the span by ID _inside_ follows_from() would
94                // then panic, because the span will have been closed and no longer exist.
95                //
96                // Don't ask me how long this took me to debug.
97                link_back_span.follows_from(&span);
98                link_back_span.in_scope(|| {
99                    // Do nothing. This span is just here to work around a Honeycomb limitation:
100                    //
101                    // If the batch span is linked to a parent span like so:
102                    //
103                    // parent_span_1 <-link- batch_span
104                    //
105                    // then in Honeycomb, the link is only shown on the batch span. It it not possible
106                    // to click through to the batch span from the parent.
107                    //
108                    // So, here we link back to the batch to make this easier.
109                });
110            }
111        }
112        output
113    }
114}
115
116impl<K, I, O, E> Clone for Batcher<K, I, O, E>
117where
118    K: 'static + Send + Eq + Hash + Clone,
119    I: 'static + Send,
120    O: 'static + Send,
121    E: 'static + Send + Clone + Display,
122{
123    fn clone(&self) -> Self {
124        Self {
125            worker: self.worker.clone(),
126            item_tx: self.item_tx.clone(),
127        }
128    }
129}