1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4};
5
6use tokio::{
7 sync::{mpsc, oneshot},
8 task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13 BatchError,
14 batch::BatchItem,
15 batch_inner::Generation,
16 batch_queue::BatchQueue,
17 limits::Limits,
18 policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
19 processor::Processor,
20};
21
22pub(crate) struct Worker<P: Processor> {
23 batcher_name: String,
24
25 item_rx: mpsc::Receiver<BatchItem<P>>,
27 processor: P,
29
30 msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
32 msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
34
35 shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
37
38 shutdown_notifiers: Vec<oneshot::Sender<()>>,
40
41 shutting_down: bool,
42
43 limits: Limits,
44 batching_policy: BatchingPolicy,
46
47 batch_queues: HashMap<P::Key, BatchQueue<P>>,
49}
50
51#[derive(Debug)]
52pub(crate) enum Message<K, E: Display + Debug> {
53 TimedOut(K, Generation),
54 ResourcesAcquired(K, Generation),
55 ResourceAcquisitionFailed(K, Generation, BatchError<E>),
56 Finished(K, BatchTerminalState),
57}
58
59#[derive(Debug)]
60pub(crate) enum BatchTerminalState {
61 Processed,
62 FailedAcquiring,
63}
64
65pub(crate) enum ShutdownMessage {
66 Register(ShutdownNotifier),
67 ShutDown,
68}
69
70pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
71
72#[derive(Debug, Clone)]
76pub struct WorkerHandle {
77 shutdown_tx: mpsc::Sender<ShutdownMessage>,
78}
79
80#[derive(Debug)]
82pub(crate) struct WorkerDropGuard {
83 handle: JoinHandle<()>,
84}
85
86impl<P: Processor> Worker<P> {
87 pub fn spawn(
88 batcher_name: String,
89 processor: P,
90 limits: Limits,
91 batching_policy: BatchingPolicy,
92 ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
93 let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
96 let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
97
98 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
99
100 let mut worker = Worker {
101 batcher_name,
102
103 item_rx,
104 processor,
105
106 msg_tx,
107 msg_rx,
108
109 shutdown_notifier_rx: shutdown_rx,
110 shutdown_notifiers: Vec::new(),
111
112 shutting_down: false,
113
114 limits,
115 batching_policy,
116
117 batch_queues: HashMap::new(),
118 };
119
120 let handle = tokio::spawn(async move {
121 worker.run().await;
122 });
123
124 (
125 WorkerHandle { shutdown_tx },
126 WorkerDropGuard { handle },
127 item_tx,
128 )
129 }
130
131 fn add(&mut self, item: BatchItem<P>) {
133 let key = item.key.clone();
134
135 let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
136 BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
137 });
138
139 match self.batching_policy.on_add(batch_queue) {
140 OnAdd::AddAndProcess => {
141 batch_queue.push(item);
142
143 self.process_next_batch(&key);
144 }
145 OnAdd::AddAndAcquireResources => {
146 batch_queue.push(item);
147
148 batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
149 }
150 OnAdd::AddAndProcessAfter(duration) => {
151 batch_queue.push(item);
152
153 batch_queue.process_after(duration, self.msg_tx.clone());
154 }
155 OnAdd::Add => {
156 batch_queue.push(item);
157 }
158 OnAdd::Reject(reason) => {
159 if item
160 .tx
161 .send((Err(BatchError::Rejected(reason)), None))
162 .is_err()
163 {
164 debug!(
167 "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
168 self.batcher_name
169 );
170 }
171 }
172 }
173 }
174
175 fn process_generation(&mut self, key: P::Key, generation: Generation) {
176 let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
177
178 batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
179 }
180
181 fn process_next_ready_batch(&mut self, key: &P::Key) {
182 let batch_queue = self
183 .batch_queues
184 .get_mut(key)
185 .expect("batch queue should exist");
186
187 batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
188 }
189
190 fn process_next_batch(&mut self, key: &P::Key) {
191 let batch_queue = self
192 .batch_queues
193 .get_mut(key)
194 .expect("batch queue should exist");
195
196 batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
197 }
198
199 fn on_timeout(&mut self, key: P::Key, generation: Generation) {
200 let batch_queue = self
201 .batch_queues
202 .get_mut(&key)
203 .expect("batch queue should exist");
204
205 match self.batching_policy.on_timeout(generation, batch_queue) {
206 OnGenerationEvent::Process => {
207 self.process_generation(key, generation);
208 }
209 OnGenerationEvent::DoNothing => {}
210 }
211 }
212
213 fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
214 let batch_queue = self
215 .batch_queues
216 .get_mut(&key)
217 .expect("batch queue should exist");
218
219 batch_queue.mark_resource_acquisition_finished();
220
221 match self
222 .batching_policy
223 .on_resources_acquired(generation, batch_queue)
224 {
225 OnGenerationEvent::Process => {
226 self.process_generation(key, generation);
227 }
228 OnGenerationEvent::DoNothing => {}
229 }
230 }
231
232 fn on_resource_acquisition_failed(
233 &mut self,
234 key: P::Key,
235 generation: Generation,
236 err: BatchError<P::Error>,
237 ) {
238 let batch_queue = self
239 .batch_queues
240 .get_mut(&key)
241 .expect("batch queue should exist");
242
243 batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
244 }
245
246 fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
247 let batch_queue = self
248 .batch_queues
249 .get_mut(key)
250 .expect("batch queue should exist");
251
252 match terminal_state {
253 BatchTerminalState::Processed => {
254 batch_queue.mark_processed();
255 }
256 BatchTerminalState::FailedAcquiring => {
257 batch_queue.mark_resource_acquisition_finished();
258 }
259 }
260
261 match self.batching_policy.on_finish(batch_queue) {
262 OnFinish::ProcessNextReady => {
263 self.process_next_ready_batch(key);
264 }
265 OnFinish::ProcessNext => {
266 self.process_next_batch(key);
267 }
268 OnFinish::DoNothing => {}
269 }
270 }
271
272 fn ready_to_shut_down(&self) -> bool {
273 self.shutting_down
274 && self.batch_queues.values().all(|q| q.is_empty())
275 && !self.batch_queues.values().any(|q| q.is_processing())
276 }
277
278 async fn run(&mut self) {
280 loop {
281 tokio::select! {
282 Some(msg) = self.shutdown_notifier_rx.recv() => {
283 match msg {
284 ShutdownMessage::Register(notifier) => {
285 self.shutdown_notifiers.push(notifier.0);
286 }
287 ShutdownMessage::ShutDown => {
288 self.shutting_down = true;
289 }
290 }
291 }
292
293 Some(item) = self.item_rx.recv() => {
294 self.add(item);
295 }
296
297 Some(msg) = self.msg_rx.recv() => {
298 match msg {
299 Message::ResourcesAcquired(key, generation) => {
300 self.on_resource_acquired(key, generation);
301 }
302 Message::ResourceAcquisitionFailed(key, generation, err) => {
303 self.on_resource_acquisition_failed(key, generation, err);
304 }
305 Message::TimedOut(key, generation) => {
306 self.on_timeout(key, generation);
307 }
308 Message::Finished(key, terminal_state) => {
309 self.on_batch_finished(&key, terminal_state);
310 }
311 }
312 }
313 }
314
315 if self.ready_to_shut_down() {
316 info!("Batch worker '{}' is shutting down", &self.batcher_name);
317 return;
318 }
319 }
320 }
321}
322
323impl WorkerHandle {
324 pub async fn shut_down(&self) {
328 info!("Sending shut down signal to batch worker");
329 let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
331 }
332
333 pub async fn wait_for_shutdown(&self) {
335 let (notifier_tx, notifier_rx) = oneshot::channel();
337 let _ = self
338 .shutdown_tx
339 .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
340 .await;
341 let _ = notifier_rx.await;
343 }
344}
345
346impl Drop for WorkerDropGuard {
347 fn drop(&mut self) {
348 info!("Aborting batch worker");
349 self.handle.abort();
350 }
351}
352
353#[cfg(test)]
354mod test {
355 use tokio::sync::oneshot;
356 use tracing::Span;
357
358 use super::*;
359
360 #[derive(Debug, Clone)]
361 struct SimpleBatchProcessor;
362
363 impl Processor for SimpleBatchProcessor {
364 type Key = String;
365 type Input = String;
366 type Output = String;
367 type Error = String;
368 type Resources = ();
369
370 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
371 Ok(())
372 }
373
374 async fn process(
375 &self,
376 _key: String,
377 inputs: impl Iterator<Item = String> + Send,
378 _resources: (),
379 ) -> Result<Vec<String>, String> {
380 Ok(inputs.map(|s| s + " processed").collect())
381 }
382 }
383
384 #[tokio::test]
385 async fn simple_test_over_channel() {
386 let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
387 "test".to_string(),
388 SimpleBatchProcessor,
389 Limits::builder().max_batch_size(2).build(),
390 BatchingPolicy::Size,
391 );
392
393 let rx1 = {
394 let (tx, rx) = oneshot::channel();
395 item_tx
396 .send(BatchItem {
397 key: "K1".to_string(),
398 input: "I1".to_string(),
399 submitted_at: tokio::time::Instant::now(),
400 tx,
401 requesting_span: Span::none(),
402 })
403 .await
404 .unwrap();
405
406 rx
407 };
408
409 let rx2 = {
410 let (tx, rx) = oneshot::channel();
411 item_tx
412 .send(BatchItem {
413 key: "K1".to_string(),
414 input: "I2".to_string(),
415 submitted_at: tokio::time::Instant::now(),
416 tx,
417 requesting_span: Span::none(),
418 })
419 .await
420 .unwrap();
421
422 rx
423 };
424
425 let o1 = rx1.await.unwrap().0.unwrap();
426 let o2 = rx2.await.unwrap().0.unwrap();
427
428 assert_eq!(o1, "I1 processed".to_string());
429 assert_eq!(o2, "I2 processed".to_string());
430 }
431}