use std::io::Error as StdIOError;
use bytes::BytesMut;
use tokio::{
io::{AsyncRead, AsyncReadExt, Error as TokioIOError},
process::Child,
spawn,
task::{JoinError, JoinHandle},
};
pub trait TokioReadExt: AsyncRead + Unpin + Send {
fn read_to_end(
&mut self,
) -> impl Future<Output = Result<BytesMut, TokioIOError>> + Send;
fn read_to_end_with<F>(
&mut self, f: F,
) -> impl Future<Output = Result<BytesMut, TokioIOError>> + Send
where F: Fn(&[u8], &[u8]) + Send;
}
pub trait TokioReadTaskExt: TokioReadExt {
fn spawn_read(self) -> JoinHandle<Result<BytesMut, TokioIOError>>
where Self: Sized + Send + 'static;
fn spawn_read_with<F>(
self, f: F,
) -> JoinHandle<Result<BytesMut, TokioIOError>>
where
Self: Sized + Send + 'static,
F: Fn(&[u8], &[u8]) + Send + 'static;
fn spawn_read_opt<F>(
self, f: Option<F>,
) -> JoinHandle<Result<BytesMut, TokioIOError>>
where
Self: Sized + Send + 'static,
F: Fn(&[u8], &[u8]) + Send + 'static;
}
pub trait TokioChildExt {
async fn read_out<F1, F2>(
&mut self, on_stdout: Option<F1>, on_stderr: Option<F2>,
) -> Result<
(Result<BytesMut, StdIOError>, Result<BytesMut, StdIOError>),
JoinError,
>
where
F1: Fn(&[u8], &[u8]) + Send + 'static,
F2: Fn(&[u8], &[u8]) + Send + 'static;
}
impl TokioChildExt for Child {
async fn read_out<F1, F2>(
&mut self, on_stdout: Option<F1>, on_stderr: Option<F2>,
) -> Result<
(Result<BytesMut, StdIOError>, Result<BytesMut, StdIOError>),
JoinError,
>
where
F1: Fn(&[u8], &[u8]) + Send + 'static,
F2: Fn(&[u8], &[u8]) + Send + 'static,
{
let stdout_task = self.stdout.take().unwrap().spawn_read_opt(on_stdout);
let stderr_task = self.stderr.take().unwrap().spawn_read_opt(on_stderr);
tokio::try_join!(stdout_task, stderr_task)
}
}
impl<T: AsyncRead + Send + Unpin> TokioReadExt for T {
async fn read_to_end(&mut self) -> Result<BytesMut, TokioIOError> {
let mut buf = BytesMut::new();
while self.read_buf(&mut buf).await? != 0 {}
Ok(buf)
}
async fn read_to_end_with<F>(
&mut self, f: F,
) -> Result<BytesMut, TokioIOError>
where F: Fn(&[u8], &[u8]) + Send {
let mut buf = BytesMut::new();
let mut last_cursor = 0;
while self.read_buf(&mut buf).await? != 0 {
f(&buf, &buf[last_cursor..]);
last_cursor = buf.len();
}
Ok(buf)
}
}
impl<T: TokioReadExt> TokioReadTaskExt for T {
fn spawn_read(mut self) -> JoinHandle<Result<BytesMut, TokioIOError>>
where Self: Sized + Send + 'static {
spawn(async move { self.read_to_end().await })
}
fn spawn_read_with<F>(
mut self, f: F,
) -> JoinHandle<Result<BytesMut, TokioIOError>>
where
Self: Sized + Send + 'static,
F: Fn(&[u8], &[u8]) + Send + 'static,
{
spawn(async move { self.read_to_end_with(f).await })
}
fn spawn_read_opt<F>(
self, f: Option<F>,
) -> JoinHandle<Result<BytesMut, TokioIOError>>
where
Self: Sized + Send + 'static,
F: Fn(&[u8], &[u8]) + Send + 'static,
{
match f {
Some(func) => self.spawn_read_with(func),
None => self.spawn_read(),
}
}
}