trillium-testing 0.10.0

testing library for trillium applications
Documentation
use async_dup::Arc;
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
    fmt::{Debug, Display},
    future::Future,
    io,
    net::{IpAddr, Shutdown, SocketAddr},
    pin::Pin,
    sync::RwLock,
    task::{Context, Poll, Waker},
};
use trillium::TypeSet;
use trillium_macros::{AsyncRead, AsyncWrite};

/// a readable and writable transport for testing
#[derive(Default, Clone, Debug, AsyncRead, AsyncWrite, fieldwork::Fieldwork)]
pub struct TestTransport {
    /// the read side of this transport
    #[async_read]
    #[field(get = read_side)]
    read: Arc<CloseableCursor>,

    /// the write side of this transport
    #[async_write]
    #[field(get = write_side)]
    write: Arc<CloseableCursor>,

    /// State that can be shared with the other side of this transport
    #[field(vis = "pub(crate)", get)]
    state: Arc<RwLock<TypeSet>>,

    /// peer ip for the read side
    #[field(get, set, option_set_some)]
    peer_ip: Option<IpAddr>,
}

impl trillium::Transport for TestTransport {
    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
        if let Some(ip) = self.peer_ip {
            Ok(Some(SocketAddr::from((ip, 0))))
        } else {
            Ok(None)
        }
    }
}

impl TestTransport {
    /// constructs a new test transport pair, representing two ends of
    /// a connection. either of them can be written to, and the
    /// content will be readable from the other. either of them can
    /// also be closed.
    pub fn new() -> (TestTransport, TestTransport) {
        let a = Arc::new(CloseableCursor::default());
        let b = Arc::new(CloseableCursor::default());
        let state: Arc<RwLock<TypeSet>> = Default::default();

        (
            TestTransport {
                read: a.clone(),
                write: b.clone(),
                state: state.clone(),
                peer_ip: None,
            },
            TestTransport {
                read: b,
                write: a,
                state,
                peer_ip: None,
            },
        )
    }

    /// close this transport, representing a disconnection
    pub fn close(&mut self) {
        self.read.close();
        self.write.close();
    }

    /// Shuts down the read, write, or both halves of this connection.
    // This function will cause all pending and future I/O on the specified portions to return
    // immediately with an appropriate value (see the documentation of Shutdown).
    pub fn shutdown(&self, how: Shutdown) {
        match how {
            Shutdown::Read => self.read.close(),
            Shutdown::Write => self.write.close(),
            Shutdown::Both => {
                self.read.close();
                self.write.close();
            }
        }
    }

    /// take an owned snapshot of the received data
    pub fn snapshot(&self) -> Vec<u8> {
        self.read.snapshot()
    }

    /// synchronously append the supplied bytes to the write side of this transport, notifying the
    /// read side of the other end
    pub fn write_all(&self, bytes: impl AsRef<[u8]>) {
        io::Write::write_all(&mut &*self.write, bytes.as_ref()).unwrap();
    }

    /// waits until there is content and then reads that content to a vec until there is no
    /// further content immediately available
    pub async fn read_available(&self) -> Vec<u8> {
        self.read.read_available().await
    }

    /// waits until there is content and then reads that content to a string until there is no
    /// further content immediately available
    pub async fn read_available_string(&self) -> String {
        self.read.read_available_string().await
    }
}

impl Drop for TestTransport {
    fn drop(&mut self) {
        self.close();
    }
}

#[derive(Default)]
struct CloseableCursorInner {
    data: Vec<u8>,
    cursor: usize,
    waker: Option<Waker>,
    closed: bool,
}

#[derive(Default)]
pub struct CloseableCursor(RwLock<CloseableCursorInner>);

impl CloseableCursor {
    /// the length of the content
    pub fn len(&self) -> usize {
        self.0.read().unwrap().data.len()
    }

    /// the current read position
    pub fn cursor(&self) -> usize {
        self.0.read().unwrap().cursor
    }

    /// does what it says on the tin
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// take a snapshot of the data
    pub fn snapshot(&self) -> Vec<u8> {
        self.0.read().unwrap().data.clone()
    }

    /// have we read to the end of the available content
    pub fn current(&self) -> bool {
        let inner = self.0.read().unwrap();
        inner.data.len() == inner.cursor
    }

    /// close this cursor, waking any pending polls
    pub fn close(&self) {
        let mut inner = self.0.write().unwrap();
        inner.closed = true;
        if let Some(waker) = inner.waker.take() {
            waker.wake();
        }
    }

    /// read any available bytes
    pub async fn read_available(&self) -> Vec<u8> {
        ReadAvailable(self).await.unwrap()
    }

    /// read any available bytes as a string
    pub async fn read_available_string(&self) -> String {
        String::from_utf8(self.read_available().await).unwrap()
    }
}

struct ReadAvailable<T>(T);

impl<T: AsyncRead + Unpin> Future for ReadAvailable<T> {
    type Output = io::Result<Vec<u8>>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut buf = vec![];
        let mut bytes_read = 0;
        loop {
            if buf.len() == bytes_read {
                buf.reserve(32);
                buf.resize(buf.capacity(), 0);
            }
            match Pin::new(&mut self.0).poll_read(cx, &mut buf[bytes_read..]) {
                Poll::Ready(Ok(0)) => break,
                Poll::Ready(Ok(new_bytes)) => {
                    bytes_read += new_bytes;
                }
                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                Poll::Pending if bytes_read == 0 => return Poll::Pending,
                Poll::Pending => break,
            }
        }

        buf.truncate(bytes_read);
        Poll::Ready(Ok(buf))
    }
}

impl Display for CloseableCursor {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let inner = self.0.read().unwrap();
        write!(f, "{}", String::from_utf8_lossy(&inner.data))
    }
}

impl Debug for CloseableCursor {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let inner = self.0.read().unwrap();
        f.debug_struct("CloseableCursor")
            .field(
                "data",
                &std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
            )
            .field("closed", &inner.closed)
            .field("cursor", &inner.cursor)
            .finish()
    }
}

impl AsyncRead for CloseableCursor {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut &*self).poll_read(cx, buf)
    }
}

impl AsyncRead for &CloseableCursor {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let mut inner = self.0.write().unwrap();
        if inner.cursor < inner.data.len() {
            let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
            buf[..bytes_to_copy]
                .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
            inner.cursor += bytes_to_copy;
            Poll::Ready(Ok(bytes_to_copy))
        } else if inner.closed {
            Poll::Ready(Ok(0))
        } else {
            inner.waker = Some(cx.waker().clone());
            Poll::Pending
        }
    }
}

impl AsyncWrite for &CloseableCursor {
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let mut inner = self.0.write().unwrap();
        if inner.closed {
            Poll::Ready(Ok(0))
        } else {
            inner.data.extend_from_slice(buf);
            if let Some(waker) = inner.waker.take() {
                waker.wake();
            }
            Poll::Ready(Ok(buf.len()))
        }
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.close();
        Poll::Ready(Ok(()))
    }
}

impl io::Write for CloseableCursor {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        io::Write::write(&mut &*self, buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}

impl io::Write for &CloseableCursor {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let mut inner = self.0.write().unwrap();
        if inner.closed {
            Ok(0)
        } else {
            inner.data.extend_from_slice(buf);
            if let Some(waker) = inner.waker.take() {
                waker.wake();
            }
            Ok(buf.len())
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}