use std::{
ffi::OsStr,
io,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use log::debug;
use pretty_hex::PrettyHex;
use tokio::{
io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
BufReader, ReadBuf,
},
net::{TcpStream, ToSocketAddrs},
time,
};
use crate::utils::{Interactive, RecvUntil};
use super::ProcessTube;
#[derive(Debug)]
pub struct Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
pub inner: T,
pub timeout: Duration,
read_buf_logged: usize,
}
const NEW_LINE: u8 = 0xA;
impl<T> Tube<BufReader<T>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(inner: T) -> Self {
Self {
inner: BufReader::new(inner),
timeout: Duration::MAX,
read_buf_logged: 0,
}
}
pub fn with_timeout(inner: T, timeout: Duration) -> Self {
Self {
inner: BufReader::new(inner),
timeout,
read_buf_logged: 0,
}
}
}
impl Tube<BufReader<ProcessTube>> {
pub fn process<S: AsRef<OsStr>>(program: S) -> io::Result<Self> {
Ok(Self::new(ProcessTube::new(program)?))
}
}
impl Tube<BufReader<TcpStream>> {
pub async fn remote<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self::new(TcpStream::connect(addr).await?))
}
}
impl<T> Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
pub fn from_buffered(inner: T) -> Self {
Self {
inner,
timeout: Duration::MAX,
read_buf_logged: 0,
}
}
pub async fn recv(&mut self, len: usize) -> io::Result<Vec<u8>> {
let mut buf = vec![0; len];
let len = time::timeout(self.timeout, self.read(&mut buf[..]))
.await
.unwrap_or(Ok(0))?;
buf.truncate(len);
Ok(buf)
}
pub async fn recv_line(&mut self) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
time::timeout(self.timeout, self.read_until(NEW_LINE, &mut buf))
.await
.unwrap_or(Ok(0))?;
Ok(buf)
}
pub async fn recv_until<A: AsRef<[u8]>>(&mut self, delims: A) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
time::timeout(
self.timeout,
RecvUntil::new(self, delims.as_ref(), &mut buf),
)
.await
.unwrap_or(Ok(()))?;
Ok(buf)
}
pub async fn send<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
self.write_all(data.as_ref()).await?;
self.flush().await
}
pub async fn send_line<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
self.write_all(data.as_ref()).await?;
self.write_all(&[NEW_LINE]).await?;
self.flush().await
}
pub async fn send_line_after<A: AsRef<[u8]>, B: AsRef<[u8]>>(
&mut self,
pattern: A,
data: B,
) -> io::Result<Vec<u8>> {
let result = self.recv_until(pattern).await?;
self.send_line(data).await?;
Ok(result)
}
pub async fn interactive(&mut self) -> io::Result<()> {
Interactive::new(self).await
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> AsyncRead for Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let olen = buf.filled().len();
if Pin::new(&mut self.get_mut().inner)
.poll_read(cx, buf)?
.is_pending()
{
return Poll::Pending;
}
debug!(target: "Tube::recv", "Received {:?}", buf.filled()[olen..].hex_dump());
Poll::Ready(Ok(()))
}
}
impl<T> AsyncWrite for Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let numb = match Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)? {
Poll::Ready(numb) => numb,
Poll::Pending => return Poll::Pending,
};
debug!(target: "Tube::send", "Sent {:?}", buf[..numb].hex_dump());
Poll::Ready(Ok(numb))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<Result<usize, io::Error>> {
let numb = match Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)? {
Poll::Ready(numb) => numb,
Poll::Pending => return Poll::Pending,
};
let mut to_log = numb;
for buf in bufs {
if to_log == 0 {
break;
}
debug!(target: "Tube::send", "Send {:?}", buf[..to_log].hex_dump());
to_log = to_log.saturating_sub(buf.len());
}
Poll::Ready(Ok(numb))
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T> AsyncBufRead for Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
let Self {
inner,
timeout: _,
read_buf_logged,
} = self.get_mut();
let buf = match Pin::new(inner).poll_fill_buf(cx)? {
Poll::Ready(buf) => buf,
Poll::Pending => return Poll::Pending,
};
if buf.len() > *read_buf_logged {
debug!(target: "Tube::recv", "Recevied {:?}", buf[*read_buf_logged..].hex_dump());
*read_buf_logged = buf.len();
}
Poll::Ready(Ok(buf))
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.read_buf_logged -= amt;
Pin::new(&mut self.get_mut().inner).consume(amt);
}
}
impl<T> From<Tube<BufReader<T>>> for BufReader<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn from(tube: Tube<BufReader<T>>) -> Self {
tube.into_inner()
}
}
impl<T> From<T> for Tube<T>
where
T: AsyncBufRead + AsyncWrite + Unpin,
{
fn from(tube_like: T) -> Self {
Self {
inner: tube_like,
timeout: Duration::MAX,
read_buf_logged: 0,
}
}
}