use crate::channel::EndOfStreamError;
pub use crate::multi_buf_batch::Stats;
use crate::multi_buf_batch::{MultiBufBatch, PollResult};
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
use std::fmt::Debug;
use std::hash::Hash;
use std::time::Duration;
use std::vec::Drain;
#[derive(Debug)]
pub enum Command<K: Debug + Ord + Hash, I: Debug> {
Append(K, I),
Flush(K),
}
pub type CommandSender<K, I> = Sender<Command<K, I>>;
#[derive(Debug)]
pub struct MultiBufBatchChannel<K: Debug + Ord + Hash, I: Debug> {
channel: Receiver<Command<K, I>>,
batch: MultiBufBatch<K, I>,
flush: Option<std::vec::IntoIter<K>>,
}
impl<K, I> MultiBufBatchChannel<K, I>
where
K: Debug + Ord + Hash + Send + Clone + 'static,
I: Debug + Send + 'static,
{
pub fn new(
max_size: usize,
max_duration: Duration,
channel_size: usize,
) -> (CommandSender<K, I>, MultiBufBatchChannel<K, I>) {
let (sender, receiver) = crossbeam_channel::bounded(channel_size);
(
sender,
MultiBufBatchChannel {
channel: receiver,
batch: MultiBufBatch::new(max_size, max_duration),
flush: None,
},
)
}
pub fn with_producer_thread(
max_size: usize,
max_duration: Duration,
channel_size: usize,
producer: impl FnOnce(CommandSender<K, I>) -> () + Send + 'static,
) -> MultiBufBatchChannel<K, I> {
let (sender, batch) = MultiBufBatchChannel::new(max_size, max_duration, channel_size);
std::thread::Builder::new().name("MultiBufBatchChannel producer".to_string()).spawn(move || producer(sender)).expect("failed to start producer thread");
batch
}
pub fn next<'i>(&'i mut self) -> Result<(K, Drain<I>), EndOfStreamError> {
loop {
if self.flush.is_some() {
let keys = self.flush.as_mut().unwrap();
if let Some(key) = keys.next() {
let batch = self.drain(&key).expect("flushing key that does not exist");
return Ok((key, batch));
}
return Err(EndOfStreamError);
}
let ready_after = match self.batch.poll() {
PollResult::Ready(key) => {
let batch = self.batch.drain(&key).expect("ready key not found");
return Ok((key, batch));
}
PollResult::NotReady(ready_after) => ready_after,
};
let recv_result = if let Some(ready_after) = ready_after {
match self.channel.recv_timeout(ready_after) {
Ok(item) => Ok(item),
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => Err(EndOfStreamError),
}
} else {
self.channel.recv().map_err(|_| EndOfStreamError)
};
match recv_result {
Ok(Command::Flush(key)) => {
if self.batch.get(&key).is_some() {
let batch = self.batch.drain(&key).unwrap();
return Ok((key, batch));
}
continue;
}
Ok(Command::Append(key, item)) => {
self.batch.append(key, item);
continue;
}
Err(_eos) => {
let keys: Vec<K> = self.batch.outstanding().cloned().collect();
self.batch.clear_cache();
self.flush = Some(keys.into_iter());
continue;
}
}
}
}
pub fn outstanding(&self) -> impl Iterator<Item = &K> {
self.batch.outstanding()
}
pub fn clear(&mut self, key: &K) {
self.batch.clear(key)
}
pub fn drain(&mut self, key: &K) -> Option<Drain<I>> {
self.batch.drain(key)
}
pub fn flush(&mut self) -> Vec<(K, Vec<I>)> {
self.batch.flush()
}
pub fn get(&self, key: &K) -> Option<&[I]> {
self.batch.get(key)
}
pub fn clear_cache(&mut self) {
self.batch.clear_cache()
}
pub fn stats(&self) -> Stats {
self.batch.stats()
}
pub fn split(self) -> (MultiBufBatch<K, I>, Receiver<Command<K, I>>) {
(self.batch, self.channel)
}
}
#[cfg(test)]
mod tests {
use super::Command::*;
pub use super::*;
use assert_matches::assert_matches;
use std::time::Duration;
#[test]
fn test_batch_max_size() {
let (sender, mut batch) = MultiBufBatchChannel::new(4, Duration::from_secs(10), 20);
sender.send(Append(0, 1)).unwrap();
sender.send(Append(0, 2)).unwrap();
sender.send(Append(0, 3)).unwrap();
sender.send(Append(0, 4)).unwrap();
sender.send(Append(0, 5)).unwrap();
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
);
sender.send(Append(1, 1)).unwrap();
sender.send(Append(1, 2)).unwrap();
sender.send(Append(1, 3)).unwrap();
sender.send(Append(1, 4)).unwrap();
sender.send(Append(1, 5)).unwrap();
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2, 3, 4])
);
sender.send(Append(1, 6)).unwrap();
sender.send(Append(0, 6)).unwrap();
sender.send(Append(1, 7)).unwrap();
sender.send(Append(0, 7)).unwrap();
sender.send(Append(1, 8)).unwrap();
sender.send(Append(0, 8)).unwrap();
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [5, 6, 7, 8])
);
}
#[test]
fn test_batch_with_producer_thread() {
let mut batch =
MultiBufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 20, |sender| {
sender.send(Append(0, 1)).unwrap();
sender.send(Append(1, 1)).unwrap();
sender.send(Append(0, 2)).unwrap();
sender.send(Append(1, 2)).unwrap();
});
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
}
#[test]
fn test_batch_max_duration() {
let mut batch = MultiBufBatchChannel::with_producer_thread(
2,
Duration::from_millis(100),
10,
|sender| {
sender.send(Append(0, 1)).unwrap();
std::thread::sleep(Duration::from_millis(500));
sender.send(Append(0, 2)).unwrap();
},
);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1])
);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [2])
);
}
#[test]
fn test_batch_disconnected() {
let (sender, mut batch) = MultiBufBatchChannel::new(2, Duration::from_secs(10), 20);
sender.send(Append(0, 1)).unwrap();
sender.send(Append(1, 1)).unwrap();
sender.send(Append(0, 2)).unwrap();
sender.send(Append(1, 2)).unwrap();
sender.send(Append(0, 3)).unwrap();
sender.send(Append(1, 3)).unwrap();
drop(sender);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [3])
);
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [3])
);
assert_matches!(batch.next(), Err(EndOfStreamError));
}
#[test]
fn test_batch_drain() {
let (sender, mut batch) = MultiBufBatchChannel::new(2, Duration::from_secs(10), 20);
sender.send(Append(0, 1)).unwrap();
sender.send(Append(1, 1)).unwrap();
sender.send(Flush(0)).unwrap();
sender.send(Append(0, 2)).unwrap();
sender.send(Append(1, 2)).unwrap();
sender.send(Append(0, 3)).unwrap();
sender.send(Append(1, 3)).unwrap();
sender.send(Flush(1)).unwrap();
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1])
);
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.next(), Ok((0, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [2, 3])
);
assert_matches!(batch.next(), Ok((1, drain)) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [3])
);
}
}