1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4};
5
6use tokio::{
7 sync::{mpsc, oneshot},
8 task::JoinHandle,
9};
10use tracing::{debug, info};
11
12use crate::{
13 BatchError,
14 batch::BatchItem,
15 batch_inner::Generation,
16 batch_queue::BatchQueue,
17 limits::Limits,
18 policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
19 processor::Processor,
20};
21
22pub(crate) struct Worker<P: Processor> {
23 batcher_name: String,
24
25 item_rx: mpsc::Receiver<BatchItem<P>>,
27 processor: P,
29
30 msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
32 msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
34
35 shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
37
38 shutdown_notifiers: Vec<oneshot::Sender<()>>,
40
41 shutting_down: bool,
42
43 limits: Limits,
44 batching_policy: BatchingPolicy,
46
47 batch_queues: HashMap<P::Key, BatchQueue<P>>,
49}
50
51#[derive(Debug)]
52pub(crate) enum Message<K, E: Display + Debug> {
53 TimedOut(K, Generation),
54 ResourcesAcquired(K, Generation),
55 ResourceAcquisitionFailed(K, Generation, BatchError<E>),
56 Finished(K, BatchTerminalState),
57}
58
59#[derive(Debug)]
60pub(crate) enum BatchTerminalState {
61 Processed,
62 FailedAcquiring,
63}
64
65pub(crate) enum ShutdownMessage {
66 Register(ShutdownNotifier),
67 ShutDown,
68}
69
70pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
71
72#[derive(Debug, Clone)]
76pub struct WorkerHandle {
77 shutdown_tx: mpsc::Sender<ShutdownMessage>,
78}
79
80#[derive(Debug)]
82pub(crate) struct WorkerDropGuard {
83 handle: JoinHandle<()>,
84}
85
86impl<P: Processor> Worker<P> {
87 pub fn spawn(
88 batcher_name: String,
89 processor: P,
90 limits: Limits,
91 batching_policy: BatchingPolicy,
92 ) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
93 let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
96 let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
97
98 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
99
100 let mut worker = Worker {
101 batcher_name,
102
103 item_rx,
104 processor,
105
106 msg_tx,
107 msg_rx,
108
109 shutdown_notifier_rx: shutdown_rx,
110 shutdown_notifiers: Vec::new(),
111
112 shutting_down: false,
113
114 limits,
115 batching_policy,
116
117 batch_queues: HashMap::new(),
118 };
119
120 let handle = tokio::spawn(async move {
121 worker.run().await;
122 });
123
124 (
125 WorkerHandle { shutdown_tx },
126 WorkerDropGuard { handle },
127 item_tx,
128 )
129 }
130
131 fn add(&mut self, item: BatchItem<P>) {
133 let key = item.key.clone();
134
135 let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
136 BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
137 });
138
139 match self.batching_policy.on_add(batch_queue) {
140 OnAdd::AddAndProcess => {
141 batch_queue.push(item);
142
143 self.process_next_batch(&key);
144 }
145 OnAdd::AddAndAcquireResources => {
146 batch_queue.push(item);
147
148 batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
149 }
150 OnAdd::AddAndProcessAfter(duration) => {
151 batch_queue.push(item);
152
153 batch_queue.process_after(duration, self.msg_tx.clone());
154 }
155 OnAdd::Add => {
156 batch_queue.push(item);
157 }
158 OnAdd::Reject(reason) => {
159 if item
160 .tx
161 .send((Err(BatchError::Rejected(reason)), None))
162 .is_err()
163 {
164 debug!(
167 "Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
168 self.batcher_name
169 );
170 }
171 }
172 }
173 }
174
175 fn process_generation(&mut self, key: P::Key, generation: Generation) {
176 let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
177
178 batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
179 }
180
181 fn process_next_ready_batch(&mut self, key: &P::Key) {
182 let batch_queue = self
183 .batch_queues
184 .get_mut(key)
185 .expect("batch queue should exist");
186
187 batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
188 }
189
190 fn process_next_batch(&mut self, key: &P::Key) {
191 let batch_queue = self
192 .batch_queues
193 .get_mut(key)
194 .expect("batch queue should exist");
195
196 batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
197 }
198
199 fn on_timeout(&mut self, key: P::Key, generation: Generation) {
200 let batch_queue = self
201 .batch_queues
202 .get_mut(&key)
203 .expect("batch queue should exist");
204
205 match self.batching_policy.on_timeout(generation, batch_queue) {
206 OnGenerationEvent::Process => {
207 self.process_generation(key, generation);
208 }
209 OnGenerationEvent::DoNothing => {}
210 }
211 }
212
213 fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
214 let batch_queue = self
215 .batch_queues
216 .get_mut(&key)
217 .expect("batch queue should exist");
218
219 batch_queue.mark_resource_acquisition_finished();
220
221 match self
222 .batching_policy
223 .on_resources_acquired(generation, batch_queue)
224 {
225 OnGenerationEvent::Process => {
226 self.process_generation(key, generation);
227 }
228 OnGenerationEvent::DoNothing => {}
229 }
230 }
231
232 fn on_resource_acquisition_failed(
233 &mut self,
234 key: P::Key,
235 generation: Generation,
236 err: BatchError<P::Error>,
237 ) {
238 let batch_queue = self
239 .batch_queues
240 .get_mut(&key)
241 .expect("batch queue should exist");
242
243 batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
244 }
245
246 fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
247 let batch_queue = self
248 .batch_queues
249 .get_mut(key)
250 .expect("batch queue should exist");
251
252 match terminal_state {
253 BatchTerminalState::Processed => {
254 batch_queue.mark_processed();
255 }
256 BatchTerminalState::FailedAcquiring => {
257 batch_queue.mark_resource_acquisition_finished();
258 }
259 }
260
261 match self.batching_policy.on_finish(batch_queue) {
262 OnFinish::ProcessNextReady => {
263 self.process_next_ready_batch(key);
264 }
265 OnFinish::ProcessNext => {
266 self.process_next_batch(key);
267 }
268 OnFinish::DoNothing => {}
269 }
270 }
271
272 fn ready_to_shut_down(&self) -> bool {
273 self.shutting_down
274 && self.batch_queues.values().all(|q| q.is_empty())
275 && !self.batch_queues.values().any(|q| q.is_processing())
276 }
277
278 async fn run(&mut self) {
280 loop {
281 tokio::select! {
282 Some(msg) = self.shutdown_notifier_rx.recv() => {
283 match msg {
284 ShutdownMessage::Register(notifier) => {
285 self.shutdown_notifiers.push(notifier.0);
286 }
287 ShutdownMessage::ShutDown => {
288 self.shutting_down = true;
289 }
290 }
291 }
292
293 Some(item) = self.item_rx.recv() => {
294 self.add(item);
295 }
296
297 Some(msg) = self.msg_rx.recv() => {
298 match msg {
299 Message::ResourcesAcquired(key, generation) => {
300 self.on_resource_acquired(key, generation);
301 }
302 Message::ResourceAcquisitionFailed(key, generation, err) => {
303 self.on_resource_acquisition_failed(key, generation, err);
304 }
305 Message::TimedOut(key, generation) => {
306 self.on_timeout(key, generation);
307 }
308 Message::Finished(key, terminal_state) => {
309 self.on_batch_finished(&key, terminal_state);
310 }
311 }
312 }
313 }
314
315 if self.ready_to_shut_down() {
316 info!("Batch worker '{}' is shutting down", &self.batcher_name);
317 return;
318 }
319 }
320 }
321}
322
323impl WorkerHandle {
324 pub async fn shut_down(&self) {
328 let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
330 }
331
332 pub async fn wait_for_shutdown(&self) {
334 let (notifier_tx, notifier_rx) = oneshot::channel();
336 let _ = self
337 .shutdown_tx
338 .send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
339 .await;
340 let _ = notifier_rx.await;
342 }
343}
344
345impl Drop for WorkerDropGuard {
346 fn drop(&mut self) {
347 self.handle.abort();
348 }
349}
350
351#[cfg(test)]
352mod test {
353 use tokio::sync::oneshot;
354 use tracing::Span;
355
356 use super::*;
357
358 #[derive(Debug, Clone)]
359 struct SimpleBatchProcessor;
360
361 impl Processor for SimpleBatchProcessor {
362 type Key = String;
363 type Input = String;
364 type Output = String;
365 type Error = String;
366 type Resources = ();
367
368 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
369 Ok(())
370 }
371
372 async fn process(
373 &self,
374 _key: String,
375 inputs: impl Iterator<Item = String> + Send,
376 _resources: (),
377 ) -> Result<Vec<String>, String> {
378 Ok(inputs.map(|s| s + " processed").collect())
379 }
380 }
381
382 #[tokio::test]
383 async fn simple_test_over_channel() {
384 let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
385 "test".to_string(),
386 SimpleBatchProcessor,
387 Limits::builder().max_batch_size(2).build(),
388 BatchingPolicy::Size,
389 );
390
391 let rx1 = {
392 let (tx, rx) = oneshot::channel();
393 item_tx
394 .send(BatchItem {
395 key: "K1".to_string(),
396 input: "I1".to_string(),
397 submitted_at: tokio::time::Instant::now(),
398 tx,
399 requesting_span: Span::none(),
400 })
401 .await
402 .unwrap();
403
404 rx
405 };
406
407 let rx2 = {
408 let (tx, rx) = oneshot::channel();
409 item_tx
410 .send(BatchItem {
411 key: "K1".to_string(),
412 input: "I2".to_string(),
413 submitted_at: tokio::time::Instant::now(),
414 tx,
415 requesting_span: Span::none(),
416 })
417 .await
418 .unwrap();
419
420 rx
421 };
422
423 let o1 = rx1.await.unwrap().0.unwrap();
424 let o2 = rx2.await.unwrap().0.unwrap();
425
426 assert_eq!(o1, "I1 processed".to_string());
427 assert_eq!(o2, "I2 processed".to_string());
428 }
429}