1use std::time::Duration;
45
46use tokio::sync::mpsc;
47
48#[derive(Debug, Clone)]
50pub struct AccumulatorConfig {
51 pub channel_capacity: usize,
53 pub max_items: usize,
55 pub max_bytes: usize,
57 pub max_wait: Duration,
59}
60
61impl Default for AccumulatorConfig {
62 fn default() -> Self {
63 Self {
64 channel_capacity: 10_000,
65 max_items: 100,
66 max_bytes: 1024 * 1024, max_wait: Duration::from_millis(10),
68 }
69 }
70}
71
72#[derive(Clone)]
74pub struct BatchAccumulator<T> {
75 tx: mpsc::Sender<(T, usize)>, }
77
78pub struct BatchDrainer<T> {
80 rx: mpsc::Receiver<(T, usize)>,
81 config: AccumulatorConfig,
82 buffer: Vec<T>,
83 buffer_bytes: usize,
84}
85
86#[derive(Debug, thiserror::Error)]
88#[error("accumulator full -- backpressure active ({capacity} items buffered)")]
89pub struct AccumulatorFull {
90 pub capacity: usize,
91}
92
93impl<T: Send + 'static> BatchAccumulator<T> {
94 #[must_use]
100 pub fn new(config: AccumulatorConfig) -> (Self, BatchDrainer<T>) {
101 let (tx, rx) = mpsc::channel(config.channel_capacity);
102 let drainer = BatchDrainer {
103 rx,
104 buffer: Vec::with_capacity(config.max_items),
105 buffer_bytes: 0,
106 config: config.clone(),
107 };
108 (Self { tx }, drainer)
109 }
110
111 pub async fn push(&self, item: T, byte_size: usize) -> Result<(), AccumulatorFull> {
119 self.tx
120 .try_send((item, byte_size))
121 .map_err(|_| AccumulatorFull {
122 capacity: self.tx.capacity(),
123 })
124 }
125
126 #[must_use]
128 pub fn is_closed(&self) -> bool {
129 self.tx.is_closed()
130 }
131}
132
133#[cfg(feature = "transport")]
150#[must_use]
151pub fn records_into_work_batch<T: crate::transport::CommitToken>(
152 records: Vec<crate::transport::Record>,
153 commit_tokens: Vec<T>,
154) -> crate::transport::WorkBatch<T> {
155 crate::transport::WorkBatch::new(records, commit_tokens)
156}
157
158impl<T> BatchDrainer<T> {
159 pub async fn next_batch(&mut self) -> Vec<T> {
164 if self.threshold_met() {
166 return self.take_buffer();
167 }
168
169 loop {
171 let timeout = tokio::time::sleep(self.config.max_wait);
172
173 tokio::select! {
174 biased;
175
176 () = timeout => {
178 if self.buffer.is_empty() {
179 continue;
181 }
182 return self.take_buffer();
183 }
184
185 item = self.rx.recv() => {
187 match item {
188 Some((val, size)) => {
189 self.buffer_bytes += size;
190 self.buffer.push(val);
191 if self.threshold_met() {
192 return self.take_buffer();
193 }
194 }
195 None => {
196 return self.take_buffer();
198 }
199 }
200 }
201 }
202 }
203 }
204
205 pub fn drain_remaining(&mut self) -> Vec<T> {
207 while let Ok((val, size)) = self.rx.try_recv() {
209 self.buffer_bytes += size;
210 self.buffer.push(val);
211 }
212 self.take_buffer()
213 }
214
215 fn threshold_met(&self) -> bool {
216 self.buffer.len() >= self.config.max_items || self.buffer_bytes >= self.config.max_bytes
217 }
218
219 fn take_buffer(&mut self) -> Vec<T> {
220 self.buffer_bytes = 0;
221 std::mem::take(&mut self.buffer)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test]
230 async fn test_drain_on_item_count() {
231 let config = AccumulatorConfig {
232 channel_capacity: 100,
233 max_items: 5,
234 max_bytes: usize::MAX,
235 max_wait: Duration::from_mins(1), };
237 let (acc, mut drainer) = BatchAccumulator::new(config);
238
239 for i in 0..5 {
241 acc.push(i, 1).await.unwrap();
242 }
243
244 let batch = drainer.next_batch().await;
245 assert_eq!(batch.len(), 5);
246 assert_eq!(batch, vec![0, 1, 2, 3, 4]);
247 }
248
249 #[tokio::test]
250 async fn test_drain_on_byte_threshold() {
251 let config = AccumulatorConfig {
252 channel_capacity: 100,
253 max_items: 1000, max_bytes: 10, max_wait: Duration::from_mins(1),
256 };
257 let (acc, mut drainer) = BatchAccumulator::new(config);
258
259 for i in 0..4 {
261 acc.push(i, 3).await.unwrap();
262 }
263
264 let batch = drainer.next_batch().await;
265 assert_eq!(batch.len(), 4);
266 }
267
268 #[tokio::test]
269 async fn test_drain_on_time_threshold() {
270 let config = AccumulatorConfig {
271 channel_capacity: 100,
272 max_items: 1000,
273 max_bytes: usize::MAX,
274 max_wait: Duration::from_millis(50), };
276 let (acc, mut drainer) = BatchAccumulator::new(config);
277
278 acc.push(1, 1).await.unwrap();
280 acc.push(2, 1).await.unwrap();
281
282 let batch = drainer.next_batch().await;
284 assert_eq!(batch.len(), 2);
285 }
286
287 #[tokio::test]
288 async fn test_backpressure_when_full() {
289 let config = AccumulatorConfig {
290 channel_capacity: 3,
291 max_items: 100,
292 max_bytes: usize::MAX,
293 max_wait: Duration::from_mins(1),
294 };
295 let (acc, _drainer) = BatchAccumulator::<i32>::new(config);
296
297 acc.push(1, 1).await.unwrap();
299 acc.push(2, 1).await.unwrap();
300 acc.push(3, 1).await.unwrap();
301
302 let result = acc.push(4, 1).await;
304 assert!(result.is_err());
305 }
306
307 #[tokio::test]
308 async fn test_shutdown_drains_remaining() {
309 let config = AccumulatorConfig {
310 channel_capacity: 100,
311 max_items: 1000,
312 max_bytes: usize::MAX,
313 max_wait: Duration::from_mins(1),
314 };
315 let (acc, mut drainer) = BatchAccumulator::new(config);
316
317 acc.push(10, 1).await.unwrap();
318 acc.push(20, 1).await.unwrap();
319
320 drop(acc);
322
323 let batch = drainer.next_batch().await;
325 assert_eq!(batch, vec![10, 20]);
326
327 let batch = drainer.next_batch().await;
329 assert!(batch.is_empty());
330 }
331
332 #[tokio::test]
333 async fn test_multiple_batches() {
334 let config = AccumulatorConfig {
335 channel_capacity: 100,
336 max_items: 3,
337 max_bytes: usize::MAX,
338 max_wait: Duration::from_mins(1),
339 };
340 let (acc, mut drainer) = BatchAccumulator::new(config);
341
342 for i in 0..7 {
344 acc.push(i, 1).await.unwrap();
345 }
346 drop(acc); let b1 = drainer.next_batch().await;
349 assert_eq!(b1.len(), 3);
350
351 let b2 = drainer.next_batch().await;
352 assert_eq!(b2.len(), 3);
353
354 let b3 = drainer.next_batch().await;
355 assert_eq!(b3.len(), 1); let b4 = drainer.next_batch().await;
358 assert!(b4.is_empty()); }
360
361 #[tokio::test]
362 async fn test_push_handle_is_clone() {
363 let config = AccumulatorConfig::default();
364 let (acc, mut drainer) = BatchAccumulator::new(config);
365
366 let acc2 = acc.clone();
367
368 acc.push(1, 1).await.unwrap();
369 acc2.push(2, 1).await.unwrap();
370
371 drop(acc);
372 drop(acc2);
373
374 let batch = drainer.next_batch().await;
375 assert_eq!(batch.len(), 2);
376 }
377
378 #[tokio::test]
379 async fn test_drain_remaining_on_shutdown() {
380 let config = AccumulatorConfig {
381 channel_capacity: 100,
382 max_items: 1000,
383 max_bytes: usize::MAX,
384 max_wait: Duration::from_mins(1),
385 };
386 let (acc, mut drainer) = BatchAccumulator::new(config);
387
388 acc.push(1, 1).await.unwrap();
389 acc.push(2, 1).await.unwrap();
390 acc.push(3, 1).await.unwrap();
391 drop(acc);
392
393 let remaining = drainer.drain_remaining();
394 assert_eq!(remaining, vec![1, 2, 3]);
395 }
396
397 #[tokio::test]
398 async fn test_empty_drain_returns_empty() {
399 let config = AccumulatorConfig::default();
400 let (_acc, mut drainer) = BatchAccumulator::<i32>::new(config);
401
402 let remaining = drainer.drain_remaining();
403 assert!(remaining.is_empty());
404 }
405
406 #[cfg(feature = "transport")]
408 #[tokio::test]
409 async fn test_records_drain_into_work_batch() {
410 use crate::transport::{CommitToken, PayloadFormat, Record, RecordMeta};
411 use bytes::Bytes;
412
413 #[derive(Debug, Clone)]
414 struct PushTok(u64);
415 impl std::fmt::Display for PushTok {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 write!(f, "push-{}", self.0)
418 }
419 }
420 impl CommitToken for PushTok {}
421
422 let record = |payload: &'static [u8]| Record {
423 payload: Bytes::from_static(payload),
424 key: None,
425 headers: vec![],
426 metadata: RecordMeta {
427 timestamp_ms: None,
428 format: PayloadFormat::Json,
429 },
430 };
431
432 let config = AccumulatorConfig {
433 channel_capacity: 100,
434 max_items: 3,
435 max_bytes: usize::MAX,
436 max_wait: Duration::from_mins(1),
437 };
438 let (acc, mut drainer) = BatchAccumulator::<Record>::new(config);
439 acc.push(record(b"{\"a\":1}"), 7).await.unwrap();
440 acc.push(record(b"{\"b\":2}"), 7).await.unwrap();
441 acc.push(record(b"{\"c\":3}"), 7).await.unwrap();
442
443 let block = drainer.next_batch().await;
444 assert_eq!(block.len(), 3);
445
446 let tokens = vec![PushTok(1), PushTok(2)];
449 let wb = records_into_work_batch(block, tokens);
450 assert_eq!(wb.record_count(), 3);
451 assert_eq!(wb.commit_tokens.len(), 2);
452 assert!(wb.dlq_entries.is_empty());
453 }
454}