use carboncopy::{BoxFuture, Sink};
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{stdout, AsyncWriteExt, Stdout};
use tokio::runtime::Runtime;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender as DropTx};
use tokio::sync::watch::{channel as watch_channel, Receiver as WatchRx};
use tokio::sync::Mutex;
use tokio::time::sleep;
pub struct BufSink<T: AsyncWriteExt + Unpin + Send + 'static> {
rt: Arc<Runtime>,
interior: Arc<Mutex<Interior<T>>>,
drop_chan_tx: DropTx<EmptySignal>,
last_flush_err_chan_rx: WatchRx<Option<Arc<std::io::Error>>>,
}
impl<T: AsyncWriteExt + Unpin + Send + 'static> Sink for BufSink<T> {
fn sink_blocking(&self, entry: String) -> std::io::Result<()> {
self.rt.block_on(self.sink(entry))
}
fn sink(&self, entry: String) -> BoxFuture<std::io::Result<()>> {
Box::pin(async move {
let mut inner = self.interior.lock().await;
if let Some((buf, _)) = inner.buf.as_mut() {
let _ = buf.write(entry.as_bytes()).await;
Ok(())
} else {
inner.output_writer.write(entry.as_bytes()).await?;
Ok(())
}
})
}
}
impl<T: AsyncWriteExt + Unpin + Send + 'static> Drop for BufSink<T> {
fn drop(&mut self) {
let _ = self.drop_chan_tx.send(EmptySignal);
}
}
impl<T: AsyncWriteExt + Unpin + Send + 'static> BufSink<T> {
pub fn new(opts: SinkOptions<T>) -> Self {
let interior = Arc::new(Mutex::new(Interior {
backlogged: false,
buf: if opts.buffer.is_none() {
None
} else {
let cap = opts.buffer.as_ref().unwrap();
Some((Vec::with_capacity(cap.0), cap.0))
},
output_writer: opts.output_writer,
}));
let (drop_tx, mut drop_rx) = unbounded_channel();
let (err_tx, err_rx) = watch_channel(None);
let rt = opts.tokio_runtime.clone();
let interior_clone = interior.clone();
let timeout_ms = opts.flush_timeout_ms;
rt.spawn(async move {
if interior_clone.lock().await.buf.is_some() {
loop {
let overflow = async {
loop {
{
let interior_check = interior_clone.lock().await;
if interior_check.buf.as_ref().unwrap().0.len()
>= interior_check.buf.as_ref().unwrap().1
{
return;
}
}
if timeout_ms > 1 {
sleep(Duration::from_millis(1)).await;
}
}
};
let timeout = async move {
sleep(Duration::from_millis(timeout_ms)).await;
};
tokio::select! {
_ = overflow => {
if let Err(io_err) = interior_clone.lock().await.flush().await {
let _ = err_tx.send(Some(Arc::new(io_err)));
} else {
let _ = err_tx.send(None);
};
}
_ = timeout => {
if let Err(io_err) = interior_clone.lock().await.flush().await {
let _ = err_tx.send(Some(Arc::new(io_err)));
} else {
let _ = err_tx.send(None);
};
}
_ = drop_rx.recv() => {
return;
}
}
}
} else {
return;
}
});
Self {
rt: rt,
interior: interior,
drop_chan_tx: drop_tx,
last_flush_err_chan_rx: err_rx,
}
}
pub async fn flush(&self) -> std::io::Result<usize> {
self.interior.lock().await.flush().await
}
pub async fn backlogged(&self) -> bool {
self.interior.lock().await.backlogged()
}
pub fn last_flush_err(&self) -> Option<Arc<std::io::Error>> {
self.last_flush_err_chan_rx.borrow().clone()
}
}
struct Interior<T: AsyncWriteExt + Unpin + Send + 'static> {
backlogged: bool,
buf: Option<(Vec<u8>, usize)>,
output_writer: T,
}
impl<T: AsyncWriteExt + Unpin + Send + 'static> Interior<T> {
async fn flush(&mut self) -> Result<usize, std::io::Error> {
if self.buf.is_none() {
Ok(0)
} else {
let vec_len = self.buf.as_ref().unwrap().0.len();
if vec_len > 0 {
let mut written: usize = 0;
while vec_len > 0 {
let res = self
.output_writer
.write(self.buf.as_ref().unwrap().0.as_slice())
.await;
if let Ok(delta) = res {
if delta == 0 {
return res;
}
if delta == vec_len {
self.buf.as_mut().unwrap().0 =
Vec::with_capacity(self.buf.as_ref().unwrap().1);
self.backlogged = false;
} else {
self.buf.as_mut().unwrap().0.drain(0..delta);
self.backlogged = true;
}
written += delta;
} else {
self.backlogged = true;
return res;
}
}
Ok(written)
} else {
Ok(0)
}
}
}
fn backlogged(&self) -> bool {
self.backlogged
}
}
pub struct SinkOptions<T: AsyncWriteExt + Unpin + Send + 'static> {
pub buffer: Option<BufferOverflowThreshold>,
pub flush_timeout_ms: u64,
pub tokio_runtime: Arc<Runtime>,
pub output_writer: T,
}
impl Default for SinkOptions<Stdout> {
fn default() -> Self {
Self {
buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
flush_timeout_ms: 100,
tokio_runtime: Arc::new(Runtime::new().unwrap()),
output_writer: stdout(),
}
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
pub struct BufferOverflowThreshold(usize);
impl BufferOverflowThreshold {
pub fn new(cap: usize) -> Result<Self, ThresholdError> {
const KB: usize = 1024;
const GB: usize = 1024 * 1024 * 1024;
if cap >= 1 * KB && cap <= 1 * GB {
Ok(Self(cap))
} else if cap < 1 * KB {
Err(ThresholdError::LessThan1KB)
} else {
Err(ThresholdError::MoreThan1GB)
}
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
pub enum ThresholdError {
LessThan1KB,
MoreThan1GB,
}
impl fmt::Display for ThresholdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::LessThan1KB => {
write!(
f,
"buffer overflow threshold can't be less than 1024 bytes (1KB)"
)
}
Self::MoreThan1GB => {
write!(
f,
"buffer overflow threshold can't be greater than 1024 * 1024 * 1024 bytes (1GB)",
)
}
}
}
}
struct EmptySignal;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overflow_threshold() {
assert_eq!(
BufferOverflowThreshold::new(1000).err().unwrap(),
ThresholdError::LessThan1KB
);
assert_eq!(
BufferOverflowThreshold::new(1024 * 1024 * 1024 + 1)
.err()
.unwrap(),
ThresholdError::MoreThan1GB
);
}
#[test]
fn default_options_dont_panic() {
assert_eq!(100, SinkOptions::default().flush_timeout_ms);
}
#[test]
fn no_buffer() {
let rt = Arc::new(Runtime::new().unwrap());
let opts = SinkOptions {
buffer: None,
flush_timeout_ms: 30,
tokio_runtime: rt.clone(),
output_writer: Vec::new(),
};
let mem_sink = Arc::new(BufSink::new(opts));
for i in 0..5 {
assert!(rt
.block_on(async {
mem_sink
.clone()
.sink(String::from(format!("hello world {}\n", i)))
.await
})
.is_ok());
}
let ref_output =
"hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
let output =
rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
assert_eq!(ref_output, std::str::from_utf8(output.as_ref()).unwrap());
}
#[test]
fn timeout_flush() {
let rt = Arc::new(Runtime::new().unwrap());
let opts = SinkOptions {
buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
flush_timeout_ms: 30,
tokio_runtime: rt.clone(),
output_writer: Vec::new(),
};
let mem_sink = Arc::new(BufSink::new(opts));
for i in 0..5 {
assert!(rt
.block_on(async {
mem_sink
.clone()
.sink(String::from(format!("hello world {}\n", i)))
.await
})
.is_ok());
}
let ref_output =
"hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
let output_before_flush_timeout =
rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
assert_ne!(
ref_output,
std::str::from_utf8(output_before_flush_timeout.as_ref()).unwrap()
);
rt.block_on(async {
sleep(Duration::from_millis(40)).await;
});
let output_after_flush_timeout =
rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
assert_eq!(
ref_output,
std::str::from_utf8(output_after_flush_timeout.as_ref()).unwrap()
);
}
#[test]
fn overflow_flush() {
let rt = Arc::new(Runtime::new().unwrap());
let opts = SinkOptions {
buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
flush_timeout_ms: 30,
tokio_runtime: rt.clone(),
output_writer: Vec::new(),
};
let mem_sink = Arc::new(BufSink::new(opts));
for _ in 0..1024 {
assert!(rt
.block_on(async { mem_sink.clone().sink(String::from("X")).await })
.is_ok());
}
let mut ref_output: String = vec!['X'; 1024].into_iter().collect();
let output_before_buf_overflow =
rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
assert_ne!(
ref_output,
std::str::from_utf8(output_before_buf_overflow.as_ref()).unwrap()
);
assert!(rt
.block_on(async { mem_sink.clone().sink(String::from("X")).await })
.is_ok());
rt.block_on(async {
sleep(Duration::from_millis(1 + 9)).await;
});
ref_output.push('X');
let output_after_buf_overflow =
rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
assert_eq!(
ref_output,
std::str::from_utf8(output_after_buf_overflow.as_ref()).unwrap()
);
}
#[test]
fn flush_err() {
use core::task::{Context, Poll};
use std::io::{Error, ErrorKind};
use std::pin::Pin;
use tokio::io::AsyncWrite;
struct ProblematicWriter;
impl AsyncWrite for ProblematicWriter {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &[u8],
) -> Poll<Result<usize, Error>> {
Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
}
}
let rt = Arc::new(Runtime::new().unwrap());
let opts = SinkOptions {
buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
flush_timeout_ms: 20,
tokio_runtime: rt.clone(),
output_writer: ProblematicWriter,
};
let mem_sink = Arc::new(BufSink::new(opts));
assert!(rt
.block_on(async { mem_sink.clone().sink(String::from("hello world\n")).await })
.is_ok());
assert!(mem_sink.last_flush_err().is_none());
rt.block_on(async {
sleep(Duration::from_millis(20 + 5)).await;
});
assert!(mem_sink.last_flush_err().is_some());
assert_eq!(ErrorKind::Other, mem_sink.last_flush_err().unwrap().kind());
assert_eq!("kaboom!", format!("{}", mem_sink.last_flush_err().unwrap()));
}
}