use linked_hash_map::LinkedHashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::time::{Duration, Instant};
use std::vec::Drain;
#[derive(Debug)]
struct OutstandingBatch<I: Debug> {
items: Vec<I>,
created: Instant,
}
impl<I: Debug> OutstandingBatch<I> {
fn new() -> OutstandingBatch<I> {
OutstandingBatch {
items: Vec::new(),
created: Instant::now(),
}
}
fn from_cache(mut items: Vec<I>) -> OutstandingBatch<I> {
items.clear();
OutstandingBatch {
items,
created: Instant::now(),
}
}
}
#[derive(Debug)]
pub enum PollResult<K: Debug> {
Ready(K),
NotReady(Option<Duration>),
}
#[derive(Debug)]
pub struct Stats {
pub outstanding: usize,
pub cached_buffers: usize,
}
#[derive(Debug)]
pub struct MultiBufBatch<K: Debug + Ord + Hash, I: Debug> {
max_size: usize,
max_duration: Duration,
cache: Vec<Vec<I>>,
outstanding: LinkedHashMap<K, OutstandingBatch<I>>,
full: Option<K>,
}
impl<K, I> MultiBufBatch<K, I>
where
K: Debug + Ord + Hash + Clone,
I: Debug,
{
pub fn new(max_size: usize, max_duration: Duration) -> MultiBufBatch<K, I> {
assert!(max_size > 0, "MultiBufBatch::new bad max_size");
MultiBufBatch {
max_size,
max_duration,
cache: Default::default(),
outstanding: Default::default(),
full: Default::default(),
}
}
pub fn poll(&self) -> PollResult<K> {
if let Some(key) = &self.full {
return PollResult::Ready(key.clone());
}
if let Some((key, batch)) = self.outstanding.front() {
let since_start = Instant::now().duration_since(batch.created);
if since_start >= self.max_duration {
return PollResult::Ready(key.clone());
}
return PollResult::NotReady(Some(self.max_duration - since_start));
}
return PollResult::NotReady(None);
}
pub fn append(&mut self, key: K, item: I) {
assert!(
self.full.is_none(),
"MultiBufBatch::append unconsumed full batch"
);
if let Some(batch) = self.outstanding.get_mut(&key) {
assert!(
batch.items.len() < self.max_size,
"MultiBufBatch::append on full batch"
);
batch.items.push(item);
if batch.items.len() >= self.max_size {
self.full = Some(key);
}
} else {
let mut batch = if let Some(items) = self.cache.pop() {
OutstandingBatch::from_cache(items)
} else {
OutstandingBatch::new()
};
batch.items.push(item);
self.outstanding.insert(key, batch);
}
}
fn move_to_cache(&mut self, key: &K) -> Option<&mut Vec<I>> {
if self.full.as_ref().filter(|fkey| *fkey == key).is_some() {
self.full.take();
}
let items = self.outstanding.remove(key)?.items;
self.cache.push(items);
self.cache.last_mut()
}
pub fn outstanding(&self) -> impl Iterator<Item = &K> {
self.outstanding.keys()
}
pub fn clear(&mut self, key: &K) {
self.move_to_cache(key).map(|items| items.clear());
}
pub fn drain(&mut self, key: &K) -> Option<Drain<I>> {
self.move_to_cache(key).map(|items| items.drain(0..))
}
pub fn flush(&mut self) -> Vec<(K, Vec<I>)> {
let cache = &mut self.cache;
let outstanding = &mut self.outstanding;
outstanding
.entries()
.map(|entry| {
let key = entry.key().clone();
let items = entry.remove().items;
cache.push(items);
let items = cache.last_mut().unwrap();
let items = items.split_off(0);
(key, items)
})
.collect()
}
pub fn get(&self, key: &K) -> Option<&[I]> {
self.outstanding
.get(key)
.map(|batch| batch.items.as_slice())
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
pub fn stats(&self) -> Stats {
Stats {
outstanding: self.outstanding.len(),
cached_buffers: self.cache.len(),
}
}
}
#[cfg(test)]
mod tests {
pub use super::*;
use assert_matches::assert_matches;
use std::time::Duration;
#[test]
fn test_batch_poll() {
let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
assert_matches!(batch.poll(), PollResult::NotReady(None));
batch.append(0, 1);
assert_matches!(batch.poll(), PollResult::NotReady(Some(_instant)));
batch.append(0, 2);
batch.append(0, 3);
batch.append(0, 4);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
);
assert_matches!(batch.poll(), PollResult::NotReady(None));
}
#[test]
fn test_batch_max_size() {
let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
batch.append(0, 1);
batch.append(0, 2);
batch.append(0, 3);
batch.append(0, 4);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
);
batch.append(0, 5);
batch.append(0, 6);
batch.append(0, 7);
batch.append(0, 8);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
);
batch.append(1, 1);
batch.append(0, 9);
batch.append(1, 2);
batch.append(0, 10);
batch.append(1, 3);
batch.append(0, 11);
batch.append(1, 4);
assert_matches!(batch.poll(), PollResult::Ready(1) =>
assert_eq!(batch.drain(&1).unwrap().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
);
batch.append(0, 12);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [9, 10, 11, 12])
);
}
#[test]
fn test_batch_max_duration() {
let mut batch = MultiBufBatch::new(4, Duration::from_millis(100));
batch.append(0, 1);
batch.append(0, 2);
let ready_after = match batch.poll() {
PollResult::NotReady(Some(ready_after)) => ready_after,
_ => panic!("expected NotReady with instant"),
};
std::thread::sleep(ready_after);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [1, 2])
);
batch.append(0, 3);
batch.append(0, 4);
batch.append(0, 5);
batch.append(0, 6);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [3, 4, 5, 6])
);
}
#[test]
fn test_drain_stream() {
let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
batch.append(0, 1);
batch.append(0, 2);
batch.append(0, 3);
batch.append(1, 1);
batch.append(1, 2);
assert_matches!(batch.drain(&1), Some(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.drain(&0), Some(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2, 3])
);
batch.append(0, 5);
batch.append(0, 6);
batch.append(0, 7);
batch.append(0, 8);
assert_matches!(batch.poll(), PollResult::Ready(0) =>
assert_eq!(batch.drain(&0).unwrap().collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
);
}
#[test]
fn test_flush() {
let mut batch = MultiBufBatch::new(4, Duration::from_secs(10));
batch.append(0, 1);
batch.append(1, 1);
batch.append(0, 2);
batch.append(1, 2);
batch.append(0, 3);
let batches = batch.flush();
assert_eq!(batches[0].0, 0);
assert_eq!(batches[0].1.as_slice(), [1, 2, 3]);
assert_eq!(batches[1].0, 1);
assert_eq!(batches[1].1.as_slice(), [1, 2]);
}
}