1use std::time::Duration;
45
46use tokio::sync::mpsc;
47
48#[derive(Debug, Clone)]
50pub struct AccumulatorConfig {
51 pub channel_capacity: usize,
53 pub max_items: usize,
55 pub max_bytes: usize,
57 pub max_wait: Duration,
59}
60
61impl Default for AccumulatorConfig {
62 fn default() -> Self {
63 Self {
64 channel_capacity: 10_000,
65 max_items: 100,
66 max_bytes: 1024 * 1024, max_wait: Duration::from_millis(10),
68 }
69 }
70}
71
72#[derive(Clone)]
74pub struct BatchAccumulator<T> {
75 tx: mpsc::Sender<(T, usize)>, }
77
78pub struct BatchDrainer<T> {
80 rx: mpsc::Receiver<(T, usize)>,
81 config: AccumulatorConfig,
82 buffer: Vec<T>,
83 buffer_bytes: usize,
84}
85
86#[derive(Debug, thiserror::Error)]
88#[error("accumulator full -- backpressure active ({capacity} items buffered)")]
89pub struct AccumulatorFull {
90 pub capacity: usize,
91}
92
93impl<T: Send + 'static> BatchAccumulator<T> {
94 #[must_use]
100 pub fn new(config: AccumulatorConfig) -> (Self, BatchDrainer<T>) {
101 let (tx, rx) = mpsc::channel(config.channel_capacity);
102 let drainer = BatchDrainer {
103 rx,
104 buffer: Vec::with_capacity(config.max_items),
105 buffer_bytes: 0,
106 config: config.clone(),
107 };
108 (Self { tx }, drainer)
109 }
110
111 pub async fn push(&self, item: T, byte_size: usize) -> Result<(), AccumulatorFull> {
119 self.tx
120 .try_send((item, byte_size))
121 .map_err(|_| AccumulatorFull {
122 capacity: self.tx.capacity(),
123 })
124 }
125
126 #[must_use]
128 pub fn is_closed(&self) -> bool {
129 self.tx.is_closed()
130 }
131}
132
133impl<T> BatchDrainer<T> {
134 pub async fn next_batch(&mut self) -> Vec<T> {
139 if self.threshold_met() {
141 return self.take_buffer();
142 }
143
144 loop {
146 let timeout = tokio::time::sleep(self.config.max_wait);
147
148 tokio::select! {
149 biased;
150
151 () = timeout => {
153 if self.buffer.is_empty() {
154 continue;
156 }
157 return self.take_buffer();
158 }
159
160 item = self.rx.recv() => {
162 match item {
163 Some((val, size)) => {
164 self.buffer_bytes += size;
165 self.buffer.push(val);
166 if self.threshold_met() {
167 return self.take_buffer();
168 }
169 }
170 None => {
171 return self.take_buffer();
173 }
174 }
175 }
176 }
177 }
178 }
179
180 pub fn drain_remaining(&mut self) -> Vec<T> {
182 while let Ok((val, size)) = self.rx.try_recv() {
184 self.buffer_bytes += size;
185 self.buffer.push(val);
186 }
187 self.take_buffer()
188 }
189
190 fn threshold_met(&self) -> bool {
191 self.buffer.len() >= self.config.max_items || self.buffer_bytes >= self.config.max_bytes
192 }
193
194 fn take_buffer(&mut self) -> Vec<T> {
195 self.buffer_bytes = 0;
196 std::mem::take(&mut self.buffer)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[tokio::test]
205 async fn test_drain_on_item_count() {
206 let config = AccumulatorConfig {
207 channel_capacity: 100,
208 max_items: 5,
209 max_bytes: usize::MAX,
210 max_wait: Duration::from_mins(1), };
212 let (acc, mut drainer) = BatchAccumulator::new(config);
213
214 for i in 0..5 {
216 acc.push(i, 1).await.unwrap();
217 }
218
219 let batch = drainer.next_batch().await;
220 assert_eq!(batch.len(), 5);
221 assert_eq!(batch, vec![0, 1, 2, 3, 4]);
222 }
223
224 #[tokio::test]
225 async fn test_drain_on_byte_threshold() {
226 let config = AccumulatorConfig {
227 channel_capacity: 100,
228 max_items: 1000, max_bytes: 10, max_wait: Duration::from_mins(1),
231 };
232 let (acc, mut drainer) = BatchAccumulator::new(config);
233
234 for i in 0..4 {
236 acc.push(i, 3).await.unwrap();
237 }
238
239 let batch = drainer.next_batch().await;
240 assert_eq!(batch.len(), 4);
241 }
242
243 #[tokio::test]
244 async fn test_drain_on_time_threshold() {
245 let config = AccumulatorConfig {
246 channel_capacity: 100,
247 max_items: 1000,
248 max_bytes: usize::MAX,
249 max_wait: Duration::from_millis(50), };
251 let (acc, mut drainer) = BatchAccumulator::new(config);
252
253 acc.push(1, 1).await.unwrap();
255 acc.push(2, 1).await.unwrap();
256
257 let batch = drainer.next_batch().await;
259 assert_eq!(batch.len(), 2);
260 }
261
262 #[tokio::test]
263 async fn test_backpressure_when_full() {
264 let config = AccumulatorConfig {
265 channel_capacity: 3,
266 max_items: 100,
267 max_bytes: usize::MAX,
268 max_wait: Duration::from_mins(1),
269 };
270 let (acc, _drainer) = BatchAccumulator::<i32>::new(config);
271
272 acc.push(1, 1).await.unwrap();
274 acc.push(2, 1).await.unwrap();
275 acc.push(3, 1).await.unwrap();
276
277 let result = acc.push(4, 1).await;
279 assert!(result.is_err());
280 }
281
282 #[tokio::test]
283 async fn test_shutdown_drains_remaining() {
284 let config = AccumulatorConfig {
285 channel_capacity: 100,
286 max_items: 1000,
287 max_bytes: usize::MAX,
288 max_wait: Duration::from_mins(1),
289 };
290 let (acc, mut drainer) = BatchAccumulator::new(config);
291
292 acc.push(10, 1).await.unwrap();
293 acc.push(20, 1).await.unwrap();
294
295 drop(acc);
297
298 let batch = drainer.next_batch().await;
300 assert_eq!(batch, vec![10, 20]);
301
302 let batch = drainer.next_batch().await;
304 assert!(batch.is_empty());
305 }
306
307 #[tokio::test]
308 async fn test_multiple_batches() {
309 let config = AccumulatorConfig {
310 channel_capacity: 100,
311 max_items: 3,
312 max_bytes: usize::MAX,
313 max_wait: Duration::from_mins(1),
314 };
315 let (acc, mut drainer) = BatchAccumulator::new(config);
316
317 for i in 0..7 {
319 acc.push(i, 1).await.unwrap();
320 }
321 drop(acc); let b1 = drainer.next_batch().await;
324 assert_eq!(b1.len(), 3);
325
326 let b2 = drainer.next_batch().await;
327 assert_eq!(b2.len(), 3);
328
329 let b3 = drainer.next_batch().await;
330 assert_eq!(b3.len(), 1); let b4 = drainer.next_batch().await;
333 assert!(b4.is_empty()); }
335
336 #[tokio::test]
337 async fn test_push_handle_is_clone() {
338 let config = AccumulatorConfig::default();
339 let (acc, mut drainer) = BatchAccumulator::new(config);
340
341 let acc2 = acc.clone();
342
343 acc.push(1, 1).await.unwrap();
344 acc2.push(2, 1).await.unwrap();
345
346 drop(acc);
347 drop(acc2);
348
349 let batch = drainer.next_batch().await;
350 assert_eq!(batch.len(), 2);
351 }
352
353 #[tokio::test]
354 async fn test_drain_remaining_on_shutdown() {
355 let config = AccumulatorConfig {
356 channel_capacity: 100,
357 max_items: 1000,
358 max_bytes: usize::MAX,
359 max_wait: Duration::from_mins(1),
360 };
361 let (acc, mut drainer) = BatchAccumulator::new(config);
362
363 acc.push(1, 1).await.unwrap();
364 acc.push(2, 1).await.unwrap();
365 acc.push(3, 1).await.unwrap();
366 drop(acc);
367
368 let remaining = drainer.drain_remaining();
369 assert_eq!(remaining, vec![1, 2, 3]);
370 }
371
372 #[tokio::test]
373 async fn test_empty_drain_returns_empty() {
374 let config = AccumulatorConfig::default();
375 let (_acc, mut drainer) = BatchAccumulator::<i32>::new(config);
376
377 let remaining = drainer.drain_remaining();
378 assert!(remaining.is_empty());
379 }
380}