use crate::buf_batch::{BufBatch, PollResult};
use crate::channel::EndOfStreamError;
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
use std::fmt::Debug;
use std::time::Duration;
use std::vec::Drain;
#[derive(Debug)]
pub enum Command<I: Debug> {
Append(I),
Flush,
}
pub type CommandSender<I> = Sender<Command<I>>;
#[derive(Debug)]
pub struct BufBatchChannel<I: Debug> {
channel: Receiver<Command<I>>,
batch: BufBatch<I>,
disconnected: bool,
}
impl<I: Debug> BufBatchChannel<I> {
pub fn new(
max_size: usize,
max_duration: Duration,
channel_size: usize,
) -> (CommandSender<I>, BufBatchChannel<I>) {
let (sender, receiver) = crossbeam_channel::bounded(channel_size);
(
sender,
BufBatchChannel {
channel: receiver,
batch: BufBatch::new(max_size, max_duration),
disconnected: false,
},
)
}
pub fn with_producer_thread(
max_size: usize,
max_duration: Duration,
channel_size: usize,
producer: impl FnOnce(CommandSender<I>) -> () + Send + 'static,
) -> BufBatchChannel<I>
where
I: Send + 'static,
{
let (sender, batch) = BufBatchChannel::new(max_size, max_duration, channel_size);
std::thread::Builder::new().name("BufBatchChannel producer".to_string()).spawn(move || producer(sender)).expect("failed to start producer thread");
batch
}
pub fn next(&mut self) -> Result<Drain<I>, EndOfStreamError> {
if self.disconnected {
return Err(EndOfStreamError);
}
loop {
let ready_after = match self.batch.poll() {
PollResult::Ready => return Ok(self.batch.drain()),
PollResult::NotReady(ready_after) => ready_after,
};
let recv_result = if let Some(ready_after) = ready_after {
match self.channel.recv_timeout(ready_after) {
Ok(item) => Ok(item),
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => Err(EndOfStreamError),
}
} else {
self.channel.recv().map_err(|_| EndOfStreamError)
};
match recv_result {
Ok(Command::Flush) => {
return Ok(self.batch.drain());
}
Ok(Command::Append(item)) => {
self.batch.append(item);
continue;
}
Err(_eos) => {
self.disconnected = true;
return Ok(self.batch.drain());
}
};
}
}
pub fn is_disconnected(&self) -> bool {
self.disconnected
}
pub fn clear(&mut self) {
self.batch.clear()
}
pub fn as_slice(&self) -> &[I] {
self.batch.as_slice()
}
pub fn drain(&mut self) -> Drain<I> {
self.batch.drain()
}
pub fn into_vec(self) -> Vec<I> {
self.batch.into_vec()
}
pub fn drain_to_end(self) -> DrainToEnd<I> {
let (buffer, channel) = self.split();
DrainToEnd(buffer.into_vec().into_iter(), channel)
}
pub fn split(self) -> (BufBatch<I>, Receiver<Command<I>>) {
(self.batch, self.channel)
}
}
#[derive(Debug)]
pub struct DrainToEnd<I: Debug>(std::vec::IntoIter<I>, Receiver<Command<I>>);
impl<I: Debug> Iterator for DrainToEnd<I> {
type Item = I;
fn next(&mut self) -> Option<I> {
self.0.next().or_else(|| {
loop {
match self.1.recv() {
Ok(Command::Append(i)) => return Some(i),
Ok(Command::Flush) => (),
Err(_) => return None
}
}
})
}
}
#[cfg(test)]
mod tests {
pub use super::*;
use assert_matches::assert_matches;
use std::time::Duration;
#[test]
fn test_batch_max_size() {
let (sender, mut batch) = BufBatchChannel::new(2, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
); }
#[test]
fn test_batch_with_producer_thread() {
let mut batch =
BufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
});
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1, 2])
);
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [3, 4])
); }
#[test]
fn test_batch_max_duration() {
let mut batch =
BufBatchChannel::with_producer_thread(2, Duration::from_millis(100), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
std::thread::sleep(Duration::from_millis(500));
});
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1])
); assert!(!batch.is_disconnected()); }
#[test]
fn test_batch_disconnected() {
let mut batch =
BufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
});
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1])
); assert_matches!(batch.next(), Err(EndOfStreamError));
}
#[test]
fn test_batch_command_complete() {
let mut batch =
BufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Flush).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Flush).unwrap();
});
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [1])
);
assert_matches!(batch.next(), Ok(drain) =>
assert_eq!(drain.collect::<Vec<_>>().as_slice(), [2])
); }
#[test]
fn test_drain_to_end() {
let (sender, batch) = BufBatchChannel::new(4, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Flush).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Flush).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
sender.send(Command::Append(5)).unwrap();
drop(sender);
assert_eq!(batch.drain_to_end().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4, 5]);
}
}