1use std::time::Duration;
45
46use tokio::sync::mpsc;
47
48#[derive(Debug, Clone)]
50pub struct AccumulatorConfig {
51 pub channel_capacity: usize,
53 pub max_items: usize,
55 pub max_bytes: usize,
57 pub max_wait: Duration,
59}
60
61impl Default for AccumulatorConfig {
62 fn default() -> Self {
63 Self {
64 channel_capacity: 10_000,
65 max_items: 100,
66 max_bytes: 1024 * 1024, max_wait: Duration::from_millis(10),
68 }
69 }
70}
71
72#[derive(Clone)]
74pub struct BatchAccumulator<T> {
75 tx: mpsc::Sender<(T, usize)>, }
77
78pub struct BatchDrainer<T> {
80 rx: mpsc::Receiver<(T, usize)>,
81 config: AccumulatorConfig,
82 buffer: Vec<T>,
83 buffer_bytes: usize,
84}
85
86#[derive(Debug, thiserror::Error)]
88#[error("accumulator full -- backpressure active ({capacity} items buffered)")]
89pub struct AccumulatorFull {
90 pub capacity: usize,
91}
92
93impl<T: Send + 'static> BatchAccumulator<T> {
94 #[must_use]
100 pub fn new(config: AccumulatorConfig) -> (Self, BatchDrainer<T>) {
101 let (tx, rx) = mpsc::channel(config.channel_capacity);
102 let drainer = BatchDrainer {
103 rx,
104 buffer: Vec::with_capacity(config.max_items),
105 buffer_bytes: 0,
106 config: config.clone(),
107 };
108 (Self { tx }, drainer)
109 }
110
111 pub async fn push(&self, item: T, byte_size: usize) -> Result<(), AccumulatorFull> {
119 self.tx
120 .try_send((item, byte_size))
121 .map_err(|_| AccumulatorFull {
122 capacity: self.tx.max_capacity(),
127 })
128 }
129
130 #[must_use]
132 pub fn is_closed(&self) -> bool {
133 self.tx.is_closed()
134 }
135}
136
137#[cfg(feature = "transport")]
148#[must_use]
149pub fn records_into_work_batch<T: crate::transport::CommitToken>(
150 records: Vec<crate::transport::Record>,
151 commit_tokens: Vec<T>,
152) -> crate::transport::WorkBatch<T> {
153 crate::transport::WorkBatch::new(records, commit_tokens)
154}
155
156impl<T> BatchDrainer<T> {
157 pub async fn next_batch(&mut self) -> Vec<T> {
162 if self.threshold_met() {
164 return self.take_buffer();
165 }
166
167 let sleep = tokio::time::sleep(self.config.max_wait);
175 tokio::pin!(sleep);
176
177 loop {
178 tokio::select! {
179 biased;
180
181 () = &mut sleep => {
183 if self.buffer.is_empty() {
184 sleep
187 .as_mut()
188 .reset(tokio::time::Instant::now() + self.config.max_wait);
189 continue;
190 }
191 return self.take_buffer();
192 }
193
194 item = self.rx.recv() => {
196 match item {
197 Some((val, size)) => {
198 self.buffer_bytes += size;
199 self.buffer.push(val);
200 if self.threshold_met() {
201 return self.take_buffer();
202 }
203 }
204 None => {
205 return self.take_buffer();
207 }
208 }
209 }
210 }
211 }
212 }
213
214 pub fn drain_remaining(&mut self) -> Vec<T> {
216 while let Ok((val, size)) = self.rx.try_recv() {
218 self.buffer_bytes += size;
219 self.buffer.push(val);
220 }
221 self.take_buffer()
222 }
223
224 fn threshold_met(&self) -> bool {
225 self.buffer.len() >= self.config.max_items || self.buffer_bytes >= self.config.max_bytes
226 }
227
228 fn take_buffer(&mut self) -> Vec<T> {
229 self.buffer_bytes = 0;
230 std::mem::take(&mut self.buffer)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[tokio::test]
239 async fn test_drain_on_item_count() {
240 let config = AccumulatorConfig {
241 channel_capacity: 100,
242 max_items: 5,
243 max_bytes: usize::MAX,
244 max_wait: Duration::from_mins(1), };
246 let (acc, mut drainer) = BatchAccumulator::new(config);
247
248 for i in 0..5 {
250 acc.push(i, 1).await.unwrap();
251 }
252
253 let batch = drainer.next_batch().await;
254 assert_eq!(batch.len(), 5);
255 assert_eq!(batch, vec![0, 1, 2, 3, 4]);
256 }
257
258 #[tokio::test]
259 async fn test_drain_on_byte_threshold() {
260 let config = AccumulatorConfig {
261 channel_capacity: 100,
262 max_items: 1000, max_bytes: 10, max_wait: Duration::from_mins(1),
265 };
266 let (acc, mut drainer) = BatchAccumulator::new(config);
267
268 for i in 0..4 {
270 acc.push(i, 3).await.unwrap();
271 }
272
273 let batch = drainer.next_batch().await;
274 assert_eq!(batch.len(), 4);
275 }
276
277 #[tokio::test]
278 async fn test_drain_on_time_threshold() {
279 let config = AccumulatorConfig {
280 channel_capacity: 100,
281 max_items: 1000,
282 max_bytes: usize::MAX,
283 max_wait: Duration::from_millis(50), };
285 let (acc, mut drainer) = BatchAccumulator::new(config);
286
287 acc.push(1, 1).await.unwrap();
289 acc.push(2, 1).await.unwrap();
290
291 let batch = drainer.next_batch().await;
293 assert_eq!(batch.len(), 2);
294 }
295
296 #[tokio::test]
297 async fn test_backpressure_when_full() {
298 let config = AccumulatorConfig {
299 channel_capacity: 3,
300 max_items: 100,
301 max_bytes: usize::MAX,
302 max_wait: Duration::from_mins(1),
303 };
304 let (acc, _drainer) = BatchAccumulator::<i32>::new(config);
305
306 acc.push(1, 1).await.unwrap();
308 acc.push(2, 1).await.unwrap();
309 acc.push(3, 1).await.unwrap();
310
311 let result = acc.push(4, 1).await;
313 assert!(result.is_err());
314 }
315
316 #[tokio::test]
317 async fn test_backpressure_error_reports_configured_capacity() {
318 let config = AccumulatorConfig {
319 channel_capacity: 3,
320 max_items: 100,
321 max_bytes: usize::MAX,
322 max_wait: Duration::from_mins(1),
323 };
324 let (acc, _drainer) = BatchAccumulator::<i32>::new(config);
325
326 acc.push(1, 1).await.unwrap();
327 acc.push(2, 1).await.unwrap();
328 acc.push(3, 1).await.unwrap();
329
330 let err = acc.push(4, 1).await.expect_err("channel full -> error");
331 assert_eq!(err.capacity, 3);
333 }
334
335 #[tokio::test(start_paused = true)]
342 async fn test_trickle_traffic_flushes_on_fixed_deadline() {
343 let config = AccumulatorConfig {
344 channel_capacity: 100,
345 max_items: 1000, max_bytes: usize::MAX, max_wait: Duration::from_millis(100),
348 };
349 let (acc, mut drainer) = BatchAccumulator::<i32>::new(config);
350
351 tokio::spawn(async move {
353 for i in 0..6 {
354 acc.push(i, 1).await.unwrap();
355 tokio::time::sleep(Duration::from_millis(40)).await;
356 }
357 });
358
359 let batch = drainer.next_batch().await;
360 assert!(
364 !batch.is_empty(),
365 "should flush items buffered within the window"
366 );
367 assert!(
368 batch.len() < 6,
369 "expected a partial flush at the fixed deadline, got all {} items \
370 (timer reset on each arrival?)",
371 batch.len()
372 );
373 }
374
375 #[tokio::test]
376 async fn test_shutdown_drains_remaining() {
377 let config = AccumulatorConfig {
378 channel_capacity: 100,
379 max_items: 1000,
380 max_bytes: usize::MAX,
381 max_wait: Duration::from_mins(1),
382 };
383 let (acc, mut drainer) = BatchAccumulator::new(config);
384
385 acc.push(10, 1).await.unwrap();
386 acc.push(20, 1).await.unwrap();
387
388 drop(acc);
390
391 let batch = drainer.next_batch().await;
393 assert_eq!(batch, vec![10, 20]);
394
395 let batch = drainer.next_batch().await;
397 assert!(batch.is_empty());
398 }
399
400 #[tokio::test]
401 async fn test_multiple_batches() {
402 let config = AccumulatorConfig {
403 channel_capacity: 100,
404 max_items: 3,
405 max_bytes: usize::MAX,
406 max_wait: Duration::from_mins(1),
407 };
408 let (acc, mut drainer) = BatchAccumulator::new(config);
409
410 for i in 0..7 {
412 acc.push(i, 1).await.unwrap();
413 }
414 drop(acc); let b1 = drainer.next_batch().await;
417 assert_eq!(b1.len(), 3);
418
419 let b2 = drainer.next_batch().await;
420 assert_eq!(b2.len(), 3);
421
422 let b3 = drainer.next_batch().await;
423 assert_eq!(b3.len(), 1); let b4 = drainer.next_batch().await;
426 assert!(b4.is_empty()); }
428
429 #[tokio::test]
430 async fn test_push_handle_is_clone() {
431 let config = AccumulatorConfig::default();
432 let (acc, mut drainer) = BatchAccumulator::new(config);
433
434 let acc2 = acc.clone();
435
436 acc.push(1, 1).await.unwrap();
437 acc2.push(2, 1).await.unwrap();
438
439 drop(acc);
440 drop(acc2);
441
442 let batch = drainer.next_batch().await;
443 assert_eq!(batch.len(), 2);
444 }
445
446 #[tokio::test]
447 async fn test_drain_remaining_on_shutdown() {
448 let config = AccumulatorConfig {
449 channel_capacity: 100,
450 max_items: 1000,
451 max_bytes: usize::MAX,
452 max_wait: Duration::from_mins(1),
453 };
454 let (acc, mut drainer) = BatchAccumulator::new(config);
455
456 acc.push(1, 1).await.unwrap();
457 acc.push(2, 1).await.unwrap();
458 acc.push(3, 1).await.unwrap();
459 drop(acc);
460
461 let remaining = drainer.drain_remaining();
462 assert_eq!(remaining, vec![1, 2, 3]);
463 }
464
465 #[tokio::test]
466 async fn test_empty_drain_returns_empty() {
467 let config = AccumulatorConfig::default();
468 let (_acc, mut drainer) = BatchAccumulator::<i32>::new(config);
469
470 let remaining = drainer.drain_remaining();
471 assert!(remaining.is_empty());
472 }
473
474 #[cfg(feature = "transport")]
476 #[tokio::test]
477 async fn test_records_drain_into_work_batch() {
478 use crate::transport::{CommitToken, PayloadFormat, Record, RecordMeta};
479 use bytes::Bytes;
480
481 #[derive(Debug, Clone)]
482 struct PushTok(u64);
483 impl std::fmt::Display for PushTok {
484 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485 write!(f, "push-{}", self.0)
486 }
487 }
488 impl CommitToken for PushTok {}
489
490 let record = |payload: &'static [u8]| Record {
491 payload: Bytes::from_static(payload),
492 key: None,
493 headers: vec![],
494 metadata: RecordMeta {
495 timestamp_ms: None,
496 format: PayloadFormat::Json,
497 },
498 };
499
500 let config = AccumulatorConfig {
501 channel_capacity: 100,
502 max_items: 3,
503 max_bytes: usize::MAX,
504 max_wait: Duration::from_mins(1),
505 };
506 let (acc, mut drainer) = BatchAccumulator::<Record>::new(config);
507 acc.push(record(b"{\"a\":1}"), 7).await.unwrap();
508 acc.push(record(b"{\"b\":2}"), 7).await.unwrap();
509 acc.push(record(b"{\"c\":3}"), 7).await.unwrap();
510
511 let block = drainer.next_batch().await;
512 assert_eq!(block.len(), 3);
513
514 let tokens = vec![PushTok(1), PushTok(2)];
517 let wb = records_into_work_batch(block, tokens);
518 assert_eq!(wb.record_count(), 3);
519 assert_eq!(wb.commit_tokens.len(), 2);
520 assert!(wb.dlq_entries.is_empty());
521 }
522}