1use linked_hash_map::LinkedHashMap;
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::time::{Duration, Instant};
7use std::vec::Drain;
8
9#[derive(Debug)]
11struct OutstandingBatch<I: Debug> {
12 items: Vec<I>,
13 created: Instant,
14}
15
16impl<I: Debug> OutstandingBatch<I> {
17 fn new() -> OutstandingBatch<I> {
18 OutstandingBatch {
19 items: Vec::new(),
20 created: Instant::now(),
21 }
22 }
23
24 fn from_cache(mut items: Vec<I>) -> OutstandingBatch<I> {
25 items.clear();
27
28 OutstandingBatch {
29 items,
30 created: Instant::now(),
31 }
32 }
33}
34
35#[derive(Debug)]
37pub enum PollResult<K: Debug> {
38 Ready(K),
40 NotReady(Option<Duration>),
44}
45
46#[derive(Debug)]
48pub struct Stats {
49 pub outstanding: usize,
51 pub cached_buffers: usize,
53}
54
55#[derive(Debug)]
63pub struct MultiBufBatch<K: Debug + Ord + Hash, I: Debug> {
64 max_size: usize,
65 max_duration: Duration,
66 cache: Vec<Vec<I>>,
68 outstanding: LinkedHashMap<K, OutstandingBatch<I>>,
70 full: Option<K>,
72}
73
74impl<K, I> MultiBufBatch<K, I>
75where
76 K: Debug + Ord + Hash + Clone,
77 I: Debug,
78{
79 pub fn new(max_size: usize, max_duration: Duration) -> MultiBufBatch<K, I> {
84 assert!(max_size > 0, "MultiBufBatch::new bad max_size");
85
86 MultiBufBatch {
87 max_size,
88 max_duration,
89 cache: Default::default(),
90 outstanding: Default::default(),
91 full: Default::default(),
92 }
93 }
94
95 pub fn poll(&self) -> PollResult<K> {
102 if let Some(key) = &self.full {
104 return PollResult::Ready(key.clone());
105 }
106
107 if let Some((key, batch)) = self.outstanding.front() {
109 let since_start = Instant::now().duration_since(batch.created);
110
111 if since_start >= self.max_duration {
112 return PollResult::Ready(key.clone());
113 }
114
115 return PollResult::NotReady(Some(self.max_duration - since_start));
116 }
117
118 return PollResult::NotReady(None);
119 }
120
121 pub fn append(&mut self, key: K, item: I) {
127 assert!(
128 self.full.is_none(),
129 "MultiBufBatch::append unconsumed full batch"
130 );
131
132 if let Some(batch) = self.outstanding.get_mut(&key) {
134 assert!(
135 batch.items.len() < self.max_size,
136 "MultiBufBatch::append on full batch"
137 );
138
139 batch.items.push(item);
140
141 if batch.items.len() >= self.max_size {
143 self.full = Some(key);
144 }
145 } else {
146 let mut batch = if let Some(items) = self.cache.pop() {
147 OutstandingBatch::from_cache(items)
148 } else {
149 OutstandingBatch::new()
150 };
151
152 batch.items.push(item);
153 self.outstanding.insert(key, batch);
154 }
155 }
156
157 fn move_to_cache(&mut self, key: &K) -> Option<&mut Vec<I>> {
159 if self.full.as_ref().filter(|fkey| *fkey == key).is_some() {
161 self.full.take();
162 }
163
164 let items = self.outstanding.remove(key)?.items;
166 self.cache.push(items);
167 self.cache.last_mut()
168 }
169
170 pub fn outstanding(&self) -> impl Iterator<Item = &K> {
172 self.outstanding.keys()
173 }
174
175 pub fn clear(&mut self, key: &K) {
177 self.move_to_cache(key).map(|items| items.clear());
178 }
179
180 pub fn drain(&mut self, key: &K) -> Option<Drain<I>> {
182 self.move_to_cache(key).map(|items| items.drain(0..))
183 }
184
185 pub fn flush(&mut self) -> Vec<(K, Vec<I>)> {
187 let cache = &mut self.cache;
188 let outstanding = &mut self.outstanding;
189
190 outstanding
191 .entries()
192 .map(|entry| {
193 let key = entry.key().clone();
194
195 let items = entry.remove().items;
197 cache.push(items);
198 let items = cache.last_mut().unwrap();
199
200 let items = items.split_off(0);
202
203 (key, items)
204 })
205 .collect()
206 }
207
208 pub fn get(&self, key: &K) -> Option<&[I]> {
210 self.outstanding
211 .get(key)
212 .map(|batch| batch.items.as_slice())
213 }
214
215 pub fn clear_cache(&mut self) {
217 self.cache.clear();
218 }
219
220 pub fn stats(&self) -> Stats {
222 Stats {
223 outstanding: self.outstanding.len(),
224 cached_buffers: self.cache.len(),
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 pub use super::*;
232 use assert_matches::assert_matches;
233 use std::time::Duration;
234
235 #[test]
236 fn test_batch_poll() {
237 let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
238
239 assert_matches!(batch.poll(), PollResult::NotReady(None));
241
242 batch.append(0, 1);
243
244 assert_matches!(batch.poll(), PollResult::NotReady(Some(_instant)));
246
247 batch.append(0, 2);
248 batch.append(0, 3);
249 batch.append(0, 4);
250
251 assert_matches!(batch.poll(), PollResult::Ready(0) =>
252 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
253 );
254
255 assert_matches!(batch.poll(), PollResult::NotReady(None));
257 }
258
259 #[test]
260 fn test_batch_max_size() {
261 let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
262
263 batch.append(0, 1);
264 batch.append(0, 2);
265 batch.append(0, 3);
266 batch.append(0, 4);
267
268 assert_matches!(batch.poll(), PollResult::Ready(0) =>
269 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
270 );
271
272 batch.append(0, 5);
273 batch.append(0, 6);
274 batch.append(0, 7);
275 batch.append(0, 8);
276
277 assert_matches!(batch.poll(), PollResult::Ready(0) =>
278 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
279 );
280
281 batch.append(1, 1);
282 batch.append(0, 9);
283 batch.append(1, 2);
284 batch.append(0, 10);
285 batch.append(1, 3);
286 batch.append(0, 11);
287 batch.append(1, 4);
288
289 assert_matches!(batch.poll(), PollResult::Ready(1) =>
290 assert_eq!(batch.drain(&1).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
291 );
292
293 batch.append(0, 12);
294
295 assert_matches!(batch.poll(), PollResult::Ready(0) =>
296 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [9, 10, 11, 12])
297 );
298 }
299
300 #[test]
301 fn test_batch_max_duration() {
302 let mut batch = MultiBufBatch::new(4, Duration::from_millis(100));
303
304 batch.append(0, 1);
305 batch.append(0, 2);
306
307 let ready_after = match batch.poll() {
308 PollResult::NotReady(Some(ready_after)) => ready_after,
309 _ => panic!("expected NotReady with instant"),
310 };
311
312 std::thread::sleep(ready_after);
313
314 assert_matches!(batch.poll(), PollResult::Ready(0) =>
315 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2])
316 );
317
318 batch.append(0, 3);
319 batch.append(0, 4);
320 batch.append(0, 5);
321 batch.append(0, 6);
322
323 assert_matches!(batch.poll(), PollResult::Ready(0) =>
324 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [3, 4, 5, 6])
325 );
326 }
327
328 #[test]
329 fn test_drain_stream() {
330 let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
331
332 batch.append(0, 1);
333 batch.append(0, 2);
334 batch.append(0, 3);
335
336 batch.append(1, 1);
337 batch.append(1, 2);
338
339 assert_matches!(batch.drain(&1), Some(drain) =>
340 assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
341 );
342
343 assert_matches!(batch.drain(&0), Some(drain) =>
344 assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2, 3])
345 );
346
347 batch.append(0, 5);
348 batch.append(0, 6);
349 batch.append(0, 7);
350 batch.append(0, 8);
351
352 assert_matches!(batch.poll(), PollResult::Ready(0) =>
353 assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
354 );
355 }
356
357 #[test]
358 fn test_flush() {
359 let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
360
361 batch.append(0, 1);
362 batch.append(1, 1);
363 batch.append(0, 2);
364 batch.append(1, 2);
365 batch.append(0, 3);
366
367 let batches = batch.flush();
368
369 assert_eq!(batches[0].0, 0);
370 assert_eq!(batches[0].1.as_slice(), [1, 2, 3]);
371
372 assert_eq!(batches[1].0, 1);
373 assert_eq!(batches[1].1.as_slice(), [1, 2]);
374 }
375}