use std::{
io::{self, IoSliceMut},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures_lite::{
ready, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
};
use crate::{process::Healthcheck, Captures, Error, Needle};
#[derive(Debug)]
pub struct Session<P = super::OsProcess, S = super::OsProcessStream> {
process: P,
stream: Stream<S>,
}
impl<P, S> Session<P, S> {
pub fn new(process: P, stream: S) -> io::Result<Self> {
Ok(Self {
process,
stream: Stream::new(stream),
})
}
pub fn get_stream(&self) -> &S {
self.stream.as_ref()
}
pub fn get_stream_mut(&mut self) -> &mut S {
self.stream.as_mut()
}
pub fn get_process(&self) -> &P {
&self.process
}
pub fn get_process_mut(&mut self) -> &mut P {
&mut self.process
}
pub fn set_expect_timeout(&mut self, expect_timeout: Option<Duration>) {
self.stream.set_expect_timeout(expect_timeout);
}
pub fn set_expect_lazy(&mut self, is_lazy: bool) {
self.stream.expect_lazy = is_lazy;
}
pub(crate) fn swap_stream<F: FnOnce(S) -> R, R>(
mut self,
new_stream: F,
) -> Result<Session<P, R>, Error> {
let buf = self.stream.get_available().to_owned();
let stream = self.stream.into_inner();
let stream = new_stream(stream);
let mut session = Session::new(self.process, stream)?;
session.stream.keep(&buf);
Ok(session)
}
}
impl<P: Healthcheck, S> Session<P, S> {
pub fn is_alive(&mut self) -> Result<bool, Error> {
self.process.is_alive().map_err(|err| err.into())
}
}
impl<P, S: AsyncRead + Unpin> Session<P, S> {
#[cfg_attr(windows, doc = "```no_run")]
#[cfg_attr(unix, doc = "```")]
#[cfg_attr(windows, doc = "```no_run")]
#[cfg_attr(unix, doc = "```")]
pub async fn expect<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
match self.stream.expect_lazy {
true => self.stream.expect_lazy(needle).await,
false => self.stream.expect_gready(needle).await,
}
}
#[cfg_attr(any(target_os = "macos", windows), doc = "```no_run")]
#[cfg_attr(not(any(target_os = "macos", windows)), doc = "```")]
pub async fn check<E: Needle>(&mut self, needle: E) -> Result<Captures, Error> {
self.stream.check(needle).await
}
pub async fn is_matched<E: Needle>(&mut self, needle: E) -> Result<bool, Error> {
self.stream.is_matched(needle).await
}
pub async fn is_empty(&mut self) -> io::Result<bool> {
self.stream.is_empty().await
}
}
impl<Proc, S: AsyncWrite + Unpin> Session<Proc, S> {
pub async fn send<B: AsRef<[u8]>>(&mut self, buf: B) -> io::Result<()> {
self.stream.write_all(buf.as_ref()).await
}
pub async fn send_line<B: AsRef<[u8]>>(&mut self, buf: B) -> io::Result<()> {
#[cfg(windows)]
const LINE_ENDING: &[u8] = b"\r\n";
#[cfg(not(windows))]
const LINE_ENDING: &[u8] = b"\n";
self.stream.write_all(buf.as_ref()).await?;
self.stream.write_all(LINE_ENDING).await?;
Ok(())
}
}
impl<P, S> Deref for Session<P, S> {
type Target = P;
fn deref(&self) -> &Self::Target {
&self.process
}
}
impl<P, S> DerefMut for Session<P, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.process
}
}
impl<P: Unpin, S: AsyncWrite + Unpin> AsyncWrite for Session<P, S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
}
}
impl<P: Unpin, S: AsyncRead + Unpin> AsyncRead for Session<P, S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl<P: Unpin, S: AsyncRead + Unpin> AsyncBufRead for Session<P, S> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().stream).poll_fill_buf(cx)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.stream).consume(amt);
}
}
#[derive(Debug)]
struct Stream<S> {
stream: BufferedStream<S>,
expect_timeout: Option<Duration>,
expect_lazy: bool,
}
impl<S> Stream<S> {
fn new(stream: S) -> Self {
Self {
stream: BufferedStream::new(stream),
expect_timeout: Some(Duration::from_millis(10000)),
expect_lazy: false,
}
}
fn as_ref(&self) -> &S {
&self.stream.stream
}
fn as_mut(&mut self) -> &mut S {
&mut self.stream.stream
}
fn set_expect_timeout(&mut self, expect_timeout: Option<Duration>) {
self.expect_timeout = expect_timeout;
}
fn keep(&mut self, buf: &[u8]) {
self.stream.keep(buf);
}
fn get_available(&mut self) -> &[u8] {
self.stream.buffer()
}
fn into_inner(self) -> S {
self.stream.stream
}
}
impl<S: AsyncRead + Unpin> Stream<S> {
async fn expect_gready<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
let expect_timeout = self.expect_timeout;
let expect_future = async {
let mut eof = false;
loop {
let data = self.stream.buffer();
let found = Needle::check(&needle, data, eof)?;
if !found.is_empty() {
let end_index = Captures::right_most_index(&found);
let involved_bytes = data[..end_index].to_vec();
self.stream.consume(end_index);
return Ok(Captures::new(involved_bytes, found));
}
if eof {
return Err(Error::Eof);
}
eof = self.stream.fill().await? == 0;
}
};
if let Some(timeout) = expect_timeout {
let timeout_future = futures_timer::Delay::new(timeout);
futures_lite::future::or(expect_future, async {
timeout_future.await;
Err(Error::ExpectTimeout)
})
.await
} else {
expect_future.await
}
}
async fn expect_lazy<N: Needle>(&mut self, needle: N) -> Result<Captures, Error> {
let expect_timeout = self.expect_timeout;
let expect_future = async {
let mut checked_length = 0;
let mut eof = false;
loop {
let available = self.stream.buffer();
let is_buffer_checked = checked_length == available.len();
if is_buffer_checked {
let n = self.stream.fill().await?;
eof = n == 0;
}
let available = self.stream.buffer();
if checked_length < available.len() {
checked_length += 1;
}
let data = &available[..checked_length];
let found = Needle::check(&needle, data, eof)?;
if !found.is_empty() {
let end_index = Captures::right_most_index(&found);
let involved_bytes = data[..end_index].to_vec();
self.stream.consume(end_index);
return Ok(Captures::new(involved_bytes, found));
}
if eof {
return Err(Error::Eof);
}
}
};
if let Some(timeout) = expect_timeout {
let timeout_future = futures_timer::Delay::new(timeout);
futures_lite::future::or(expect_future, async {
timeout_future.await;
Err(Error::ExpectTimeout)
})
.await
} else {
expect_future.await
}
}
async fn is_matched<E: Needle>(&mut self, needle: E) -> Result<bool, Error> {
let eof = self.try_fill().await?;
let buf = self.stream.buffer();
let found = needle.check(buf, eof)?;
if !found.is_empty() {
return Ok(true);
}
if eof {
return Err(Error::Eof);
}
Ok(false)
}
async fn check<E: Needle>(&mut self, needle: E) -> Result<Captures, Error> {
let eof = self.try_fill().await?;
let buf = self.stream.buffer();
let found = needle.check(buf, eof)?;
if !found.is_empty() {
let end_index = Captures::right_most_index(&found);
let involved_bytes = buf[..end_index].to_vec();
self.stream.consume(end_index);
return Ok(Captures::new(involved_bytes, found));
}
if eof {
return Err(Error::Eof);
}
Ok(Captures::new(Vec::new(), Vec::new()))
}
async fn is_empty(&mut self) -> io::Result<bool> {
match futures_lite::future::poll_once(self.read(&mut [])).await {
Some(Ok(0)) => Ok(true),
Some(Ok(_)) => Ok(false),
Some(Err(err)) => Err(err),
None => Ok(true),
}
}
async fn try_fill(&mut self) -> Result<bool, Error> {
match futures_lite::future::poll_once(self.stream.fill()).await {
Some(Ok(n)) => Ok(n == 0),
Some(Err(err)) => Err(err.into()),
None => Ok(false),
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut *self.stream.get_mut()).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut *self.stream.get_mut()).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut *self.stream.get_mut()).poll_close(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut *self.stream.get_mut()).poll_write_vectored(cx, bufs)
}
}
impl<S: AsyncRead + Unpin> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl<S: AsyncRead + Unpin> AsyncBufRead for Stream<S> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().stream).poll_fill_buf(cx)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.stream).consume(amt);
}
}
#[derive(Debug)]
struct BufferedStream<S> {
stream: S,
buffer: Vec<u8>,
length: usize,
}
impl<S> BufferedStream<S> {
fn new(stream: S) -> Self {
Self {
stream,
buffer: Vec::new(),
length: 0,
}
}
fn keep(&mut self, buf: &[u8]) {
self.buffer.extend(buf);
self.length += buf.len();
}
fn buffer(&self) -> &[u8] {
&self.buffer[..self.length]
}
fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
}
impl<S: AsyncRead + Unpin> BufferedStream<S> {
async fn fill(&mut self) -> io::Result<usize> {
let mut buf = [0; 128];
let n = self.stream.read(&mut buf).await?;
self.keep(&buf[..n]);
Ok(n)
}
}
impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?;
let nread = std::io::Read::read(&mut rem, buf)?;
self.consume(nread);
Poll::Ready(Ok(nread))
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?;
let nread = std::io::Read::read_vectored(&mut rem, bufs)?;
self.consume(nread);
Poll::Ready(Ok(nread))
}
}
impl<S: AsyncRead + Unpin> AsyncBufRead for BufferedStream<S> {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
if self.buffer.is_empty() {
let mut buf = [0; 128];
let n = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
self.keep(&buf[..n]);
}
let buf = self.get_mut().buffer();
Poll::Ready(Ok(buf))
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
let _ = self.buffer.drain(..amt);
self.length -= amt;
}
}
#[cfg(test)]
mod tests {
use futures_lite::AsyncWriteExt;
use crate::Eof;
use super::*;
#[test]
fn test_expect_lazy() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let found = stream.expect_lazy("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[test]
fn test_expect_lazy_eof() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let found = stream.expect_lazy(Eof).await.unwrap();
assert_eq!(b"", found.before());
assert_eq!(vec![b"Hello World"], found.matches().collect::<Vec<_>>());
});
let cursor = futures_lite::io::Cursor::new(Vec::new());
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let err = stream.expect_lazy("").await.unwrap_err();
assert!(matches!(err, Error::Eof));
});
}
#[test]
fn test_expect_lazy_timeout() {
futures_lite::future::block_on(async {
let mut stream = Stream::new(NoEofReader::default());
stream.set_expect_timeout(Some(Duration::from_millis(100)));
stream.write_all(b"Hello").await.unwrap();
let err = stream.expect_lazy("Hello World").await.unwrap_err();
assert!(matches!(err, Error::ExpectTimeout));
stream.write_all(b" World").await.unwrap();
let found = stream.expect_lazy("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[test]
fn test_expect_gready() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let found = stream.expect_gready("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[test]
fn test_expect_gready_eof() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let found = stream.expect_gready(Eof).await.unwrap();
assert_eq!(b"", found.before());
assert_eq!(vec![b"Hello World"], found.matches().collect::<Vec<_>>());
});
let cursor = futures_lite::io::Cursor::new(Vec::new());
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let err = stream.expect_gready("").await.unwrap_err();
assert!(matches!(err, Error::Eof));
});
}
#[test]
fn test_expect_gready_timeout() {
futures_lite::future::block_on(async {
let mut stream = Stream::new(NoEofReader::default());
stream.set_expect_timeout(Some(Duration::from_millis(100)));
stream.write_all(b"Hello").await.unwrap();
let err = stream.expect_gready("Hello World").await.unwrap_err();
assert!(matches!(err, Error::ExpectTimeout));
stream.write_all(b" World").await.unwrap();
let found = stream.expect_gready("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[test]
fn test_check() {
let buf = b"Hello World".to_vec();
let cursor = futures_lite::io::Cursor::new(buf);
let mut stream = Stream::new(cursor);
futures_lite::future::block_on(async {
let found = stream.check("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[test]
fn test_is_matched() {
let mut stream = Stream::new(NoEofReader::default());
futures_lite::future::block_on(async {
stream.write_all(b"Hello World").await.unwrap();
assert!(stream.is_matched("World").await.unwrap());
assert!(!stream.is_matched("*****").await.unwrap());
let found = stream.check("World").await.unwrap();
assert_eq!(b"Hello ", found.before());
assert_eq!(vec![b"World"], found.matches().collect::<Vec<_>>());
});
}
#[derive(Debug, Default)]
struct NoEofReader {
data: Vec<u8>,
}
impl AsyncWrite for NoEofReader {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.data.extend(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncRead for NoEofReader {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if self.data.is_empty() {
return Poll::Pending;
}
let n = std::io::Write::write(&mut buf, &self.data)?;
let _ = self.data.drain(..n);
Poll::Ready(Ok(n))
}
}
}