1use std::{collections::HashMap, fmt::Debug};
2
3use tokio::{
4 sync::{mpsc, oneshot},
5 task::JoinHandle,
6};
7use tracing::{Span, debug, info};
8
9use crate::{
10 BatchError,
11 batch::BatchItem,
12 batch_inner::Generation,
13 batch_queue::BatchQueue,
14 limits::Limits,
15 policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
16 processor::Processor,
17};
18
19pub(crate) struct Worker<P: Processor> {
20 batcher_name: String,
21
22 item_rx: mpsc::Receiver<BatchItem<P>>,
24 processor: P,
26
27 msg_tx: mpsc::Sender<Message<P>>,
29 msg_rx: mpsc::Receiver<Message<P>>,
31
32 shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
34
35 shutdown_notifiers: Vec<oneshot::Sender<()>>,
37
38 shutting_down: bool,
39
40 limits: Limits,
41 batching_policy: BatchingPolicy,
43
44 batch_queues: HashMap<P::Key, BatchQueue<P>>,
46}
47
48pub(crate) enum Message<P: Processor> {
54 TimedOut(P::Key, Generation),
55 ResourcesAcquired(P::Key, Generation, P::Resources, Span),
56 ResourceAcquisitionFailed(P::Key, Generation, BatchError<P::Error>),
57 Finished(P::Key),
58}
59
60impl<P: Processor> Debug for Message<P> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Message::TimedOut(key, generation) => f
64 .debug_tuple("TimedOut")
65 .field(key)
66 .field(generation)
67 .finish(),
68 Message::ResourcesAcquired(key, generation, _, _) => f
69 .debug_tuple("ResourcesAcquired")
70 .field(key)
71 .field(generation)
72 .field(&"<Resources>")
73 .finish(),
74 Message::ResourceAcquisitionFailed(key, generation, err) => f
75 .debug_tuple("ResourceAcquisitionFailed")
76 .field(key)
77 .field(generation)
78 .field(err)
79 .finish(),
80 Message::Finished(key) => f.debug_tuple("Finished").field(key).finish(),
81 }
82 }
83}
84
85pub(crate) enum ShutdownMessage {
86 Register(ShutdownNotifier),
87 ShutDown,
88}
89
90pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
91
92#[derive(Debug, Clone)]
96pub struct WorkerHandle {
97 shutdown_tx: mpsc::Sender<ShutdownMessage>,
98}
99
100#[derive(Debug)]
102pub(crate) struct WorkerDropGuard {
103 handle: JoinHandle<()>,
104}
105
106impl<P: Processor> Worker<P> {
107 pub fn spawn(
108 batcher_name: String,
109 processor: P,
110 limits: Limits,
111 batching_policy: BatchingPolicy,
112 ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
113 let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
116 let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
117
118 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
119
120 let mut worker = Worker {
121 batcher_name,
122
123 item_rx,
124 processor,
125
126 msg_tx,
127 msg_rx,
128
129 shutdown_notifier_rx: shutdown_rx,
130 shutdown_notifiers: Vec::new(),
131
132 shutting_down: false,
133
134 limits,
135 batching_policy,
136
137 batch_queues: HashMap::new(),
138 };
139
140 let handle = tokio::spawn(async move {
141 worker.run().await;
142 });
143
144 (
145 WorkerHandle { shutdown_tx },
146 WorkerDropGuard { handle },
147 item_tx,
148 )
149 }
150
151 fn add(&mut self, item: BatchItem<P>) {
153 let key = item.key.clone();
154
155 let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
156 BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
157 });
158
159 match self.batching_policy.on_add(batch_queue) {
160 OnAdd::AddAndProcess => {
161 batch_queue.push(item);
162
163 self.process_next_batch(&key);
164 }
165 OnAdd::AddAndAcquireResources => {
166 batch_queue.push(item);
167
168 batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
169 }
170 OnAdd::AddAndProcessAfter(duration) => {
171 batch_queue.push(item);
172
173 batch_queue.process_after(duration, self.msg_tx.clone());
174 }
175 OnAdd::Add => {
176 batch_queue.push(item);
177 }
178 OnAdd::Reject(reason) => {
179 if item
180 .tx
181 .send((Err(BatchError::Rejected(reason)), None))
182 .is_err()
183 {
184 debug!(
187 "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
188 self.batcher_name
189 );
190 }
191 }
192 }
193 }
194
195 fn queue_mut<'q>(
198 batch_queues: &'q mut HashMap<P::Key, BatchQueue<P>>,
199 key: &P::Key,
200 ) -> &'q mut BatchQueue<P> {
201 batch_queues.get_mut(key).expect("batch queue should exist")
202 }
203
204 fn process_generation(&mut self, key: P::Key, generation: Generation) {
205 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
206
207 batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
208 }
209
210 fn process_next_ready_batch(&mut self, key: &P::Key) {
211 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
212
213 batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
214 }
215
216 fn process_next_batch(&mut self, key: &P::Key) {
217 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
218
219 batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
220 }
221
222 fn on_timeout(&mut self, key: P::Key, generation: Generation) {
223 let Some(batch_queue) = self.batch_queues.get_mut(&key) else {
226 debug!("Timeout for a batch queue which no longer exists. Ignoring.");
227 return;
228 };
229
230 match self.batching_policy.on_timeout(generation, batch_queue) {
231 OnGenerationEvent::Process => {
232 self.process_generation(key, generation);
233 }
234 OnGenerationEvent::DoNothing => {}
235 }
236 }
237
238 fn on_resource_acquired(
239 &mut self,
240 key: P::Key,
241 generation: Generation,
242 resources: P::Resources,
243 span: Span,
244 ) {
245 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
246
247 batch_queue.resources_acquired(generation, resources, span);
248
249 match self
250 .batching_policy
251 .on_resources_acquired(generation, batch_queue)
252 {
253 OnGenerationEvent::Process => {
254 self.process_generation(key, generation);
255 }
256 OnGenerationEvent::DoNothing => {}
257 }
258 }
259
260 fn on_resource_acquisition_failed(
261 &mut self,
262 key: P::Key,
263 generation: Generation,
264 err: BatchError<P::Error>,
265 ) {
266 let batch_queue = Self::queue_mut(&mut self.batch_queues, &key);
267
268 batch_queue.fail_generation(generation, err);
269
270 self.process_next_and_clean_up(&key);
271 }
272
273 fn on_batch_finished(&mut self, key: &P::Key) {
274 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
275
276 batch_queue.mark_processed();
277
278 self.process_next_and_clean_up(key);
279 }
280
281 fn process_next_and_clean_up(&mut self, key: &P::Key) {
284 let batch_queue = Self::queue_mut(&mut self.batch_queues, key);
285
286 match self.batching_policy.on_finish(batch_queue) {
287 OnFinish::ProcessNextReady => {
288 self.process_next_ready_batch(key);
289 }
290 OnFinish::ProcessNext => {
291 self.process_next_batch(key);
292 }
293 OnFinish::DoNothing => {}
294 }
295
296 if Self::queue_mut(&mut self.batch_queues, key).is_idle() {
300 self.batch_queues.remove(key);
301 }
302 }
303
304 fn ready_to_shut_down(&self) -> bool {
305 self.shutting_down
306 && self.batch_queues.values().all(|q| q.is_empty())
307 && !self.batch_queues.values().any(|q| q.is_processing())
308 }
309
310 async fn run(&mut self) {
312 loop {
313 tokio::select! {
314 Some(msg) = self.shutdown_notifier_rx.recv() => {
315 match msg {
316 ShutdownMessage::Register(notifier) => {
317 self.shutdown_notifiers.push(notifier.0);
318 }
319 ShutdownMessage::ShutDown => {
320 self.shutting_down = true;
321 }
322 }
323 }
324
325 Some(item) = self.item_rx.recv() => {
326 self.add(item);
327 }
328
329 Some(msg) = self.msg_rx.recv() => {
330 match msg {
331 Message::ResourcesAcquired(key, generation, resources, span) => {
332 self.on_resource_acquired(key, generation, resources, span);
333 }
334 Message::ResourceAcquisitionFailed(key, generation, err) => {
335 self.on_resource_acquisition_failed(key, generation, err);
336 }
337 Message::TimedOut(key, generation) => {
338 self.on_timeout(key, generation);
339 }
340 Message::Finished(key) => {
341 self.on_batch_finished(&key);
342 }
343 }
344 }
345 }
346
347 if self.ready_to_shut_down() {
348 info!("Batch worker '{}' is shutting down", &self.batcher_name);
349 return;
350 }
351 }
352 }
353}
354
355impl WorkerHandle {
356 pub async fn shut_down(&self) {
369 info!("Sending shut down signal to batch worker");
370 let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
372 }
373
374 pub async fn wait_for_shutdown(&self) {
376 let (notifier_tx, notifier_rx) = oneshot::channel();
378 let _ = self
379 .shutdown_tx
380 .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
381 .await;
382 let _ = notifier_rx.await;
384 }
385}
386
387impl Drop for WorkerDropGuard {
388 fn drop(&mut self) {
389 info!("Aborting batch worker");
390 self.handle.abort();
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use tokio::sync::oneshot;
397 use tracing::Span;
398
399 use super::*;
400
401 #[derive(Debug, Clone)]
402 struct SimpleBatchProcessor;
403
404 impl Processor for SimpleBatchProcessor {
405 type Key = String;
406 type Input = String;
407 type Output = String;
408 type Error = String;
409 type Resources = ();
410
411 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
412 Ok(())
413 }
414
415 async fn process(
416 &self,
417 _key: String,
418 inputs: impl Iterator<Item = String> + Send,
419 _resources: (),
420 ) -> Result<Vec<String>, String> {
421 Ok(inputs.map(|s| s + " processed").collect())
422 }
423 }
424
425 fn new_worker() -> Worker<SimpleBatchProcessor> {
428 let (_item_tx, item_rx) = mpsc::channel(1);
429 let (msg_tx, msg_rx) = mpsc::channel(1);
430 let (_shutdown_tx, shutdown_rx) = mpsc::channel(1);
431
432 Worker {
433 batcher_name: "test".to_string(),
434 item_rx,
435 processor: SimpleBatchProcessor,
436 msg_tx,
437 msg_rx,
438 shutdown_notifier_rx: shutdown_rx,
439 shutdown_notifiers: Vec::new(),
440 shutting_down: false,
441 limits: Limits::builder().max_batch_size(1).build(),
442 batching_policy: BatchingPolicy::Size,
443 batch_queues: HashMap::new(),
444 }
445 }
446
447 #[tokio::test]
448 async fn removes_batch_queue_when_key_becomes_idle() {
449 let mut worker = new_worker();
450
451 let (tx, rx) = oneshot::channel();
452 worker.add(BatchItem {
453 key: "K1".to_string(),
454 input: "I1".to_string(),
455 submitted_at: tokio::time::Instant::now(),
456 tx,
457 requesting_span: Span::none(),
458 });
459
460 let output = rx.await.unwrap().0.unwrap();
462 assert_eq!(output, "I1 processed");
463
464 let msg = worker.msg_rx.recv().await.unwrap();
466 let Message::Finished(key) = msg else {
467 panic!("expected Finished message, got {:?}", msg);
468 };
469 worker.on_batch_finished(&key);
470
471 assert!(
472 worker.batch_queues.is_empty(),
473 "the batch queue for an idle key should be removed"
474 );
475 }
476
477 #[tokio::test]
478 async fn ignores_timeout_for_removed_batch_queue() {
479 let mut worker = new_worker();
483
484 worker.on_timeout("K1".to_string(), Generation::default());
485 }
486
487 #[tokio::test]
488 async fn simple_test_over_channel() {
489 let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
490 "test".to_string(),
491 SimpleBatchProcessor,
492 Limits::builder().max_batch_size(2).build(),
493 BatchingPolicy::Size,
494 );
495
496 let rx1 = {
497 let (tx, rx) = oneshot::channel();
498 item_tx
499 .send(BatchItem {
500 key: "K1".to_string(),
501 input: "I1".to_string(),
502 submitted_at: tokio::time::Instant::now(),
503 tx,
504 requesting_span: Span::none(),
505 })
506 .await
507 .unwrap();
508
509 rx
510 };
511
512 let rx2 = {
513 let (tx, rx) = oneshot::channel();
514 item_tx
515 .send(BatchItem {
516 key: "K1".to_string(),
517 input: "I2".to_string(),
518 submitted_at: tokio::time::Instant::now(),
519 tx,
520 requesting_span: Span::none(),
521 })
522 .await
523 .unwrap();
524
525 rx
526 };
527
528 let o1 = rx1.await.unwrap().0.unwrap();
529 let o2 = rx2.await.unwrap().0.unwrap();
530
531 assert_eq!(o1, "I1 processed".to_string());
532 assert_eq!(o2, "I2 processed".to_string());
533 }
534}