1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4};
5
6use tokio::{
7 sync::{mpsc, oneshot},
8 task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13 BatchError,
14 batch::BatchItem,
15 batch_inner::Generation,
16 batch_queue::BatchQueue,
17 limits::Limits,
18 policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
19 processor::Processor,
20};
21
22pub(crate) struct Worker<P: Processor> {
23 batcher_name: String,
24
25 item_rx: mpsc::Receiver<BatchItem<P>>,
27 processor: P,
29
30 msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
32 msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
34
35 shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
37
38 shutdown_notifiers: Vec<oneshot::Sender<()>>,
40
41 shutting_down: bool,
42
43 limits: Limits,
44 batching_policy: BatchingPolicy,
46
47 batch_queues: HashMap<P::Key, BatchQueue<P>>,
49}
50
51#[derive(Debug)]
52pub(crate) enum Message<K, E: Display + Debug> {
53 TimedOut(K, Generation),
54 ResourcesAcquired(K, Generation),
55 ResourceAcquisitionFailed(K, Generation, BatchError<E>),
56 Finished(K, BatchTerminalState),
57}
58
59#[derive(Debug)]
60pub(crate) enum BatchTerminalState {
61 Processed,
62 FailedAcquiring,
63}
64
65pub(crate) enum ShutdownMessage {
66 Register(ShutdownNotifier),
67 ShutDown,
68}
69
70pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
71
72#[derive(Debug, Clone)]
76pub struct WorkerHandle {
77 shutdown_tx: mpsc::Sender<ShutdownMessage>,
78}
79
80#[derive(Debug)]
82pub(crate) struct WorkerDropGuard {
83 handle: JoinHandle<()>,
84}
85
86impl<P: Processor> Worker<P> {
87 pub fn spawn(
88 batcher_name: String,
89 processor: P,
90 limits: Limits,
91 batching_policy: BatchingPolicy,
92 ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
93 let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
96 let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
97
98 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
99
100 let mut worker = Worker {
101 batcher_name,
102
103 item_rx,
104 processor,
105
106 msg_tx,
107 msg_rx,
108
109 shutdown_notifier_rx: shutdown_rx,
110 shutdown_notifiers: Vec::new(),
111
112 shutting_down: false,
113
114 limits,
115 batching_policy,
116
117 batch_queues: HashMap::new(),
118 };
119
120 let handle = tokio::spawn(async move {
121 worker.run().await;
122 });
123
124 (
125 WorkerHandle { shutdown_tx },
126 WorkerDropGuard { handle },
127 item_tx,
128 )
129 }
130
131 fn add(&mut self, item: BatchItem<P>) {
133 let key = item.key.clone();
134
135 let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
136 BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
137 });
138
139 match self.batching_policy.on_add(batch_queue) {
140 OnAdd::AddAndProcess => {
141 batch_queue.push(item);
142
143 self.process_next_batch(&key);
144 }
145 OnAdd::AddAndAcquireResources => {
146 batch_queue.push(item);
147
148 batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
149 }
150 OnAdd::AddAndProcessAfter(duration) => {
151 batch_queue.push(item);
152
153 batch_queue.process_after(duration, self.msg_tx.clone());
154 }
155 OnAdd::Add => {
156 batch_queue.push(item);
157 }
158 OnAdd::Reject(reason) => {
159 if item
160 .tx
161 .send((Err(BatchError::Rejected(reason)), None))
162 .is_err()
163 {
164 debug!(
167 "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
168 self.batcher_name
169 );
170 }
171 }
172 }
173 }
174
175 fn queue_mut<'q>(
178 batch_queues: &'q mut HashMap<P::Key, BatchQueue<P>>,
179 key: &P::Key,
180 ) -> &'q mut BatchQueue<P> {
181 batch_queues.get_mut(key).expect("batch queue should exist")
182 }
183
184 fn process_generation(&mut self, key: P::Key, generation: Generation) {
185 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
186
187 batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
188 }
189
190 fn process_next_ready_batch(&mut self, key: &P::Key) {
191 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
192
193 batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
194 }
195
196 fn process_next_batch(&mut self, key: &P::Key) {
197 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
198
199 batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
200 }
201
202 fn on_timeout(&mut self, key: P::Key, generation: Generation) {
203 let Some(batch_queue) = self.batch_queues.get_mut(&key) else {
206 debug!("Timeout for a batch queue which no longer exists. Ignoring.");
207 return;
208 };
209
210 match self.batching_policy.on_timeout(generation, batch_queue) {
211 OnGenerationEvent::Process => {
212 self.process_generation(key, generation);
213 }
214 OnGenerationEvent::DoNothing => {}
215 }
216 }
217
218 fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
219 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
220
221 batch_queue.mark_resource_acquisition_finished();
222
223 match self
224 .batching_policy
225 .on_resources_acquired(generation, batch_queue)
226 {
227 OnGenerationEvent::Process => {
228 self.process_generation(key, generation);
229 }
230 OnGenerationEvent::DoNothing => {}
231 }
232 }
233
234 fn on_resource_acquisition_failed(
235 &mut self,
236 key: P::Key,
237 generation: Generation,
238 err: BatchError<P::Error>,
239 ) {
240 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
241
242 batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
243 }
244
245 fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
246 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
247
248 match terminal_state {
249 BatchTerminalState::Processed => {
250 batch_queue.mark_processed();
251 }
252 BatchTerminalState::FailedAcquiring => {
253 batch_queue.mark_resource_acquisition_finished();
254 }
255 }
256
257 match self.batching_policy.on_finish(batch_queue) {
258 OnFinish::ProcessNextReady => {
259 self.process_next_ready_batch(key);
260 }
261 OnFinish::ProcessNext => {
262 self.process_next_batch(key);
263 }
264 OnFinish::DoNothing => {}
265 }
266
267 if Self::queue_mut(&mut self.batch_queues, key).is_idle() {
271 self.batch_queues.remove(key);
272 }
273 }
274
275 fn ready_to_shut_down(&self) -> bool {
276 self.shutting_down
277 && self.batch_queues.values().all(|q| q.is_empty())
278 && !self.batch_queues.values().any(|q| q.is_processing())
279 }
280
281 async fn run(&mut self) {
283 loop {
284 tokio::select! {
285 Some(msg) = self.shutdown_notifier_rx.recv() => {
286 match msg {
287 ShutdownMessage::Register(notifier) => {
288 self.shutdown_notifiers.push(notifier.0);
289 }
290 ShutdownMessage::ShutDown => {
291 self.shutting_down = true;
292 }
293 }
294 }
295
296 Some(item) = self.item_rx.recv() => {
297 self.add(item);
298 }
299
300 Some(msg) = self.msg_rx.recv() => {
301 match msg {
302 Message::ResourcesAcquired(key, generation) => {
303 self.on_resource_acquired(key, generation);
304 }
305 Message::ResourceAcquisitionFailed(key, generation, err) => {
306 self.on_resource_acquisition_failed(key, generation, err);
307 }
308 Message::TimedOut(key, generation) => {
309 self.on_timeout(key, generation);
310 }
311 Message::Finished(key, terminal_state) => {
312 self.on_batch_finished(&key, terminal_state);
313 }
314 }
315 }
316 }
317
318 if self.ready_to_shut_down() {
319 info!("Batch worker '{}' is shutting down", &self.batcher_name);
320 return;
321 }
322 }
323 }
324}
325
326impl WorkerHandle {
327 pub async fn shut_down(&self) {
340 info!("Sending shut down signal to batch worker");
341 let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
343 }
344
345 pub async fn wait_for_shutdown(&self) {
347 let (notifier_tx, notifier_rx) = oneshot::channel();
349 let _ = self
350 .shutdown_tx
351 .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
352 .await;
353 let _ = notifier_rx.await;
355 }
356}
357
358impl Drop for WorkerDropGuard {
359 fn drop(&mut self) {
360 info!("Aborting batch worker");
361 self.handle.abort();
362 }
363}
364
365#[cfg(test)]
366mod test {
367 use tokio::sync::oneshot;
368 use tracing::Span;
369
370 use super::*;
371
372 #[derive(Debug, Clone)]
373 struct SimpleBatchProcessor;
374
375 impl Processor for SimpleBatchProcessor {
376 type Key = String;
377 type Input = String;
378 type Output = String;
379 type Error = String;
380 type Resources = ();
381
382 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
383 Ok(())
384 }
385
386 async fn process(
387 &self,
388 _key: String,
389 inputs: impl Iterator<Item = String> + Send,
390 _resources: (),
391 ) -> Result<Vec<String>, String> {
392 Ok(inputs.map(|s| s + " processed").collect())
393 }
394 }
395
396 fn new_worker() -> Worker<SimpleBatchProcessor> {
399 let (_item_tx, item_rx) = mpsc::channel(1);
400 let (msg_tx, msg_rx) = mpsc::channel(1);
401 let (_shutdown_tx, shutdown_rx) = mpsc::channel(1);
402
403 Worker {
404 batcher_name: "test".to_string(),
405 item_rx,
406 processor: SimpleBatchProcessor,
407 msg_tx,
408 msg_rx,
409 shutdown_notifier_rx: shutdown_rx,
410 shutdown_notifiers: Vec::new(),
411 shutting_down: false,
412 limits: Limits::builder().max_batch_size(1).build(),
413 batching_policy: BatchingPolicy::Size,
414 batch_queues: HashMap::new(),
415 }
416 }
417
418 #[tokio::test]
419 async fn removes_batch_queue_when_key_becomes_idle() {
420 let mut worker = new_worker();
421
422 let (tx, rx) = oneshot::channel();
423 worker.add(BatchItem {
424 key: "K1".to_string(),
425 input: "I1".to_string(),
426 submitted_at: tokio::time::Instant::now(),
427 tx,
428 requesting_span: Span::none(),
429 });
430
431 let output = rx.await.unwrap().0.unwrap();
433 assert_eq!(output, "I1 processed");
434
435 let msg = worker.msg_rx.recv().await.unwrap();
437 let Message::Finished(key, terminal_state) = msg else {
438 panic!("expected Finished message, got {:?}", msg);
439 };
440 worker.on_batch_finished(&key, terminal_state);
441
442 assert!(
443 worker.batch_queues.is_empty(),
444 "the batch queue for an idle key should be removed"
445 );
446 }
447
448 #[tokio::test]
449 async fn ignores_timeout_for_removed_batch_queue() {
450 let mut worker = new_worker();
454
455 worker.on_timeout("K1".to_string(), Generation::default());
456 }
457
458 #[tokio::test]
459 async fn simple_test_over_channel() {
460 let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
461 "test".to_string(),
462 SimpleBatchProcessor,
463 Limits::builder().max_batch_size(2).build(),
464 BatchingPolicy::Size,
465 );
466
467 let rx1 = {
468 let (tx, rx) = oneshot::channel();
469 item_tx
470 .send(BatchItem {
471 key: "K1".to_string(),
472 input: "I1".to_string(),
473 submitted_at: tokio::time::Instant::now(),
474 tx,
475 requesting_span: Span::none(),
476 })
477 .await
478 .unwrap();
479
480 rx
481 };
482
483 let rx2 = {
484 let (tx, rx) = oneshot::channel();
485 item_tx
486 .send(BatchItem {
487 key: "K1".to_string(),
488 input: "I2".to_string(),
489 submitted_at: tokio::time::Instant::now(),
490 tx,
491 requesting_span: Span::none(),
492 })
493 .await
494 .unwrap();
495
496 rx
497 };
498
499 let o1 = rx1.await.unwrap().0.unwrap();
500 let o2 = rx2.await.unwrap().0.unwrap();
501
502 assert_eq!(o1, "I1 processed".to_string());
503 assert_eq!(o2, "I2 processed".to_string());
504 }
505}