batch_aint_one/
batcher.rs

1use std::{fmt::Display, hash::Hash, sync::Arc};
2
3use async_trait::async_trait;
4use tokio::sync::{mpsc, oneshot};
5use tracing::{span, Level, Span};
6
7use crate::{
8    batch::BatchItem,
9    error::BatchResult,
10    policies::{BatchingPolicy, Limits},
11    worker::{Worker, WorkerHandle},
12};
13
14/// Groups items to be processed in batches.
15///
16/// Takes inputs (`I`) grouped by a key (`K`) and processes multiple together in a batch. An output
17/// (`O`) is produced for each input.
18///
19/// Errors (`E`) can be returned from a batch.
20///
21/// Cheap to clone.
22#[derive(Debug)]
23pub struct Batcher<K, I, O = (), E = String>
24where
25    K: 'static + Send + Eq + Hash + Clone,
26    I: 'static + Send,
27    O: 'static + Send,
28    E: 'static + Send + Clone + Display,
29{
30    worker: Arc<WorkerHandle>,
31    item_tx: mpsc::Sender<BatchItem<K, I, O, E>>,
32}
33
34/// Process a batch of inputs for a given key.
35#[async_trait]
36pub trait Processor<K, I, O = (), E = String>
37where
38    E: Display,
39{
40    /// Process the batch.
41    ///
42    /// The order of the outputs in the returned `Vec` must be the same as the order of the inputs
43    /// in the given iterator.
44    async fn process(
45        &self,
46        key: K,
47        inputs: impl Iterator<Item = I> + Send,
48    ) -> std::result::Result<Vec<O>, E>;
49}
50
51impl<K, I, O, E> Batcher<K, I, O, E>
52where
53    K: 'static + Send + Eq + Hash + Clone,
54    I: 'static + Send,
55    O: 'static + Send,
56    E: 'static + Send + Clone + Display,
57{
58    /// Create a new batcher.
59    pub fn new<F>(processor: F, limits: Limits, batching_policy: BatchingPolicy) -> Self
60    where
61        F: 'static + Send + Clone + Processor<K, I, O, E>,
62    {
63        let (handle, item_tx) = Worker::spawn(processor, limits, batching_policy);
64
65        Self {
66            worker: Arc::new(handle),
67            item_tx,
68        }
69    }
70
71    /// Add an item to the batch and await the result.
72    pub async fn add(&self, key: K, input: I) -> BatchResult<O, E> {
73        // Record the span ID so we can link the shared processing span.
74        let requesting_span = Span::current().clone();
75
76        let (tx, rx) = oneshot::channel();
77        self.item_tx
78            .send(BatchItem {
79                key,
80                input,
81                tx,
82                requesting_span,
83            })
84            .await?;
85
86        let (output, batch_span) = rx.await?;
87
88        {
89            let link_back_span = span!(Level::INFO, "batch finished");
90            if let Some(span) = batch_span {
91                // WARNING: It's very important that we don't drop the span until _after_
92                // follows_from().
93                //
94                // If we did e.g. `.follows_from(span)` then the span would get converted into an ID
95                // and dropped. Any attempt to look up the span by ID _inside_ follows_from() would
96                // then panic, because the span will have been closed and no longer exist.
97                //
98                // Don't ask me how long this took me to debug.
99                link_back_span.follows_from(&span);
100                link_back_span.in_scope(|| {
101                    // Do nothing. This span is just here to work around a Honeycomb limitation:
102                    //
103                    // If the batch span is linked to a parent span like so:
104                    //
105                    // parent_span_1 <-link- batch_span
106                    //
107                    // then in Honeycomb, the link is only shown on the batch span. It it not possible
108                    // to click through to the batch span from the parent.
109                    //
110                    // So, here we link back to the batch to make this easier.
111                });
112            }
113        }
114        output
115    }
116}
117
118impl<K, I, O, E> Clone for Batcher<K, I, O, E>
119where
120    K: 'static + Send + Eq + Hash + Clone,
121    I: 'static + Send,
122    O: 'static + Send,
123    E: 'static + Send + Clone + Display,
124{
125    fn clone(&self) -> Self {
126        Self {
127            worker: self.worker.clone(),
128            item_tx: self.item_tx.clone(),
129        }
130    }
131}