batch_aint_one/
worker.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4};
5
6use tokio::{
7    sync::{mpsc, oneshot},
8    task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13    BatchError,
14    batch::BatchItem,
15    batch_inner::Generation,
16    batch_queue::BatchQueue,
17    policies::{BatchingPolicy, Limits, OnAdd, ProcessAction},
18    processor::Processor,
19};
20
21pub(crate) struct Worker<P: Processor> {
22    batcher_name: String,
23
24    /// Used to receive new batch items.
25    item_rx: mpsc::Receiver<BatchItem<P>>,
26    /// The callback to process a batch of inputs.
27    processor: P,
28
29    /// Used to signal that a batch for key `K` should be processed.
30    msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
31    /// Receives signals to process a batch for key `K`.
32    msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
33
34    /// Used to send messages to the worker related to shutdown.
35    shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
36
37    /// Used to signal to listeners that the worker has shut down.
38    shutdown_notifiers: Vec<oneshot::Sender<()>>,
39
40    shutting_down: bool,
41
42    limits: Limits,
43    /// Controls when to start processing a batch.
44    batching_policy: BatchingPolicy,
45
46    /// Unprocessed batches, grouped by key `K`.
47    batch_queues: HashMap<P::Key, BatchQueue<P>>,
48}
49
50#[derive(Debug)]
51pub(crate) enum Message<K, E: Display + Debug> {
52    TimedOut(K, Generation),
53    ResourcesAcquired(K, Generation),
54    ResourceAcquisitionFailed(K, Generation, BatchError<E>),
55    Finished(K),
56}
57
58pub(crate) enum ShutdownMessage {
59    Register(ShutdownNotifier),
60    ShutDown,
61}
62
63pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
64
65/// A handle to the worker task.
66///
67/// Used for shutting down the worker and waiting for it to finish.
68#[derive(Debug, Clone)]
69pub struct WorkerHandle {
70    shutdown_tx: mpsc::Sender<ShutdownMessage>,
71}
72
73/// Aborts the worker task when dropped.
74#[derive(Debug)]
75pub(crate) struct WorkerDropGuard {
76    handle: JoinHandle<()>,
77}
78
79impl<P: Processor> Worker<P> {
80    pub fn spawn(
81        batcher_name: String,
82        processor: P,
83        limits: Limits,
84        batching_policy: BatchingPolicy,
85    ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
86        let (item_tx, item_rx) = mpsc::channel(10);
87
88        let (timeout_tx, timeout_rx) = mpsc::channel(10);
89
90        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
91
92        let mut worker = Worker {
93            batcher_name,
94
95            item_rx,
96            processor,
97
98            msg_tx: timeout_tx,
99            msg_rx: timeout_rx,
100
101            shutdown_notifier_rx: shutdown_rx,
102            shutdown_notifiers: Vec::new(),
103
104            shutting_down: false,
105
106            limits,
107            batching_policy,
108
109            batch_queues: HashMap::new(),
110        };
111
112        let handle = tokio::spawn(async move {
113            worker.run().await;
114        });
115
116        (
117            WorkerHandle { shutdown_tx },
118            WorkerDropGuard { handle },
119            item_tx,
120        )
121    }
122
123    /// Add an item to the batch.
124    fn add(&mut self, item: BatchItem<P>) {
125        let key = item.key.clone();
126
127        let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
128            BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
129        });
130
131        match self.batching_policy.on_add(batch_queue) {
132            OnAdd::AddAndProcess => {
133                batch_queue.push(item);
134
135                self.process_next_batch(&key);
136            }
137            OnAdd::AddAndAcquireResources => {
138                batch_queue.push(item);
139
140                batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
141            }
142            OnAdd::AddAndProcessAfter(duration) => {
143                batch_queue.push(item);
144
145                batch_queue.process_after(duration, self.msg_tx.clone());
146            }
147            OnAdd::Add => {
148                batch_queue.push(item);
149            }
150            OnAdd::Reject(reason) => {
151                if item
152                    .tx
153                    .send((Err(BatchError::Rejected(reason)), None))
154                    .is_err()
155                {
156                    // Whatever was waiting for the output must have shut down. Presumably it
157                    // doesn't care anymore, but we log here anyway. There's not much else we can do.
158                    debug!(
159                        "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
160                        self.batcher_name
161                    );
162                }
163            }
164        }
165    }
166
167    fn process_generation(&mut self, key: P::Key, generation: Generation) {
168        let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
169
170        if let Some(batch) = batch_queue.take_generation(generation) {
171            let on_finished = self.msg_tx.clone();
172
173            batch.process(self.processor.clone(), on_finished);
174        }
175    }
176
177    fn process_next_batch(&mut self, key: &P::Key) {
178        let batch_queue = self
179            .batch_queues
180            .get_mut(key)
181            .expect("batch queue should exist");
182
183        if let Some(batch) = batch_queue.take_next_ready_batch() {
184            let on_finished = self.msg_tx.clone();
185
186            batch.process(self.processor.clone(), on_finished);
187
188            debug_assert!(
189                batch_queue.within_processing_capacity(),
190                "processing count should not exceed max key concurrency"
191            );
192        }
193    }
194
195    fn on_timeout(&mut self, key: P::Key, generation: Generation) {
196        let batch_queue = self
197            .batch_queues
198            .get_mut(&key)
199            .expect("batch queue should exist");
200
201        match self.batching_policy.on_timeout(generation, batch_queue) {
202            ProcessAction::Process => {
203                self.process_generation(key, generation);
204            }
205            ProcessAction::DoNothing => {}
206        }
207    }
208
209    fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
210        let batch_queue = self
211            .batch_queues
212            .get_mut(&key)
213            .expect("batch queue should exist");
214
215        match self
216            .batching_policy
217            .on_resources_acquired(generation, batch_queue)
218        {
219            ProcessAction::Process => {
220                self.process_generation(key, generation);
221            }
222            ProcessAction::DoNothing => {}
223        }
224    }
225
226    fn on_batch_finished(&mut self, key: &P::Key) {
227        let batch_queue = self
228            .batch_queues
229            .get_mut(key)
230            .expect("batch queue should exist");
231
232        match self.batching_policy.on_finish(batch_queue) {
233            ProcessAction::Process => {
234                self.process_next_batch(key);
235            }
236            ProcessAction::DoNothing => {}
237        }
238    }
239
240    fn fail_batch(&mut self, key: P::Key, generation: Generation, err: BatchError<P::Error>) {
241        let batch_queue = self
242            .batch_queues
243            .get_mut(&key)
244            .expect("batch queue should exist");
245
246        if let Some(batch) = batch_queue.take_generation(generation) {
247            let on_finished = self.msg_tx.clone();
248            batch.fail(err, on_finished)
249        }
250    }
251
252    fn ready_to_shut_down(&self) -> bool {
253        self.shutting_down
254            && self.batch_queues.values().all(|q| q.is_empty())
255            && !self.batch_queues.values().any(|q| q.is_processing())
256    }
257
258    /// Start running the worker event loop.
259    async fn run(&mut self) {
260        loop {
261            tokio::select! {
262                Some(msg) = self.shutdown_notifier_rx.recv() => {
263                    match msg {
264                        ShutdownMessage::Register(notifier) => {
265                           self.shutdown_notifiers.push(notifier.0);
266                        }
267                        ShutdownMessage::ShutDown => {
268                            self.shutting_down = true;
269                        }
270                    }
271                }
272
273                Some(item) = self.item_rx.recv() => {
274                    self.add(item);
275                }
276
277                Some(msg) = self.msg_rx.recv() => {
278                    match msg {
279                        Message::ResourcesAcquired(key, generation) => {
280                            self.on_resource_acquired(key, generation);
281                        }
282                        Message::ResourceAcquisitionFailed(key, generation, err) => {
283                            self.fail_batch(key, generation, err);
284                        }
285                        Message::TimedOut(key, generation) => {
286                            self.on_timeout(key, generation);
287                        }
288                        Message::Finished(key) => {
289                            self.on_batch_finished(&key);
290                        }
291                    }
292                }
293            }
294
295            if self.ready_to_shut_down() {
296                info!("Batch worker '{}' is shutting down", &self.batcher_name);
297                return;
298            }
299        }
300    }
301}
302
303impl WorkerHandle {
304    /// Signal the worker to shut down after processing any in-flight batches.
305    ///
306    /// Note that when using the Size policy this may wait indefinitely if no new items are added.
307    pub async fn shut_down(&self) {
308        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
309        let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
310    }
311
312    /// Wait for the worker to finish.
313    pub async fn wait_for_shutdown(&self) {
314        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
315        let (notifier_tx, notifier_rx) = oneshot::channel();
316        let _ = self
317            .shutdown_tx
318            .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
319            .await;
320        // Wait for the notifier to be dropped.
321        let _ = notifier_rx.await;
322    }
323}
324
325impl Drop for WorkerDropGuard {
326    fn drop(&mut self) {
327        self.handle.abort();
328    }
329}
330
331#[cfg(test)]
332mod test {
333    use tokio::sync::oneshot;
334    use tracing::Span;
335
336    use super::*;
337
338    #[derive(Debug, Clone)]
339    struct SimpleBatchProcessor;
340
341    impl Processor for SimpleBatchProcessor {
342        type Key = String;
343        type Input = String;
344        type Output = String;
345        type Error = String;
346        type Resources = ();
347
348        async fn acquire_resources(&self, _key: String) -> Result<(), String> {
349            Ok(())
350        }
351
352        async fn process(
353            &self,
354            _key: String,
355            inputs: impl Iterator<Item = String> + Send,
356            _resources: (),
357        ) -> Result<Vec<String>, String> {
358            Ok(inputs.map(|s| s + " processed").collect())
359        }
360    }
361
362    #[tokio::test]
363    async fn simple_test_over_channel() {
364        let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
365            "test".to_string(),
366            SimpleBatchProcessor,
367            Limits::default().with_max_batch_size(2),
368            BatchingPolicy::Size,
369        );
370
371        let rx1 = {
372            let (tx, rx) = oneshot::channel();
373            item_tx
374                .send(BatchItem {
375                    key: "K1".to_string(),
376                    input: "I1".to_string(),
377                    tx,
378                    requesting_span: Span::none(),
379                })
380                .await
381                .unwrap();
382
383            rx
384        };
385
386        let rx2 = {
387            let (tx, rx) = oneshot::channel();
388            item_tx
389                .send(BatchItem {
390                    key: "K1".to_string(),
391                    input: "I2".to_string(),
392                    tx,
393                    requesting_span: Span::none(),
394                })
395                .await
396                .unwrap();
397
398            rx
399        };
400
401        let o1 = rx1.await.unwrap().0.unwrap();
402        let o2 = rx2.await.unwrap().0.unwrap();
403
404        assert_eq!(o1, "I1 processed".to_string());
405        assert_eq!(o2, "I2 processed".to_string());
406    }
407}