Skip to main content

batch_aint_one/
worker.rs

1use std::{collections::HashMap, fmt::Debug};
2
3use tokio::{
4    sync::{mpsc, oneshot},
5    task::JoinHandle,
6};
7use tracing::{Span, debug, info};
8
9use crate::{
10    BatchError,
11    batch::BatchItem,
12    batch_inner::Generation,
13    batch_queue::BatchQueue,
14    limits::Limits,
15    policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
16    processor::Processor,
17};
18
19pub(crate) struct Worker<P: Processor> {
20    batcher_name: String,
21
22    /// Used to receive new batch items.
23    item_rx: mpsc::Receiver<BatchItem<P>>,
24    /// The callback to process a batch of inputs.
25    processor: P,
26
27    /// Used to signal that a batch for key `K` should be processed.
28    msg_tx: mpsc::Sender<Message<P>>,
29    /// Receives signals to process a batch for key `K`.
30    msg_rx: mpsc::Receiver<Message<P>>,
31
32    /// Used to send messages to the worker related to shutdown.
33    shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
34
35    /// Used to signal to listeners that the worker has shut down.
36    shutdown_notifiers: Vec<oneshot::Sender<()>>,
37
38    shutting_down: bool,
39
40    limits: Limits,
41    /// Controls when to start processing a batch.
42    batching_policy: BatchingPolicy,
43
44    /// Unprocessed batches, grouped by key `K`.
45    batch_queues: HashMap<P::Key, BatchQueue<P>>,
46}
47
48/// Events which drive the worker.
49///
50/// Spawned tasks (resource acquisition, timeouts) report their outcomes to the worker by
51/// sending a message, and the worker performs the resulting batch state transitions when
52/// handling them, so it observes all events in message order.
53pub(crate) enum Message<P: Processor> {
54    TimedOut(P::Key, Generation),
55    ResourcesAcquired(P::Key, Generation, P::Resources, Span),
56    ResourceAcquisitionFailed(P::Key, Generation, BatchError<P::Error>),
57    Finished(P::Key),
58}
59
60impl<P: Processor> Debug for Message<P> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Message::TimedOut(key, generation) => f
64                .debug_tuple("TimedOut")
65                .field(key)
66                .field(generation)
67                .finish(),
68            Message::ResourcesAcquired(key, generation, _, _) => f
69                .debug_tuple("ResourcesAcquired")
70                .field(key)
71                .field(generation)
72                .field(&"<Resources>")
73                .finish(),
74            Message::ResourceAcquisitionFailed(key, generation, err) => f
75                .debug_tuple("ResourceAcquisitionFailed")
76                .field(key)
77                .field(generation)
78                .field(err)
79                .finish(),
80            Message::Finished(key) => f.debug_tuple("Finished").field(key).finish(),
81        }
82    }
83}
84
85pub(crate) enum ShutdownMessage {
86    Register(ShutdownNotifier),
87    ShutDown,
88}
89
90pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
91
92/// A handle to the worker task.
93///
94/// Used for shutting down the worker and waiting for it to finish.
95#[derive(Debug, Clone)]
96pub struct WorkerHandle {
97    shutdown_tx: mpsc::Sender<ShutdownMessage>,
98}
99
100/// Aborts the worker task when dropped.
101#[derive(Debug)]
102pub(crate) struct WorkerDropGuard {
103    handle: JoinHandle<()>,
104}
105
106impl<P: Processor> Worker<P> {
107    pub fn spawn(
108        batcher_name: String,
109        processor: P,
110        limits: Limits,
111        batching_policy: BatchingPolicy,
112    ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
113        // These channel sizes are somewhat arbitrary - they just need to be big enough to avoid
114        // backpressure in normal operation.
115        let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
116        let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
117
118        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
119
120        let mut worker = Worker {
121            batcher_name,
122
123            item_rx,
124            processor,
125
126            msg_tx,
127            msg_rx,
128
129            shutdown_notifier_rx: shutdown_rx,
130            shutdown_notifiers: Vec::new(),
131
132            shutting_down: false,
133
134            limits,
135            batching_policy,
136
137            batch_queues: HashMap::new(),
138        };
139
140        let handle = tokio::spawn(async move {
141            worker.run().await;
142        });
143
144        (
145            WorkerHandle { shutdown_tx },
146            WorkerDropGuard { handle },
147            item_tx,
148        )
149    }
150
151    /// Add an item to the batch.
152    fn add(&mut self, item: BatchItem<P>) {
153        let key = item.key.clone();
154
155        let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
156            BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
157        });
158
159        match self.batching_policy.on_add(batch_queue) {
160            OnAdd::AddAndProcess => {
161                batch_queue.push(item);
162
163                self.process_next_batch(&key);
164            }
165            OnAdd::AddAndAcquireResources => {
166                batch_queue.push(item);
167
168                batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
169            }
170            OnAdd::AddAndProcessAfter(duration) => {
171                batch_queue.push(item);
172
173                batch_queue.process_after(duration, self.msg_tx.clone());
174            }
175            OnAdd::Add => {
176                batch_queue.push(item);
177            }
178            OnAdd::Reject(reason) => {
179                if item
180                    .tx
181                    .send((Err(BatchError::Rejected(reason)), None))
182                    .is_err()
183                {
184                    // Whatever was waiting for the output must have shut down. Presumably it
185                    // doesn't care anymore, but we log here anyway. There's not much else we can do.
186                    debug!(
187                        "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
188                        self.batcher_name
189                    );
190                }
191            }
192        }
193    }
194
195    /// Get the batch queue for the given key, which should always exist when handling an event
196    /// for that key.
197    fn queue_mut<'q>(
198        batch_queues: &'q mut HashMap<P::Key, BatchQueue<P>>,
199        key: &P::Key,
200    ) -> &'q mut BatchQueue<P> {
201        batch_queues.get_mut(key).expect("batch queue should exist")
202    }
203
204    fn process_generation(&mut self, key: P::Key, generation: Generation) {
205        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
206
207        batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
208    }
209
210    fn process_next_ready_batch(&mut self, key: &P::Key) {
211        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
212
213        batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
214    }
215
216    fn process_next_batch(&mut self, key: &P::Key) {
217        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
218
219        batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
220    }
221
222    fn on_timeout(&mut self, key: P::Key, generation: Generation) {
223        // Unlike the other message handlers, the queue may have been removed: timers are not
224        // tracked by the in-flight counters, so a TimedOut message can outlive its queue.
225        let Some(batch_queue) = self.batch_queues.get_mut(&key) else {
226            debug!("Timeout for a batch queue which no longer exists. Ignoring.");
227            return;
228        };
229
230        match self.batching_policy.on_timeout(generation, batch_queue) {
231            OnGenerationEvent::Process => {
232                self.process_generation(key, generation);
233            }
234            OnGenerationEvent::DoNothing => {}
235        }
236    }
237
238    fn on_resource_acquired(
239        &mut self,
240        key: P::Key,
241        generation: Generation,
242        resources: P::Resources,
243        span: Span,
244    ) {
245        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
246
247        batch_queue.resources_acquired(generation, resources, span);
248
249        match self
250            .batching_policy
251            .on_resources_acquired(generation, batch_queue)
252        {
253            OnGenerationEvent::Process => {
254                self.process_generation(key, generation);
255            }
256            OnGenerationEvent::DoNothing => {}
257        }
258    }
259
260    fn on_resource_acquisition_failed(
261        &mut self,
262        key: P::Key,
263        generation: Generation,
264        err: BatchError<P::Error>,
265    ) {
266        let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
267
268        batch_queue.fail_generation(generation, err);
269
270        self.process_next_and_clean_up(&key);
271    }
272
273    fn on_batch_finished(&mut self, key: &P::Key) {
274        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
275
276        batch_queue.mark_processed();
277
278        self.process_next_and_clean_up(key);
279    }
280
281    /// After a batch has left the queue (either processed or having failed to acquire resources),
282    /// apply the batching policy's finish action and drop the queue if the key is now idle.
283    fn process_next_and_clean_up(&mut self, key: &P::Key) {
284        let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
285
286        match self.batching_policy.on_finish(batch_queue) {
287            OnFinish::ProcessNextReady => {
288                self.process_next_ready_batch(key);
289            }
290            OnFinish::ProcessNext => {
291                self.process_next_batch(key);
292            }
293            OnFinish::DoNothing => {}
294        }
295
296        // Remove the queue for idle keys, otherwise the map grows unboundedly as new keys are
297        // seen. A key can only become idle once a batch leaves the queue, so these handlers are
298        // the only place we need to do this.
299        if Self::queue_mut(&mut self.batch_queues, key).is_idle() {
300            self.batch_queues.remove(key);
301        }
302    }
303
304    fn ready_to_shut_down(&self) -> bool {
305        self.shutting_down
306            && self.batch_queues.values().all(|q| q.is_empty())
307            && !self.batch_queues.values().any(|q| q.is_processing())
308    }
309
310    /// Start running the worker event loop.
311    async fn run(&mut self) {
312        loop {
313            tokio::select! {
314                Some(msg) = self.shutdown_notifier_rx.recv() => {
315                    match msg {
316                        ShutdownMessage::Register(notifier) => {
317                           self.shutdown_notifiers.push(notifier.0);
318                        }
319                        ShutdownMessage::ShutDown => {
320                            self.shutting_down = true;
321                        }
322                    }
323                }
324
325                Some(item) = self.item_rx.recv() => {
326                    self.add(item);
327                }
328
329                Some(msg) = self.msg_rx.recv() => {
330                    match msg {
331                        Message::ResourcesAcquired(key, generation, resources, span) => {
332                            self.on_resource_acquired(key, generation, resources, span);
333                        }
334                        Message::ResourceAcquisitionFailed(key, generation, err) => {
335                            self.on_resource_acquisition_failed(key, generation, err);
336                        }
337                        Message::TimedOut(key, generation) => {
338                            self.on_timeout(key, generation);
339                        }
340                        Message::Finished(key) => {
341                            self.on_batch_finished(&key);
342                        }
343                    }
344                }
345            }
346
347            if self.ready_to_shut_down() {
348                info!("Batch worker '{}' is shutting down", &self.batcher_name);
349                return;
350            }
351        }
352    }
353}
354
355impl WorkerHandle {
356    /// Signal the worker to shut down after processing any in-flight batches.
357    ///
358    /// New items are still accepted while shutting down, and the worker only shuts down once all
359    /// keys are idle. This means shutdown may never complete if:
360    ///
361    /// - new items keep being added, or
362    /// - a batch never meets its policy's processing condition, e.g. when using the
363    ///   [`Size`](crate::BatchingPolicy::Size) policy, a final partial batch may wait
364    ///   indefinitely for more items.
365    ///
366    /// Stopping the flow of new items is expected to be handled by the caller, e.g. by shutting
367    /// down the message handlers which add items before shutting down the batcher.
368    pub async fn shut_down(&self) {
369        info!("Sending shut down signal to batch worker");
370        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
371        let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
372    }
373
374    /// Wait for the worker to finish.
375    pub async fn wait_for_shutdown(&self) {
376        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
377        let (notifier_tx, notifier_rx) = oneshot::channel();
378        let _ = self
379            .shutdown_tx
380            .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
381            .await;
382        // Wait for the notifier to be dropped.
383        let _ = notifier_rx.await;
384    }
385}
386
387impl Drop for WorkerDropGuard {
388    fn drop(&mut self) {
389        info!("Aborting batch worker");
390        self.handle.abort();
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use tokio::sync::oneshot;
397    use tracing::Span;
398
399    use super::*;
400
401    #[derive(Debug, Clone)]
402    struct SimpleBatchProcessor;
403
404    impl Processor for SimpleBatchProcessor {
405        type Key = String;
406        type Input = String;
407        type Output = String;
408        type Error = String;
409        type Resources = ();
410
411        async fn acquire_resources(&self, _key: String) -> Result<(), String> {
412            Ok(())
413        }
414
415        async fn process(
416            &self,
417            _key: String,
418            inputs: impl Iterator<Item = String> + Send,
419            _resources: (),
420        ) -> Result<Vec<String>, String> {
421            Ok(inputs.map(|s| s + " processed").collect())
422        }
423    }
424
425    /// Construct a worker directly, without spawning the run loop, so tests can drive it
426    /// manually and inspect its state.
427    fn new_worker() -> Worker<SimpleBatchProcessor> {
428        let (_item_tx, item_rx) = mpsc::channel(1);
429        let (msg_tx, msg_rx) = mpsc::channel(1);
430        let (_shutdown_tx, shutdown_rx) = mpsc::channel(1);
431
432        Worker {
433            batcher_name: "test".to_string(),
434            item_rx,
435            processor: SimpleBatchProcessor,
436            msg_tx,
437            msg_rx,
438            shutdown_notifier_rx: shutdown_rx,
439            shutdown_notifiers: Vec::new(),
440            shutting_down: false,
441            limits: Limits::builder().max_batch_size(1).build(),
442            batching_policy: BatchingPolicy::Size,
443            batch_queues: HashMap::new(),
444        }
445    }
446
447    #[tokio::test]
448    async fn removes_batch_queue_when_key_becomes_idle() {
449        let mut worker = new_worker();
450
451        let (tx, rx) = oneshot::channel();
452        worker.add(BatchItem {
453            key: "K1".to_string(),
454            input: "I1".to_string(),
455            submitted_at: tokio::time::Instant::now(),
456            tx,
457            requesting_span: Span::none(),
458        });
459
460        // max_batch_size is 1, so the batch processes immediately.
461        let output = rx.await.unwrap().0.unwrap();
462        assert_eq!(output, "I1 processed");
463
464        // Handle the Finished message, as the run loop would.
465        let msg = worker.msg_rx.recv().await.unwrap();
466        let Message::Finished(key) = msg else {
467            panic!("expected Finished message, got {:?}", msg);
468        };
469        worker.on_batch_finished(&key);
470
471        assert!(
472            worker.batch_queues.is_empty(),
473            "the batch queue for an idle key should be removed"
474        );
475    }
476
477    #[tokio::test]
478    async fn ignores_timeout_for_removed_batch_queue() {
479        // A timer can fire and enqueue a TimedOut message, after which the batch is processed
480        // anyway (e.g. it filled up) and the queue is removed once the key is idle. The stale
481        // TimedOut message must be ignored, not panic the worker.
482        let mut worker = new_worker();
483
484        worker.on_timeout("K1".to_string(), Generation::default());
485    }
486
487    #[tokio::test]
488    async fn simple_test_over_channel() {
489        let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
490            "test".to_string(),
491            SimpleBatchProcessor,
492            Limits::builder().max_batch_size(2).build(),
493            BatchingPolicy::Size,
494        );
495
496        let rx1 = {
497            let (tx, rx) = oneshot::channel();
498            item_tx
499                .send(BatchItem {
500                    key: "K1".to_string(),
501                    input: "I1".to_string(),
502                    submitted_at: tokio::time::Instant::now(),
503                    tx,
504                    requesting_span: Span::none(),
505                })
506                .await
507                .unwrap();
508
509            rx
510        };
511
512        let rx2 = {
513            let (tx, rx) = oneshot::channel();
514            item_tx
515                .send(BatchItem {
516                    key: "K1".to_string(),
517                    input: "I2".to_string(),
518                    submitted_at: tokio::time::Instant::now(),
519                    tx,
520                    requesting_span: Span::none(),
521                })
522                .await
523                .unwrap();
524
525            rx
526        };
527
528        let o1 = rx1.await.unwrap().0.unwrap();
529        let o2 = rx2.await.unwrap().0.unwrap();
530
531        assert_eq!(o1, "I1 processed".to_string());
532        assert_eq!(o2, "I2 processed".to_string());
533    }
534}