batch_aint_one/batcher.rs
1use std::{fmt::Display, hash::Hash, sync::Arc};
2
3use async_trait::async_trait;
4use tokio::sync::{mpsc, oneshot};
5use tracing::{span, Level, Span};
6
7use crate::{
8 batch::BatchItem,
9 error::BatchResult,
10 policies::{BatchingPolicy, Limits},
11 worker::{Worker, WorkerHandle},
12};
13
14/// Groups items to be processed in batches.
15///
16/// Takes inputs (`I`) grouped by a key (`K`) and processes multiple together in a batch. An output
17/// (`O`) is produced for each input.
18///
19/// Errors (`E`) can be returned from a batch.
20///
21/// Cheap to clone.
22#[derive(Debug)]
23pub struct Batcher<K, I, O = (), E = String>
24where
25 K: 'static + Send + Eq + Hash + Clone,
26 I: 'static + Send,
27 O: 'static + Send,
28 E: 'static + Send + Clone + Display,
29{
30 worker: Arc<WorkerHandle>,
31 item_tx: mpsc::Sender<BatchItem<K, I, O, E>>,
32}
33
34/// Process a batch of inputs for a given key.
35#[async_trait]
36pub trait Processor<K, I, O = (), E = String>
37where
38 E: Display,
39{
40 /// Process the batch.
41 ///
42 /// The order of the outputs in the returned `Vec` must be the same as the order of the inputs
43 /// in the given iterator.
44 async fn process(
45 &self,
46 key: K,
47 inputs: impl Iterator<Item = I> + Send,
48 ) -> std::result::Result<Vec<O>, E>;
49}
50
51impl<K, I, O, E> Batcher<K, I, O, E>
52where
53 K: 'static + Send + Eq + Hash + Clone,
54 I: 'static + Send,
55 O: 'static + Send,
56 E: 'static + Send + Clone + Display,
57{
58 /// Create a new batcher.
59 pub fn new<F>(processor: F, limits: Limits, batching_policy: BatchingPolicy) -> Self
60 where
61 F: 'static + Send + Clone + Processor<K, I, O, E>,
62 {
63 let (handle, item_tx) = Worker::spawn(processor, limits, batching_policy);
64
65 Self {
66 worker: Arc::new(handle),
67 item_tx,
68 }
69 }
70
71 /// Add an item to the batch and await the result.
72 pub async fn add(&self, key: K, input: I) -> BatchResult<O, E> {
73 // Record the span ID so we can link the shared processing span.
74 let requesting_span = Span::current().clone();
75
76 let (tx, rx) = oneshot::channel();
77 self.item_tx
78 .send(BatchItem {
79 key,
80 input,
81 tx,
82 requesting_span,
83 })
84 .await?;
85
86 let (output, batch_span) = rx.await?;
87
88 {
89 let link_back_span = span!(Level::INFO, "batch finished");
90 if let Some(span) = batch_span {
91 // WARNING: It's very important that we don't drop the span until _after_
92 // follows_from().
93 //
94 // If we did e.g. `.follows_from(span)` then the span would get converted into an ID
95 // and dropped. Any attempt to look up the span by ID _inside_ follows_from() would
96 // then panic, because the span will have been closed and no longer exist.
97 //
98 // Don't ask me how long this took me to debug.
99 link_back_span.follows_from(&span);
100 link_back_span.in_scope(|| {
101 // Do nothing. This span is just here to work around a Honeycomb limitation:
102 //
103 // If the batch span is linked to a parent span like so:
104 //
105 // parent_span_1 <-link- batch_span
106 //
107 // then in Honeycomb, the link is only shown on the batch span. It it not possible
108 // to click through to the batch span from the parent.
109 //
110 // So, here we link back to the batch to make this easier.
111 });
112 }
113 }
114 output
115 }
116}
117
118impl<K, I, O, E> Clone for Batcher<K, I, O, E>
119where
120 K: 'static + Send + Eq + Hash + Clone,
121 I: 'static + Send,
122 O: 'static + Send,
123 E: 'static + Send + Clone + Display,
124{
125 fn clone(&self) -> Self {
126 Self {
127 worker: self.worker.clone(),
128 item_tx: self.item_tx.clone(),
129 }
130 }
131}