Skip to main content

batch_aint_one/
worker.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4};
5
6use tokio::{
7    sync::{mpsc, oneshot},
8    task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13    BatchError,
14    batch::BatchItem,
15    batch_inner::Generation,
16    batch_queue::BatchQueue,
17    limits::Limits,
18    policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
19    processor::Processor,
20};
21
22pub(crate) struct Worker<P: Processor> {
23    batcher_name: String,
24
25    /// Used to receive new batch items.
26    item_rx: mpsc::Receiver<BatchItem<P>>,
27    /// The callback to process a batch of inputs.
28    processor: P,
29
30    /// Used to signal that a batch for key `K` should be processed.
31    msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
32    /// Receives signals to process a batch for key `K`.
33    msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
34
35    /// Used to send messages to the worker related to shutdown.
36    shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
37
38    /// Used to signal to listeners that the worker has shut down.
39    shutdown_notifiers: Vec<oneshot::Sender<()>>,
40
41    shutting_down: bool,
42
43    limits: Limits,
44    /// Controls when to start processing a batch.
45    batching_policy: BatchingPolicy,
46
47    /// Unprocessed batches, grouped by key `K`.
48    batch_queues: HashMap<P::Key, BatchQueue<P>>,
49}
50
51#[derive(Debug)]
52pub(crate) enum Message<K, E: Display + Debug> {
53    TimedOut(K, Generation),
54    ResourcesAcquired(K, Generation),
55    ResourceAcquisitionFailed(K, Generation, BatchError<E>),
56    Finished(K, BatchTerminalState),
57}
58
59#[derive(Debug)]
60pub(crate) enum BatchTerminalState {
61    Processed,
62    FailedAcquiring,
63}
64
65pub(crate) enum ShutdownMessage {
66    Register(ShutdownNotifier),
67    ShutDown,
68}
69
70pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
71
72/// A handle to the worker task.
73///
74/// Used for shutting down the worker and waiting for it to finish.
75#[derive(Debug, Clone)]
76pub struct WorkerHandle {
77    shutdown_tx: mpsc::Sender<ShutdownMessage>,
78}
79
80/// Aborts the worker task when dropped.
81#[derive(Debug)]
82pub(crate) struct WorkerDropGuard {
83    handle: JoinHandle<()>,
84}
85
86impl<P: Processor> Worker<P> {
87    pub fn spawn(
88        batcher_name: String,
89        processor: P,
90        limits: Limits,
91        batching_policy: BatchingPolicy,
92    ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
93        // These channel sizes are somewhat arbitrary - they just need to be big enough to avoid
94        // backpressure in normal operation.
95        let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
96        let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
97
98        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
99
100        let mut worker = Worker {
101            batcher_name,
102
103            item_rx,
104            processor,
105
106            msg_tx,
107            msg_rx,
108
109            shutdown_notifier_rx: shutdown_rx,
110            shutdown_notifiers: Vec::new(),
111
112            shutting_down: false,
113
114            limits,
115            batching_policy,
116
117            batch_queues: HashMap::new(),
118        };
119
120        let handle = tokio::spawn(async move {
121            worker.run().await;
122        });
123
124        (
125            WorkerHandle { shutdown_tx },
126            WorkerDropGuard { handle },
127            item_tx,
128        )
129    }
130
131    /// Add an item to the batch.
132    fn add(&mut self, item: BatchItem<P>) {
133        let key = item.key.clone();
134
135        let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
136            BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
137        });
138
139        match self.batching_policy.on_add(batch_queue) {
140            OnAdd::AddAndProcess => {
141                batch_queue.push(item);
142
143                self.process_next_batch(&key);
144            }
145            OnAdd::AddAndAcquireResources => {
146                batch_queue.push(item);
147
148                batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
149            }
150            OnAdd::AddAndProcessAfter(duration) => {
151                batch_queue.push(item);
152
153                batch_queue.process_after(duration, self.msg_tx.clone());
154            }
155            OnAdd::Add => {
156                batch_queue.push(item);
157            }
158            OnAdd::Reject(reason) => {
159                if item
160                    .tx
161                    .send((Err(BatchError::Rejected(reason)), None))
162                    .is_err()
163                {
164                    // Whatever was waiting for the output must have shut down. Presumably it
165                    // doesn't care anymore, but we log here anyway. There's not much else we can do.
166                    debug!(
167                        "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
168                        self.batcher_name
169                    );
170                }
171            }
172        }
173    }
174
175    /// Get the batch queue for the given key, which should always exist when handling an event
176    /// for that key.
177    fn queue_mut<'q>(
178        batch_queues: &'q mut HashMap<P::Key, BatchQueue<P>>,
179        key: &P::Key,
180    ) -> &'q mut BatchQueue<P> {
181        batch_queues.get_mut(key).expect("batch queue should exist")
182    }
183
184    fn process_generation(&mut self, key: P::Key, generation: Generation) {
185        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
186
187        batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
188    }
189
190    fn process_next_ready_batch(&mut self, key: &P::Key) {
191        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
192
193        batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
194    }
195
196    fn process_next_batch(&mut self, key: &P::Key) {
197        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
198
199        batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
200    }
201
202    fn on_timeout(&mut self, key: P::Key, generation: Generation) {
203        // Unlike the other message handlers, the queue may have been removed: timers are not
204        // tracked by the in-flight counters, so a TimedOut message can outlive its queue.
205        let Some(batch_queue) = self.batch_queues.get_mut(&key) else {
206            debug!("Timeout for a batch queue which no longer exists. Ignoring.");
207            return;
208        };
209
210        match self.batching_policy.on_timeout(generation, batch_queue) {
211            OnGenerationEvent::Process => {
212                self.process_generation(key, generation);
213            }
214            OnGenerationEvent::DoNothing => {}
215        }
216    }
217
218    fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
219        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
220
221        batch_queue.mark_resource_acquisition_finished();
222
223        match self
224            .batching_policy
225            .on_resources_acquired(generation, batch_queue)
226        {
227            OnGenerationEvent::Process => {
228                self.process_generation(key, generation);
229            }
230            OnGenerationEvent::DoNothing => {}
231        }
232    }
233
234    fn on_resource_acquisition_failed(
235        &mut self,
236        key: P::Key,
237        generation: Generation,
238        err: BatchError<P::Error>,
239    ) {
240        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
241
242        batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
243    }
244
245    fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
246        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
247
248        match terminal_state {
249            BatchTerminalState::Processed => {
250                batch_queue.mark_processed();
251            }
252            BatchTerminalState::FailedAcquiring => {
253                batch_queue.mark_resource_acquisition_finished();
254            }
255        }
256
257        match self.batching_policy.on_finish(batch_queue) {
258            OnFinish::ProcessNextReady => {
259                self.process_next_ready_batch(key);
260            }
261            OnFinish::ProcessNext => {
262                self.process_next_batch(key);
263            }
264            OnFinish::DoNothing => {}
265        }
266
267        // Remove the queue for idle keys, otherwise the map grows unboundedly as new keys are
268        // seen. A key can only become idle when a batch finishes, so this is the only place we
269        // need to do this.
270        if Self::queue_mut(&mut self.batch_queues, key).is_idle() {
271            self.batch_queues.remove(key);
272        }
273    }
274
275    fn ready_to_shut_down(&self) -> bool {
276        self.shutting_down
277            && self.batch_queues.values().all(|q| q.is_empty())
278            && !self.batch_queues.values().any(|q| q.is_processing())
279    }
280
281    /// Start running the worker event loop.
282    async fn run(&mut self) {
283        loop {
284            tokio::select! {
285                Some(msg) = self.shutdown_notifier_rx.recv() => {
286                    match msg {
287                        ShutdownMessage::Register(notifier) => {
288                           self.shutdown_notifiers.push(notifier.0);
289                        }
290                        ShutdownMessage::ShutDown => {
291                            self.shutting_down = true;
292                        }
293                    }
294                }
295
296                Some(item) = self.item_rx.recv() => {
297                    self.add(item);
298                }
299
300                Some(msg) = self.msg_rx.recv() => {
301                    match msg {
302                        Message::ResourcesAcquired(key, generation) => {
303                            self.on_resource_acquired(key, generation);
304                        }
305                        Message::ResourceAcquisitionFailed(key, generation, err) => {
306                            self.on_resource_acquisition_failed(key, generation, err);
307                        }
308                        Message::TimedOut(key, generation) => {
309                            self.on_timeout(key, generation);
310                        }
311                        Message::Finished(key, terminal_state) => {
312                            self.on_batch_finished(&key, terminal_state);
313                        }
314                    }
315                }
316            }
317
318            if self.ready_to_shut_down() {
319                info!("Batch worker '{}' is shutting down", &self.batcher_name);
320                return;
321            }
322        }
323    }
324}
325
326impl WorkerHandle {
327    /// Signal the worker to shut down after processing any in-flight batches.
328    ///
329    /// New items are still accepted while shutting down, and the worker only shuts down once all
330    /// keys are idle. This means shutdown may never complete if:
331    ///
332    /// - new items keep being added, or
333    /// - a batch never meets its policy's processing condition, e.g. when using the
334    ///   [`Size`](crate::BatchingPolicy::Size) policy, a final partial batch may wait
335    ///   indefinitely for more items.
336    ///
337    /// Stopping the flow of new items is expected to be handled by the caller, e.g. by shutting
338    /// down the message handlers which add items before shutting down the batcher.
339    pub async fn shut_down(&self) {
340        info!("Sending shut down signal to batch worker");
341        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
342        let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
343    }
344
345    /// Wait for the worker to finish.
346    pub async fn wait_for_shutdown(&self) {
347        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
348        let (notifier_tx, notifier_rx) = oneshot::channel();
349        let _ = self
350            .shutdown_tx
351            .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
352            .await;
353        // Wait for the notifier to be dropped.
354        let _ = notifier_rx.await;
355    }
356}
357
358impl Drop for WorkerDropGuard {
359    fn drop(&mut self) {
360        info!("Aborting batch worker");
361        self.handle.abort();
362    }
363}
364
365#[cfg(test)]
366mod test {
367    use tokio::sync::oneshot;
368    use tracing::Span;
369
370    use super::*;
371
372    #[derive(Debug, Clone)]
373    struct SimpleBatchProcessor;
374
375    impl Processor for SimpleBatchProcessor {
376        type Key = String;
377        type Input = String;
378        type Output = String;
379        type Error = String;
380        type Resources = ();
381
382        async fn acquire_resources(&self, _key: String) -> Result<(), String> {
383            Ok(())
384        }
385
386        async fn process(
387            &self,
388            _key: String,
389            inputs: impl Iterator<Item = String> + Send,
390            _resources: (),
391        ) -> Result<Vec<String>, String> {
392            Ok(inputs.map(|s| s + " processed").collect())
393        }
394    }
395
396    /// Construct a worker directly, without spawning the run loop, so tests can drive it
397    /// manually and inspect its state.
398    fn new_worker() -> Worker<SimpleBatchProcessor> {
399        let (_item_tx, item_rx) = mpsc::channel(1);
400        let (msg_tx, msg_rx) = mpsc::channel(1);
401        let (_shutdown_tx, shutdown_rx) = mpsc::channel(1);
402
403        Worker {
404            batcher_name: "test".to_string(),
405            item_rx,
406            processor: SimpleBatchProcessor,
407            msg_tx,
408            msg_rx,
409            shutdown_notifier_rx: shutdown_rx,
410            shutdown_notifiers: Vec::new(),
411            shutting_down: false,
412            limits: Limits::builder().max_batch_size(1).build(),
413            batching_policy: BatchingPolicy::Size,
414            batch_queues: HashMap::new(),
415        }
416    }
417
418    #[tokio::test]
419    async fn removes_batch_queue_when_key_becomes_idle() {
420        let mut worker = new_worker();
421
422        let (tx, rx) = oneshot::channel();
423        worker.add(BatchItem {
424            key: "K1".to_string(),
425            input: "I1".to_string(),
426            submitted_at: tokio::time::Instant::now(),
427            tx,
428            requesting_span: Span::none(),
429        });
430
431        // max_batch_size is 1, so the batch processes immediately.
432        let output = rx.await.unwrap().0.unwrap();
433        assert_eq!(output, "I1 processed");
434
435        // Handle the Finished message, as the run loop would.
436        let msg = worker.msg_rx.recv().await.unwrap();
437        let Message::Finished(key, terminal_state) = msg else {
438            panic!("expected Finished message, got {:?}", msg);
439        };
440        worker.on_batch_finished(&key, terminal_state);
441
442        assert!(
443            worker.batch_queues.is_empty(),
444            "the batch queue for an idle key should be removed"
445        );
446    }
447
448    #[tokio::test]
449    async fn ignores_timeout_for_removed_batch_queue() {
450        // A timer can fire and enqueue a TimedOut message, after which the batch is processed
451        // anyway (e.g. it filled up) and the queue is removed once the key is idle. The stale
452        // TimedOut message must be ignored, not panic the worker.
453        let mut worker = new_worker();
454
455        worker.on_timeout("K1".to_string(), Generation::default());
456    }
457
458    #[tokio::test]
459    async fn simple_test_over_channel() {
460        let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
461            "test".to_string(),
462            SimpleBatchProcessor,
463            Limits::builder().max_batch_size(2).build(),
464            BatchingPolicy::Size,
465        );
466
467        let rx1 = {
468            let (tx, rx) = oneshot::channel();
469            item_tx
470                .send(BatchItem {
471                    key: "K1".to_string(),
472                    input: "I1".to_string(),
473                    submitted_at: tokio::time::Instant::now(),
474                    tx,
475                    requesting_span: Span::none(),
476                })
477                .await
478                .unwrap();
479
480            rx
481        };
482
483        let rx2 = {
484            let (tx, rx) = oneshot::channel();
485            item_tx
486                .send(BatchItem {
487                    key: "K1".to_string(),
488                    input: "I2".to_string(),
489                    submitted_at: tokio::time::Instant::now(),
490                    tx,
491                    requesting_span: Span::none(),
492                })
493                .await
494                .unwrap();
495
496            rx
497        };
498
499        let o1 = rx1.await.unwrap().0.unwrap();
500        let o2 = rx2.await.unwrap().0.unwrap();
501
502        assert_eq!(o1, "I1 processed".to_string());
503        assert_eq!(o2, "I2 processed".to_string());
504    }
505}