use crate::buf_batch::{BufBatch, PollResult};
use crate::channel::EndOfStreamError;
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
use std::fmt::Debug;
use std::time::Duration;
use std::vec::Drain;
#[derive(Debug)]
pub enum Command<I: Debug> {
Append(I),
Flush,
}
pub type CommandSender<I> = Sender<Command<I>>;
#[derive(Debug)]
pub struct Complete<'i, I: Debug>(&'i mut TxBufBatchChannel<I>);
impl<'i, I: Debug> Complete<'i, I> {
pub fn retry(&mut self) {
self.0.retry()
}
pub fn commit(&mut self) {
self.0.clear()
}
pub fn drain(&mut self) -> Drain<I> {
self.0.drain()
}
}
#[derive(Debug)]
pub enum TxBufBatchChannelResult<'i, I: Debug> {
Item(&'i I),
Complete(Complete<'i, I>),
BufferedComplete(Complete<'i, I>),
}
#[derive(Debug)]
pub struct TxBufBatchChannel<I: Debug> {
channel: Receiver<Command<I>>,
batch: BufBatch<I>,
retry: Option<usize>,
disconnected: bool,
}
impl<I: Debug> TxBufBatchChannel<I> {
pub fn new(
max_size: usize,
max_duration: Duration,
channel_size: usize,
) -> (CommandSender<I>, TxBufBatchChannel<I>) {
let (sender, receiver) = crossbeam_channel::bounded(channel_size);
(
sender,
TxBufBatchChannel {
channel: receiver,
batch: BufBatch::new(max_size, max_duration),
retry: None,
disconnected: false,
},
)
}
pub fn with_producer_thread(
max_size: usize,
max_duration: Duration,
channel_size: usize,
producer: impl FnOnce(CommandSender<I>) -> () + Send + 'static,
) -> TxBufBatchChannel<I>
where
I: Send + 'static,
{
let (sender, batch) = TxBufBatchChannel::new(max_size, max_duration, channel_size);
std::thread::Builder::new().name("TxBufBatchChannel producer".to_string()).spawn(move || producer(sender)).expect("failed to start producer thread");
batch
}
pub fn next(&mut self) -> Result<TxBufBatchChannelResult<I>, EndOfStreamError> {
if let Some(retry) = self.retry {
if retry == 0 {
self.retry = None;
return Ok(TxBufBatchChannelResult::BufferedComplete(Complete(self)));
}
let item = &self.batch.as_slice()[self.batch.as_slice().len() - retry];
self.retry = Some(retry - 1);
return Ok(TxBufBatchChannelResult::Item(item));
}
if self.disconnected {
return Err(EndOfStreamError);
}
loop {
let ready_after = match self.batch.poll() {
PollResult::Ready => return Ok(TxBufBatchChannelResult::Complete(Complete(self))),
PollResult::NotReady(ready_after) => ready_after,
};
let recv_result = if let Some(ready_after) = ready_after {
match self.channel.recv_timeout(ready_after) {
Ok(item) => Ok(item),
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => Err(EndOfStreamError),
}
} else {
self.channel.recv().map_err(|_| EndOfStreamError)
};
match recv_result {
Ok(Command::Append(item)) => {
let item = self.batch.append(item);
return Ok(TxBufBatchChannelResult::Item(item));
}
Ok(Command::Flush) => {
return Ok(TxBufBatchChannelResult::Complete(Complete(self)));
}
Err(_eos) => {
self.disconnected = true;
return Ok(TxBufBatchChannelResult::Complete(Complete(self)));
}
};
}
}
pub fn pop(&mut self) -> Option<I> {
if let Some(retry) = self.retry {
if retry == self.batch.as_slice().len() {
return None;
}
return Some(self.batch.remove(self.batch.as_slice().len() - (retry + 1)));
}
self.batch.pop()
}
pub fn is_disconnected(&self) -> bool {
self.disconnected
}
pub fn retry(&mut self) {
self.retry = Some(self.as_slice().len());
}
pub fn clear(&mut self) {
self.batch.clear()
}
pub fn drain(&mut self) -> Drain<I> {
self.batch.drain()
}
pub fn as_slice(&self) -> &[I] {
self.batch.as_slice()
}
pub fn len(&self) -> usize {
self.batch.as_slice().len()
}
pub fn batch_len(&self) -> usize {
if let Some(retry) = self.retry {
self.batch.as_slice().len() - retry
} else {
self.len()
}
}
pub fn into_vec(self) -> Vec<I> {
self.batch.into_vec()
}
pub fn drain_to_end(self) -> DrainToEnd<I> {
let (buffer, channel) = self.split();
DrainToEnd(buffer.into_vec().into_iter(), channel)
}
pub fn split(self) -> (BufBatch<I>, Receiver<Command<I>>) {
(self.batch, self.channel)
}
}
#[derive(Debug)]
pub struct DrainToEnd<I: Debug>(std::vec::IntoIter<I>, Receiver<Command<I>>);
impl<I: Debug> Iterator for DrainToEnd<I> {
type Item = I;
fn next(&mut self) -> Option<I> {
self.0.next().or_else(|| {
loop {
match self.1.recv() {
Ok(Command::Append(i)) => return Some(i),
Ok(Command::Flush) => (),
Err(_) => return None
}
}
})
}
}
#[cfg(test)]
mod tests {
pub use super::*;
use assert_matches::assert_matches;
use std::time::Duration;
#[test]
fn test_batch_retry() {
let (sender, mut batch) = TxBufBatchChannel::new(4, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
sender.send(Command::Append(5)).unwrap();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_))); assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) => complete.retry());
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(_)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) => complete.commit());
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_)));
sender.send(Command::Append(5)).unwrap();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(5)));
}
#[test]
fn test_batch_empty() {
let mut batch =
TxBufBatchChannel::with_producer_thread(2, Duration::from_millis(100), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
std::thread::sleep(Duration::from_millis(500));
});
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
batch.clear();
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_))); assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(_))); assert_matches!(batch.next(), Err(EndOfStreamError));
}
#[test]
fn test_batch_commit() {
let (sender, mut batch) = TxBufBatchChannel::new(2, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(
batch.next(),
Ok(TxBufBatchChannelResult::Complete(_complete))
);
batch.clear();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(mut complete)) => complete.commit());
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) =>
assert_eq!(complete.drain().collect::<Vec<_>>().as_slice(), [4])
); }
#[test]
fn test_batch_with_producer_thread() {
let mut batch =
TxBufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
});
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(
batch.next(),
Ok(TxBufBatchChannelResult::Complete(_complete))
);
batch.clear();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(mut complete)) => complete.commit());
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) =>
assert_eq!(complete.drain().collect::<Vec<_>>().as_slice(), [4])
); }
#[test]
fn test_batch_max_duration() {
let mut batch =
TxBufBatchChannel::with_producer_thread(2, Duration::from_millis(100), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
std::thread::sleep(Duration::from_millis(500));
});
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) =>
assert_eq!(complete.drain().collect::<Vec<_>>().as_slice(), [1])
); }
#[test]
fn test_batch_disconnected() {
let mut batch =
TxBufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
});
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) =>
assert_eq!(complete.drain().collect::<Vec<_>>().as_slice(), [1])
); assert_matches!(batch.next(), Err(EndOfStreamError));
}
#[test]
fn test_batch_command_complete() {
let mut batch =
TxBufBatchChannel::with_producer_thread(2, Duration::from_secs(10), 10, |sender| {
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Flush).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Flush).unwrap();
});
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(
batch.next(),
Ok(TxBufBatchChannelResult::Complete(_complete))
);
batch.clear();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Complete(mut complete)) =>
assert_eq!(complete.drain().collect::<Vec<_>>().as_slice(), [2])
); }
#[test]
fn test_drain_to_end() {
let (sender, mut batch) = TxBufBatchChannel::new(4, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
sender.send(Command::Append(5)).unwrap();
drop(sender);
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_eq!(batch.drain_to_end().collect::<Vec<_>>().as_slice(), [1, 2, 3, 4, 5]);
}
#[test]
fn test_pop() {
let (sender, mut batch) = TxBufBatchChannel::new(40, Duration::from_secs(10), 10);
sender.send(Command::Append(1)).unwrap();
sender.send(Command::Append(2)).unwrap();
sender.send(Command::Append(3)).unwrap();
sender.send(Command::Append(4)).unwrap();
sender.send(Command::Append(5)).unwrap();
sender.send(Command::Append(6)).unwrap();
sender.send(Command::Append(7)).unwrap();
drop(sender);
assert_eq!(batch.batch_len(), 0);
assert_matches!(batch.pop(), None);
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(2)));
assert_eq!(batch.batch_len(), 2);
assert_matches!(batch.pop(), Some(2));
assert_eq!(batch.batch_len(), 1);
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(5)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(6)));
assert_eq!(batch.batch_len(), 5);
assert_matches!(batch.pop(), Some(6));
assert_matches!(batch.pop(), Some(5));
assert_eq!(batch.batch_len(), 3);
batch.retry();
assert_eq!(batch.batch_len(), 0);
assert_matches!(batch.pop(), None);
assert_eq!(batch.batch_len(), 0);
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(1)));
assert_eq!(batch.batch_len(), 1);
assert_matches!(batch.pop(), Some(1));
assert_eq!(batch.batch_len(), 0);
assert_matches!(batch.pop(), None);
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_eq!(batch.batch_len(), 2);
batch.retry();
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(3)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(4)));
assert_matches!(batch.pop(), Some(4));
assert_matches!(batch.pop(), Some(3));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::BufferedComplete(_)));
assert_matches!(batch.next(), Ok(TxBufBatchChannelResult::Item(7)));
assert_eq!(batch.batch_len(), 1);
}
}