batch_aint_one/
worker.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4};
5
6use tokio::{
7    sync::{mpsc, oneshot},
8    task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13    BatchError,
14    batch::BatchItem,
15    batch_inner::Generation,
16    batch_queue::BatchQueue,
17    limits::Limits,
18    policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
19    processor::Processor,
20};
21
22pub(crate) struct Worker<P: Processor> {
23    batcher_name: String,
24
25    /// Used to receive new batch items.
26    item_rx: mpsc::Receiver<BatchItem<P>>,
27    /// The callback to process a batch of inputs.
28    processor: P,
29
30    /// Used to signal that a batch for key `K` should be processed.
31    msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
32    /// Receives signals to process a batch for key `K`.
33    msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
34
35    /// Used to send messages to the worker related to shutdown.
36    shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
37
38    /// Used to signal to listeners that the worker has shut down.
39    shutdown_notifiers: Vec<oneshot::Sender<()>>,
40
41    shutting_down: bool,
42
43    limits: Limits,
44    /// Controls when to start processing a batch.
45    batching_policy: BatchingPolicy,
46
47    /// Unprocessed batches, grouped by key `K`.
48    batch_queues: HashMap<P::Key, BatchQueue<P>>,
49}
50
51#[derive(Debug)]
52pub(crate) enum Message<K, E: Display + Debug> {
53    TimedOut(K, Generation),
54    ResourcesAcquired(K, Generation),
55    ResourceAcquisitionFailed(K, Generation, BatchError<E>),
56    Finished(K, BatchTerminalState),
57}
58
59#[derive(Debug)]
60pub(crate) enum BatchTerminalState {
61    Processed,
62    FailedAcquiring,
63}
64
65pub(crate) enum ShutdownMessage {
66    Register(ShutdownNotifier),
67    ShutDown,
68}
69
70pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
71
72/// A handle to the worker task.
73///
74/// Used for shutting down the worker and waiting for it to finish.
75#[derive(Debug, Clone)]
76pub struct WorkerHandle {
77    shutdown_tx: mpsc::Sender<ShutdownMessage>,
78}
79
80/// Aborts the worker task when dropped.
81#[derive(Debug)]
82pub(crate) struct WorkerDropGuard {
83    handle: JoinHandle<()>,
84}
85
86impl<P: Processor> Worker<P> {
87    pub fn spawn(
88        batcher_name: String,
89        processor: P,
90        limits: Limits,
91        batching_policy: BatchingPolicy,
92    ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
93        // These channel sizes are somewhat arbitrary - they just need to be big enough to avoid
94        // backpressure in normal operation.
95        let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
96        let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
97
98        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
99
100        let mut worker = Worker {
101            batcher_name,
102
103            item_rx,
104            processor,
105
106            msg_tx,
107            msg_rx,
108
109            shutdown_notifier_rx: shutdown_rx,
110            shutdown_notifiers: Vec::new(),
111
112            shutting_down: false,
113
114            limits,
115            batching_policy,
116
117            batch_queues: HashMap::new(),
118        };
119
120        let handle = tokio::spawn(async move {
121            worker.run().await;
122        });
123
124        (
125            WorkerHandle { shutdown_tx },
126            WorkerDropGuard { handle },
127            item_tx,
128        )
129    }
130
131    /// Add an item to the batch.
132    fn add(&mut self, item: BatchItem<P>) {
133        let key = item.key.clone();
134
135        let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
136            BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
137        });
138
139        match self.batching_policy.on_add(batch_queue) {
140            OnAdd::AddAndProcess => {
141                batch_queue.push(item);
142
143                self.process_next_batch(&key);
144            }
145            OnAdd::AddAndAcquireResources => {
146                batch_queue.push(item);
147
148                batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
149            }
150            OnAdd::AddAndProcessAfter(duration) => {
151                batch_queue.push(item);
152
153                batch_queue.process_after(duration, self.msg_tx.clone());
154            }
155            OnAdd::Add => {
156                batch_queue.push(item);
157            }
158            OnAdd::Reject(reason) => {
159                if item
160                    .tx
161                    .send((Err(BatchError::Rejected(reason)), None))
162                    .is_err()
163                {
164                    // Whatever was waiting for the output must have shut down. Presumably it
165                    // doesn't care anymore, but we log here anyway. There's not much else we can do.
166                    debug!(
167                        "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
168                        self.batcher_name
169                    );
170                }
171            }
172        }
173    }
174
175    fn process_generation(&mut self, key: P::Key, generation: Generation) {
176        let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
177
178        batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
179    }
180
181    fn process_next_ready_batch(&mut self, key: &P::Key) {
182        let batch_queue = self
183            .batch_queues
184            .get_mut(key)
185            .expect("batch queue should exist");
186
187        batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
188    }
189
190    fn process_next_batch(&mut self, key: &P::Key) {
191        let batch_queue = self
192            .batch_queues
193            .get_mut(key)
194            .expect("batch queue should exist");
195
196        batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
197    }
198
199    fn on_timeout(&mut self, key: P::Key, generation: Generation) {
200        let batch_queue = self
201            .batch_queues
202            .get_mut(&key)
203            .expect("batch queue should exist");
204
205        match self.batching_policy.on_timeout(generation, batch_queue) {
206            OnGenerationEvent::Process => {
207                self.process_generation(key, generation);
208            }
209            OnGenerationEvent::DoNothing => {}
210        }
211    }
212
213    fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
214        let batch_queue = self
215            .batch_queues
216            .get_mut(&key)
217            .expect("batch queue should exist");
218
219        batch_queue.mark_resource_acquisition_finished();
220
221        match self
222            .batching_policy
223            .on_resources_acquired(generation, batch_queue)
224        {
225            OnGenerationEvent::Process => {
226                self.process_generation(key, generation);
227            }
228            OnGenerationEvent::DoNothing => {}
229        }
230    }
231
232    fn on_resource_acquisition_failed(
233        &mut self,
234        key: P::Key,
235        generation: Generation,
236        err: BatchError<P::Error>,
237    ) {
238        let batch_queue = self
239            .batch_queues
240            .get_mut(&key)
241            .expect("batch queue should exist");
242
243        batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
244    }
245
246    fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
247        let batch_queue = self
248            .batch_queues
249            .get_mut(key)
250            .expect("batch queue should exist");
251
252        match terminal_state {
253            BatchTerminalState::Processed => {
254                batch_queue.mark_processed();
255            }
256            BatchTerminalState::FailedAcquiring => {
257                batch_queue.mark_resource_acquisition_finished();
258            }
259        }
260
261        match self.batching_policy.on_finish(batch_queue) {
262            OnFinish::ProcessNextReady => {
263                self.process_next_ready_batch(key);
264            }
265            OnFinish::ProcessNext => {
266                self.process_next_batch(key);
267            }
268            OnFinish::DoNothing => {}
269        }
270    }
271
272    fn ready_to_shut_down(&self) -> bool {
273        self.shutting_down
274            && self.batch_queues.values().all(|q| q.is_empty())
275            && !self.batch_queues.values().any(|q| q.is_processing())
276    }
277
278    /// Start running the worker event loop.
279    async fn run(&mut self) {
280        loop {
281            tokio::select! {
282                Some(msg) = self.shutdown_notifier_rx.recv() => {
283                    match msg {
284                        ShutdownMessage::Register(notifier) => {
285                           self.shutdown_notifiers.push(notifier.0);
286                        }
287                        ShutdownMessage::ShutDown => {
288                            self.shutting_down = true;
289                        }
290                    }
291                }
292
293                Some(item) = self.item_rx.recv() => {
294                    self.add(item);
295                }
296
297                Some(msg) = self.msg_rx.recv() => {
298                    match msg {
299                        Message::ResourcesAcquired(key, generation) => {
300                            self.on_resource_acquired(key, generation);
301                        }
302                        Message::ResourceAcquisitionFailed(key, generation, err) => {
303                            self.on_resource_acquisition_failed(key, generation, err);
304                        }
305                        Message::TimedOut(key, generation) => {
306                            self.on_timeout(key, generation);
307                        }
308                        Message::Finished(key, terminal_state) => {
309                            self.on_batch_finished(&key, terminal_state);
310                        }
311                    }
312                }
313            }
314
315            if self.ready_to_shut_down() {
316                info!("Batch worker '{}' is shutting down", &self.batcher_name);
317                return;
318            }
319        }
320    }
321}
322
323impl WorkerHandle {
324    /// Signal the worker to shut down after processing any in-flight batches.
325    ///
326    /// Note that when using the Size policy this may wait indefinitely if no new items are added.
327    pub async fn shut_down(&self) {
328        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
329        let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
330    }
331
332    /// Wait for the worker to finish.
333    pub async fn wait_for_shutdown(&self) {
334        // We ignore errors here - if the receiver has gone away, the worker is already shut down.
335        let (notifier_tx, notifier_rx) = oneshot::channel();
336        let _ = self
337            .shutdown_tx
338            .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
339            .await;
340        // Wait for the notifier to be dropped.
341        let _ = notifier_rx.await;
342    }
343}
344
345impl Drop for WorkerDropGuard {
346    fn drop(&mut self) {
347        self.handle.abort();
348    }
349}
350
351#[cfg(test)]
352mod test {
353    use tokio::sync::oneshot;
354    use tracing::Span;
355
356    use super::*;
357
358    #[derive(Debug, Clone)]
359    struct SimpleBatchProcessor;
360
361    impl Processor for SimpleBatchProcessor {
362        type Key = String;
363        type Input = String;
364        type Output = String;
365        type Error = String;
366        type Resources = ();
367
368        async fn acquire_resources(&self, _key: String) -> Result<(), String> {
369            Ok(())
370        }
371
372        async fn process(
373            &self,
374            _key: String,
375            inputs: impl Iterator<Item = String> + Send,
376            _resources: (),
377        ) -> Result<Vec<String>, String> {
378            Ok(inputs.map(|s| s + " processed").collect())
379        }
380    }
381
382    #[tokio::test]
383    async fn simple_test_over_channel() {
384        let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
385            "test".to_string(),
386            SimpleBatchProcessor,
387            Limits::builder().max_batch_size(2).build(),
388            BatchingPolicy::Size,
389        );
390
391        let rx1 = {
392            let (tx, rx) = oneshot::channel();
393            item_tx
394                .send(BatchItem {
395                    key: "K1".to_string(),
396                    input: "I1".to_string(),
397                    submitted_at: tokio::time::Instant::now(),
398                    tx,
399                    requesting_span: Span::none(),
400                })
401                .await
402                .unwrap();
403
404            rx
405        };
406
407        let rx2 = {
408            let (tx, rx) = oneshot::channel();
409            item_tx
410                .send(BatchItem {
411                    key: "K1".to_string(),
412                    input: "I2".to_string(),
413                    submitted_at: tokio::time::Instant::now(),
414                    tx,
415                    requesting_span: Span::none(),
416                })
417                .await
418                .unwrap();
419
420            rx
421        };
422
423        let o1 = rx1.await.unwrap().0.unwrap();
424        let o2 = rx2.await.unwrap().0.unwrap();
425
426        assert_eq!(o1, "I1 processed".to_string());
427        assert_eq!(o2, "I2 processed".to_string());
428    }
429}