use std::{
io::{self, Read, Result, Write},
ops::{Deref, DerefMut},
};
#[cfg(feature = "async")]
use futures_lite::{AsyncRead, AsyncWrite};
#[cfg(feature = "async")]
use std::{
pin::Pin,
task::{Context, Poll},
};
use crate::process::NonBlocking;
#[derive(Debug)]
pub struct LogStream<S, W> {
stream: S,
logger: W,
}
impl<S, W> LogStream<S, W> {
pub fn new(stream: S, logger: W) -> Self {
Self { stream, logger }
}
}
impl<S, W: Write> LogStream<S, W> {
fn log_write(&mut self, buf: &[u8]) {
log(&mut self.logger, "write", buf);
}
fn log_read(&mut self, buf: &[u8]) {
log(&mut self.logger, "read", buf);
}
}
impl<S: Write, W: Write> Write for LogStream<S, W> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let n = self.stream.write(buf)?;
self.log_write(&buf[..n]);
Ok(n)
}
fn flush(&mut self) -> Result<()> {
self.stream.flush()
}
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> Result<usize> {
let n = self.stream.write_vectored(bufs)?;
let mut rest = n;
let mut bytes = Vec::new();
for buf in bufs {
let written = std::cmp::min(buf.len(), rest);
rest -= written;
bytes.extend(&buf.as_ref()[..written]);
if rest == 0 {
break;
}
}
self.log_write(&bytes);
Ok(n)
}
}
impl<S: Read, W: Write> Read for LogStream<S, W> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let n = self.stream.read(buf)?;
self.log_read(&buf[..n]);
Ok(n)
}
}
impl<S, W> NonBlocking for LogStream<S, W>
where
S: NonBlocking,
{
fn set_blocking(&mut self, on: bool) -> Result<()> {
self.stream.set_blocking(on)
}
}
impl<S, W> Deref for LogStream<S, W> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S, W> DerefMut for LogStream<S, W> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
#[cfg(feature = "async")]
impl<S: AsyncWrite + Unpin, W: Write + Unpin> AsyncWrite for LogStream<S, W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.log_write(buf);
Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
}
}
#[cfg(feature = "async")]
impl<S: AsyncRead + Unpin, W: Write + Unpin> AsyncRead for LogStream<S, W> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
let result = Pin::new(&mut self.stream).poll_read(cx, buf);
if let Poll::Ready(Ok(n)) = &result {
self.log_read(&buf[..*n]);
}
result
}
}
fn log(mut writer: impl Write, target: &str, data: &[u8]) {
let _ = match std::str::from_utf8(data) {
Ok(data) => writeln!(writer, "{}: {:?}", target, data),
Err(..) => writeln!(writer, "{}:(bytes): {:?}", target, data),
};
}