1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4};
5
6use tokio::{
7 sync::{mpsc, oneshot},
8 task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13 BatchError,
14 batch::BatchItem,
15 batch_inner::Generation,
16 batch_queue::BatchQueue,
17 policies::{BatchingPolicy, Limits, OnAdd, ProcessAction},
18 processor::Processor,
19};
20
21pub(crate) struct Worker<P: Processor> {
22 batcher_name: String,
23
24 item_rx: mpsc::Receiver<BatchItem<P>>,
26 processor: P,
28
29 msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
31 msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
33
34 shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
36
37 shutdown_notifiers: Vec<oneshot::Sender<()>>,
39
40 shutting_down: bool,
41
42 limits: Limits,
43 batching_policy: BatchingPolicy,
45
46 batch_queues: HashMap<P::Key, BatchQueue<P>>,
48}
49
50#[derive(Debug)]
51pub(crate) enum Message<K, E: Display + Debug> {
52 TimedOut(K, Generation),
53 ResourcesAcquired(K, Generation),
54 ResourceAcquisitionFailed(K, Generation, BatchError<E>),
55 Finished(K),
56}
57
58pub(crate) enum ShutdownMessage {
59 Register(ShutdownNotifier),
60 ShutDown,
61}
62
63pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
64
65#[derive(Debug, Clone)]
69pub struct WorkerHandle {
70 shutdown_tx: mpsc::Sender<ShutdownMessage>,
71}
72
73#[derive(Debug)]
75pub(crate) struct WorkerDropGuard {
76 handle: JoinHandle<()>,
77}
78
79impl<P: Processor> Worker<P> {
80 pub fn spawn(
81 batcher_name: String,
82 processor: P,
83 limits: Limits,
84 batching_policy: BatchingPolicy,
85 ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
86 let (item_tx, item_rx) = mpsc::channel(10);
87
88 let (timeout_tx, timeout_rx) = mpsc::channel(10);
89
90 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
91
92 let mut worker = Worker {
93 batcher_name,
94
95 item_rx,
96 processor,
97
98 msg_tx: timeout_tx,
99 msg_rx: timeout_rx,
100
101 shutdown_notifier_rx: shutdown_rx,
102 shutdown_notifiers: Vec::new(),
103
104 shutting_down: false,
105
106 limits,
107 batching_policy,
108
109 batch_queues: HashMap::new(),
110 };
111
112 let handle = tokio::spawn(async move {
113 worker.run().await;
114 });
115
116 (
117 WorkerHandle { shutdown_tx },
118 WorkerDropGuard { handle },
119 item_tx,
120 )
121 }
122
123 fn add(&mut self, item: BatchItem<P>) {
125 let key = item.key.clone();
126
127 let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
128 BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
129 });
130
131 match self.batching_policy.on_add(batch_queue) {
132 OnAdd::AddAndProcess => {
133 batch_queue.push(item);
134
135 self.process_next_batch(&key);
136 }
137 OnAdd::AddAndAcquireResources => {
138 batch_queue.push(item);
139
140 batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
141 }
142 OnAdd::AddAndProcessAfter(duration) => {
143 batch_queue.push(item);
144
145 batch_queue.process_after(duration, self.msg_tx.clone());
146 }
147 OnAdd::Add => {
148 batch_queue.push(item);
149 }
150 OnAdd::Reject(reason) => {
151 if item
152 .tx
153 .send((Err(BatchError::Rejected(reason)), None))
154 .is_err()
155 {
156 debug!(
159 "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
160 self.batcher_name
161 );
162 }
163 }
164 }
165 }
166
167 fn process_generation(&mut self, key: P::Key, generation: Generation) {
168 let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
169
170 if let Some(batch) = batch_queue.take_generation(generation) {
171 let on_finished = self.msg_tx.clone();
172
173 batch.process(self.processor.clone(), on_finished);
174 }
175 }
176
177 fn process_next_batch(&mut self, key: &P::Key) {
178 let batch_queue = self
179 .batch_queues
180 .get_mut(key)
181 .expect("batch queue should exist");
182
183 if let Some(batch) = batch_queue.take_next_ready_batch() {
184 let on_finished = self.msg_tx.clone();
185
186 batch.process(self.processor.clone(), on_finished);
187
188 debug_assert!(
189 batch_queue.within_processing_capacity(),
190 "processing count should not exceed max key concurrency"
191 );
192 }
193 }
194
195 fn on_timeout(&mut self, key: P::Key, generation: Generation) {
196 let batch_queue = self
197 .batch_queues
198 .get_mut(&key)
199 .expect("batch queue should exist");
200
201 match self.batching_policy.on_timeout(generation, batch_queue) {
202 ProcessAction::Process => {
203 self.process_generation(key, generation);
204 }
205 ProcessAction::DoNothing => {}
206 }
207 }
208
209 fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
210 let batch_queue = self
211 .batch_queues
212 .get_mut(&key)
213 .expect("batch queue should exist");
214
215 match self
216 .batching_policy
217 .on_resources_acquired(generation, batch_queue)
218 {
219 ProcessAction::Process => {
220 self.process_generation(key, generation);
221 }
222 ProcessAction::DoNothing => {}
223 }
224 }
225
226 fn on_batch_finished(&mut self, key: &P::Key) {
227 let batch_queue = self
228 .batch_queues
229 .get_mut(key)
230 .expect("batch queue should exist");
231
232 match self.batching_policy.on_finish(batch_queue) {
233 ProcessAction::Process => {
234 self.process_next_batch(key);
235 }
236 ProcessAction::DoNothing => {}
237 }
238 }
239
240 fn fail_batch(&mut self, key: P::Key, generation: Generation, err: BatchError<P::Error>) {
241 let batch_queue = self
242 .batch_queues
243 .get_mut(&key)
244 .expect("batch queue should exist");
245
246 if let Some(batch) = batch_queue.take_generation(generation) {
247 let on_finished = self.msg_tx.clone();
248 batch.fail(err, on_finished)
249 }
250 }
251
252 fn ready_to_shut_down(&self) -> bool {
253 self.shutting_down
254 && self.batch_queues.values().all(|q| q.is_empty())
255 && !self.batch_queues.values().any(|q| q.is_processing())
256 }
257
258 async fn run(&mut self) {
260 loop {
261 tokio::select! {
262 Some(msg) = self.shutdown_notifier_rx.recv() => {
263 match msg {
264 ShutdownMessage::Register(notifier) => {
265 self.shutdown_notifiers.push(notifier.0);
266 }
267 ShutdownMessage::ShutDown => {
268 self.shutting_down = true;
269 }
270 }
271 }
272
273 Some(item) = self.item_rx.recv() => {
274 self.add(item);
275 }
276
277 Some(msg) = self.msg_rx.recv() => {
278 match msg {
279 Message::ResourcesAcquired(key, generation) => {
280 self.on_resource_acquired(key, generation);
281 }
282 Message::ResourceAcquisitionFailed(key, generation, err) => {
283 self.fail_batch(key, generation, err);
284 }
285 Message::TimedOut(key, generation) => {
286 self.on_timeout(key, generation);
287 }
288 Message::Finished(key) => {
289 self.on_batch_finished(&key);
290 }
291 }
292 }
293 }
294
295 if self.ready_to_shut_down() {
296 info!("Batch worker '{}' is shutting down", &self.batcher_name);
297 return;
298 }
299 }
300 }
301}
302
303impl WorkerHandle {
304 pub async fn shut_down(&self) {
308 let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
310 }
311
312 pub async fn wait_for_shutdown(&self) {
314 let (notifier_tx, notifier_rx) = oneshot::channel();
316 let _ = self
317 .shutdown_tx
318 .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
319 .await;
320 let _ = notifier_rx.await;
322 }
323}
324
325impl Drop for WorkerDropGuard {
326 fn drop(&mut self) {
327 self.handle.abort();
328 }
329}
330
331#[cfg(test)]
332mod test {
333 use tokio::sync::oneshot;
334 use tracing::Span;
335
336 use super::*;
337
338 #[derive(Debug, Clone)]
339 struct SimpleBatchProcessor;
340
341 impl Processor for SimpleBatchProcessor {
342 type Key = String;
343 type Input = String;
344 type Output = String;
345 type Error = String;
346 type Resources = ();
347
348 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
349 Ok(())
350 }
351
352 async fn process(
353 &self,
354 _key: String,
355 inputs: impl Iterator<Item = String> + Send,
356 _resources: (),
357 ) -> Result<Vec<String>, String> {
358 Ok(inputs.map(|s| s + " processed").collect())
359 }
360 }
361
362 #[tokio::test]
363 async fn simple_test_over_channel() {
364 let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
365 "test".to_string(),
366 SimpleBatchProcessor,
367 Limits::default().with_max_batch_size(2),
368 BatchingPolicy::Size,
369 );
370
371 let rx1 = {
372 let (tx, rx) = oneshot::channel();
373 item_tx
374 .send(BatchItem {
375 key: "K1".to_string(),
376 input: "I1".to_string(),
377 tx,
378 requesting_span: Span::none(),
379 })
380 .await
381 .unwrap();
382
383 rx
384 };
385
386 let rx2 = {
387 let (tx, rx) = oneshot::channel();
388 item_tx
389 .send(BatchItem {
390 key: "K1".to_string(),
391 input: "I2".to_string(),
392 tx,
393 requesting_span: Span::none(),
394 })
395 .await
396 .unwrap();
397
398 rx
399 };
400
401 let o1 = rx1.await.unwrap().0.unwrap();
402 let o2 = rx2.await.unwrap().0.unwrap();
403
404 assert_eq!(o1, "I1 processed".to_string());
405 assert_eq!(o2, "I2 processed".to_string());
406 }
407}