1use std::{fmt::Debug, time::Duration};
2
3use crate::{Processor, batch_inner::Generation, batch_queue::BatchQueue, error::RejectionReason};
4
5#[derive(Debug)]
7#[non_exhaustive]
8pub enum BatchingPolicy {
9 Immediate,
25
26 Size,
30
31 Duration(Duration, OnFull),
38}
39
40#[derive(Debug, Clone, Copy)]
52#[non_exhaustive]
53pub struct Limits {
54 pub(crate) max_batch_size: usize,
55 pub(crate) max_key_concurrency: usize,
56}
57
58#[derive(Debug)]
60#[non_exhaustive]
61pub enum OnFull {
62 Process,
65
66 Reject,
68}
69
70#[derive(Debug)]
71pub(crate) enum OnAdd {
72 AddAndProcess,
73 AddAndAcquireResources,
74 AddAndProcessAfter(Duration),
75 Reject(RejectionReason),
76 Add,
77}
78
79#[derive(Debug)]
80pub(crate) enum ProcessAction {
81 Process,
82 DoNothing,
83}
84
85impl Limits {
86 pub fn with_max_batch_size(self, max: usize) -> Self {
88 Self {
89 max_batch_size: max,
90 ..self
91 }
92 }
93
94 pub fn with_max_key_concurrency(self, max: usize) -> Self {
96 Self {
97 max_key_concurrency: max,
98 ..self
99 }
100 }
101}
102
103impl Default for Limits {
104 fn default() -> Self {
105 Self {
106 max_batch_size: 100,
107 max_key_concurrency: 10,
108 }
109 }
110}
111
112impl BatchingPolicy {
113 pub(crate) fn on_add<P: Processor>(&self, batch_queue: &BatchQueue<P>) -> OnAdd {
115 if let Some(rejection) = self.should_reject(batch_queue) {
116 return OnAdd::Reject(rejection);
117 }
118
119 self.determine_action(batch_queue)
120 }
121
122 fn should_reject<P: Processor>(&self, batch_queue: &BatchQueue<P>) -> Option<RejectionReason> {
124 if batch_queue.is_full() {
125 if batch_queue.at_max_processing_capacity() {
126 Some(RejectionReason::MaxConcurrency)
127 } else {
128 Some(RejectionReason::BatchFull)
129 }
130 } else {
131 None
132 }
133 }
134
135 fn determine_action<P: Processor>(&self, batch_queue: &BatchQueue<P>) -> OnAdd {
137 match self {
138 Self::Size if batch_queue.last_space_in_batch() => self.add_or_process(batch_queue),
139
140 Self::Duration(_dur, on_full) if batch_queue.last_space_in_batch() => {
141 if matches!(on_full, OnFull::Process) {
142 self.add_or_process(batch_queue)
143 } else {
144 OnAdd::Add
145 }
146 }
147
148 Self::Duration(dur, _on_full) if batch_queue.adding_to_new_batch() => {
149 OnAdd::AddAndProcessAfter(*dur)
150 }
151
152 Self::Immediate => {
153 if batch_queue.at_max_processing_capacity() {
154 OnAdd::Add
155 } else {
156 if batch_queue.adding_to_new_batch() && !batch_queue.at_max_acquiring_capacity()
158 {
159 OnAdd::AddAndAcquireResources
160 } else {
161 OnAdd::Add
162 }
163 }
164 }
165
166 BatchingPolicy::Size | BatchingPolicy::Duration(_, _) => OnAdd::Add,
167 }
168 }
169
170 fn add_or_process<P: Processor>(&self, batch_queue: &BatchQueue<P>) -> OnAdd {
172 if batch_queue.at_max_processing_capacity() {
173 OnAdd::Add
175 } else {
176 OnAdd::AddAndProcess
177 }
178 }
179
180 pub(crate) fn on_timeout<P: Processor>(
181 &self,
182 generation: Generation,
183 batch_queue: &BatchQueue<P>,
184 ) -> ProcessAction {
185 if batch_queue.at_max_processing_capacity() {
186 ProcessAction::DoNothing
187 } else {
188 Self::process_generation_if_ready(generation, batch_queue)
189 }
190 }
191
192 pub(crate) fn on_resources_acquired<P: Processor>(
193 &self,
194 generation: Generation,
195 batch_queue: &BatchQueue<P>,
196 ) -> ProcessAction {
197 if batch_queue.at_max_processing_capacity() {
198 ProcessAction::DoNothing
199 } else {
200 Self::process_generation_if_ready(generation, batch_queue)
201 }
202 }
203
204 pub(crate) fn on_finish<P: Processor>(&self, batch_queue: &BatchQueue<P>) -> ProcessAction {
205 if dbg!(batch_queue.at_max_processing_capacity()) {
206 return ProcessAction::DoNothing;
207 }
208 match self {
209 BatchingPolicy::Immediate => Self::process_if_any_ready(batch_queue),
210
211 BatchingPolicy::Duration(_, _) if batch_queue.has_next_batch_timeout_expired() => {
212 ProcessAction::Process
213 }
214
215 BatchingPolicy::Duration(_, _) | BatchingPolicy::Size => {
216 if batch_queue.is_next_batch_full() {
217 ProcessAction::Process
218 } else {
219 ProcessAction::DoNothing
220 }
221 }
222 }
223 }
224
225 fn process_generation_if_ready<P: Processor>(
226 generation: Generation,
227 batch_queue: &BatchQueue<P>,
228 ) -> ProcessAction {
229 if batch_queue.is_generation_ready(generation) {
230 ProcessAction::Process
231 } else {
232 ProcessAction::DoNothing
233 }
234 }
235
236 fn process_if_any_ready<P: Processor>(batch_queue: &BatchQueue<P>) -> ProcessAction {
237 if batch_queue.has_batch_ready() {
238 ProcessAction::Process
239 } else {
240 ProcessAction::DoNothing
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use std::sync::{Arc, atomic::AtomicUsize};
248
249 use assert_matches::assert_matches;
250 use tokio::sync::{Mutex, Notify, futures::OwnedNotified, mpsc};
251 use tracing::Span;
252
253 use crate::{Processor, batch::BatchItem, batch_queue::BatchQueue, worker::Message};
254
255 use super::*;
256
257 #[derive(Clone)]
258 struct TestProcessor;
259
260 impl Processor for TestProcessor {
261 type Key = String;
262 type Input = String;
263 type Output = String;
264 type Error = String;
265 type Resources = ();
266
267 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
268 Ok(())
269 }
270
271 async fn process(
272 &self,
273 _key: String,
274 inputs: impl Iterator<Item = String> + Send,
275 _resources: (),
276 ) -> Result<Vec<String>, String> {
277 Ok(inputs.collect())
278 }
279 }
280
281 #[derive(Default, Clone)]
282 struct ControlledProcessor {
283 acquire_locks: Vec<Arc<Mutex<()>>>,
285 acquire_counter: Arc<AtomicUsize>,
286 }
287
288 impl Processor for ControlledProcessor {
289 type Key = ();
290 type Input = OwnedNotified;
291 type Output = ();
292 type Error = String;
293 type Resources = ();
294
295 async fn acquire_resources(&self, _key: ()) -> Result<(), String> {
296 let n = self
297 .acquire_counter
298 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
299 if let Some(lock) = self.acquire_locks.get(n) {
300 let _guard = lock.lock().await;
301 }
302 Ok(())
303 }
304
305 async fn process(
306 &self,
307 _key: (),
308 inputs: impl Iterator<Item = OwnedNotified> + Send,
309 _resources: (),
310 ) -> Result<Vec<()>, String> {
311 let mut outputs = vec![];
312 for item in inputs {
313 item.await;
314 outputs.push(());
315 }
316 Ok(outputs)
317 }
318 }
319
320 fn new_item<P: Processor>(key: P::Key, input: P::Input) -> BatchItem<P> {
321 let (tx, _rx) = tokio::sync::oneshot::channel();
322 BatchItem {
323 key,
324 input,
325 tx,
326 requesting_span: Span::none(),
327 }
328 }
329
330 #[test]
331 fn limits_builder_methods() {
332 let limits = Limits::default()
333 .with_max_batch_size(50)
334 .with_max_key_concurrency(5);
335
336 assert_eq!(limits.max_batch_size, 50);
337 assert_eq!(limits.max_key_concurrency, 5);
338 }
339
340 #[test]
341 fn size_policy_waits_for_full_batch_when_empty() {
342 let limits = Limits::default()
343 .with_max_batch_size(3)
344 .with_max_key_concurrency(2);
345 let queue = BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
346
347 let policy = BatchingPolicy::Size;
348 let result = policy.on_add(&queue);
349
350 assert_matches!(result, OnAdd::Add);
351 }
352
353 #[test]
354 fn immediate_policy_acquires_resources_when_empty() {
355 let limits = Limits::default()
356 .with_max_batch_size(3)
357 .with_max_key_concurrency(2);
358 let queue = BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
359
360 let policy = BatchingPolicy::Immediate;
361 let result = policy.on_add(&queue);
362
363 assert_matches!(result, OnAdd::AddAndAcquireResources);
364 }
365
366 #[test]
367 fn duration_policy_schedules_timeout_when_empty() {
368 let limits = Limits::default().with_max_batch_size(2);
369 let queue = BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
370
371 let duration = Duration::from_millis(100);
372 let policy = BatchingPolicy::Duration(duration, OnFull::Process);
373 let result = policy.on_add(&queue);
374
375 assert_matches!(result, OnAdd::AddAndProcessAfter(d) if d == duration);
376 }
377
378 #[test]
379 fn size_policy_processes_when_batch_becomes_full() {
380 let limits = Limits::default().with_max_batch_size(2);
381 let mut queue =
382 BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
383
384 queue.push(new_item("key".to_string(), "item1".to_string()));
386
387 let policy = BatchingPolicy::Size;
388 let result = policy.on_add(&queue);
389
390 assert_matches!(result, OnAdd::AddAndProcess);
392 }
393
394 #[tokio::test]
395 async fn immediate_policy_adds_when_at_max_capacity() {
396 let limits = Limits::default()
397 .with_max_batch_size(1)
398 .with_max_key_concurrency(1);
399 let mut queue =
400 BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
401
402 queue.push(new_item("key".to_string(), "item1".to_string()));
403
404 let batch = queue.take_next_ready_batch().unwrap();
405
406 let (on_finished, _rx) = tokio::sync::mpsc::channel(1);
407 batch.process(TestProcessor, on_finished);
408
409 let policy = BatchingPolicy::Immediate;
410
411 let result = policy.on_add(&queue);
412 assert_matches!(result, OnAdd::Add);
413 }
414
415 #[tokio::test]
416 async fn size_policy_rejects_when_full_and_at_capacity() {
417 let limits = Limits::default()
418 .with_max_batch_size(1)
419 .with_max_key_concurrency(1);
420 let mut queue =
421 BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
422
423 queue.push(new_item("key".to_string(), "item1".to_string()));
425
426 let batch = queue.take_next_ready_batch().unwrap();
428 let (on_finished, _rx) = tokio::sync::mpsc::channel(1);
429 batch.process(TestProcessor, on_finished);
430
431 queue.push(new_item("key".to_string(), "item2".to_string()));
433
434 let policy = BatchingPolicy::Size;
436 let result = policy.on_add(&queue);
437
438 assert_matches!(result, OnAdd::Reject(RejectionReason::MaxConcurrency));
439 }
440
441 #[test]
442 fn duration_policy_onfull_reject_rejects_when_full_but_not_processing() {
443 let limits = Limits::default()
444 .with_max_batch_size(1)
445 .with_max_key_concurrency(1);
446 let mut queue =
447 BatchQueue::<TestProcessor>::new("test".to_string(), "key".to_string(), limits);
448
449 queue.push(new_item("key".to_string(), "item1".to_string()));
451
452 let policy = BatchingPolicy::Duration(Duration::from_millis(100), OnFull::Reject);
454 let result = policy.on_add(&queue);
455
456 assert_matches!(result, OnAdd::Reject(RejectionReason::BatchFull));
457 }
458
459 #[tokio::test]
460 async fn scenario_duration_policy_timeout_while_processing() {
461 let processor = ControlledProcessor::default();
468 let limits = Limits::default()
469 .with_max_batch_size(2)
470 .with_max_key_concurrency(1);
471 let mut queue = BatchQueue::<ControlledProcessor>::new("test".to_string(), (), limits);
472 let policy = BatchingPolicy::Duration(Duration::from_millis(100), OnFull::Process);
473
474 let result = policy.on_add(&queue);
476 assert_matches!(result, OnAdd::AddAndProcessAfter(_));
477 let notify1 = Arc::new(Notify::new());
478 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
479
480 let result = policy.on_add(&queue);
481 assert_matches!(result, OnAdd::AddAndProcess); queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
483
484 let first_batch = queue.take_next_ready_batch().unwrap();
486 let (on_finished, mut rx) = mpsc::channel(1);
487 first_batch.process(processor, on_finished);
488
489 let result = policy.on_add(&queue);
491 assert_matches!(result, OnAdd::AddAndProcessAfter(_)); let notify2 = Arc::new(Notify::new());
493 queue.push(new_item((), notify2.notified_owned()));
494 let (tx, mut timeout_rx) = mpsc::channel(1);
495 queue.process_after(Duration::from_millis(1), tx);
496
497 let msg = timeout_rx.recv().await.unwrap(); let second_gen = Generation::default().next();
500 assert_matches!(msg, Message::TimedOut(_, generation)=> {
501 assert_eq!(generation, second_gen);
502 });
503 let result = policy.on_timeout(second_gen, &queue);
504 assert_matches!(result, ProcessAction::DoNothing); notify1.notify_waiters(); let msg = rx.recv().await.unwrap();
509 assert_matches!(msg, Message::Finished(_));
510
511 let result = policy.on_finish(&queue);
512 assert_matches!(result, ProcessAction::Process); }
514
515 #[tokio::test]
516 async fn scenario_immediate_policy_processes_after_finish() {
517 let processor = ControlledProcessor::default();
520 let limits = Limits::default()
521 .with_max_batch_size(2)
522 .with_max_key_concurrency(1);
523 let mut queue = BatchQueue::<ControlledProcessor>::new("test".to_string(), (), limits);
524 let policy = BatchingPolicy::Immediate;
525
526 let notify1 = Arc::new(Notify::new());
528 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
529 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
530
531 let first_batch = queue.take_next_ready_batch().unwrap();
533 let (on_finished, mut rx) = mpsc::channel(1);
534 first_batch.process(processor, on_finished);
535
536 let result = policy.on_add(&queue);
538 assert_matches!(result, OnAdd::Add); let notify2 = Arc::new(Notify::new());
540 queue.push(new_item((), Arc::clone(¬ify2).notified_owned()));
541
542 notify1.notify_waiters(); let msg = rx.recv().await.unwrap();
545 assert_matches!(msg, Message::Finished(_));
546
547 let result = policy.on_finish(&queue);
548 assert_matches!(result, ProcessAction::Process); }
550
551 #[tokio::test]
552 async fn scenario_size_policy_waits_for_full_batch() {
553 let processor = ControlledProcessor::default();
556 let limits = Limits::default()
557 .with_max_batch_size(3)
558 .with_max_key_concurrency(2);
559 let mut queue = BatchQueue::<ControlledProcessor>::new("test".to_string(), (), limits);
560 let policy = BatchingPolicy::Size;
561
562 let notify1 = Arc::new(Notify::new());
564 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
565 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
566 queue.push(new_item((), Arc::clone(¬ify1).notified_owned()));
567
568 let first_batch = queue.take_next_ready_batch().unwrap();
569 let (on_finished, mut rx) = mpsc::channel(1);
570 first_batch.process(processor, on_finished);
571
572 let notify2 = Arc::new(Notify::new());
574 queue.push(new_item((), Arc::clone(¬ify2).notified_owned()));
575 queue.push(new_item((), Arc::clone(¬ify2).notified_owned())); notify1.notify_waiters(); let msg = rx.recv().await.unwrap();
580 assert_matches!(msg, Message::Finished(_));
581
582 let result = policy.on_finish(&queue);
583 assert_matches!(result, ProcessAction::DoNothing); let result = policy.on_add(&queue);
587 assert_matches!(result, OnAdd::AddAndProcess); }
589
590 #[tokio::test]
591 async fn scenario_out_of_order_acquisition() {
592 let mut processor = ControlledProcessor::default();
595 let limits = Limits::default()
596 .with_max_batch_size(2)
597 .with_max_key_concurrency(2);
598 let mut queue = BatchQueue::<ControlledProcessor>::new("test".to_string(), (), limits);
599 let policy = BatchingPolicy::Immediate;
600
601 let result = policy.on_add(&queue);
603 assert_matches!(result, OnAdd::AddAndAcquireResources);
604 queue.push(new_item((), Arc::new(Notify::new()).notified_owned()));
605
606 let acquire_lock1 = Arc::new(Mutex::new(()));
607 let lock_guard1 = acquire_lock1.lock().await; processor.acquire_locks.push(Arc::clone(&acquire_lock1));
609 let (tx, mut acquired1) = mpsc::channel(1);
610 queue.pre_acquire_resources(processor.clone(), tx);
611
612 let result = policy.on_add(&queue);
614 assert_matches!(result, OnAdd::Add);
615 queue.push(new_item((), Arc::new(Notify::new()).notified_owned()));
616
617 let result = policy.on_add(&queue);
621 assert_matches!(result, OnAdd::AddAndAcquireResources); queue.push(new_item((), Arc::new(Notify::new()).notified_owned()));
623
624 let acquire_lock2 = Arc::new(Mutex::new(()));
625 let lock_guard2 = acquire_lock2.lock().await; processor.acquire_locks.push(Arc::clone(&acquire_lock2));
627 let (tx, mut acquired2) = mpsc::channel(1);
628 queue.pre_acquire_resources(processor.clone(), tx);
629
630 drop(lock_guard2); let msg = acquired2.recv().await.unwrap();
634 let second_gen = Generation::default().next();
635 assert_matches!(msg, Message::ResourcesAcquired(_, generation) => {
636 assert_eq!(generation, second_gen);
637 });
638
639 let result = policy.on_resources_acquired(second_gen, &queue);
640 assert_matches!(result, ProcessAction::Process); drop(lock_guard1);
644
645 let msg = acquired1.recv().await.unwrap();
646 let first_gen = Generation::default();
647 assert_matches!(msg, Message::ResourcesAcquired(_, generation) => {
648 assert_eq!(generation, first_gen);
649 });
650
651 let result = policy.on_resources_acquired(first_gen, &queue);
652 assert_matches!(result, ProcessAction::Process);
653 }
654}