noxious/
socket.rs

1use async_trait::async_trait;
2#[cfg(test)]
3use mockall::automock;
4use pin_project_lite::pin_project;
5use std::{io, net::SocketAddr};
6use tokio::{
7    io::{AsyncRead, AsyncWrite},
8    net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream},
9};
10
11#[cfg(not(test))]
12use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
13
14/// The TcpListener interface we need to mock
15#[cfg_attr(test, automock(type Stream=TcpStream;))]
16#[async_trait]
17pub trait SocketListener: Sized + Send + Sync {
18    /// The associated listener interface to be mocked too
19    type Stream: SocketStream + 'static;
20
21    /// Creates a new SocketListener, which will be bound to the specified address.
22    async fn bind(addr: &str) -> io::Result<Self>
23    where
24        Self: Sized;
25
26    /// Accepts a new incoming connection from this listener.
27    async fn accept(&self) -> io::Result<(Self::Stream, SocketAddr)>;
28}
29
30/// The TcpStream interface we need to mock
31#[cfg_attr(test, automock)]
32#[async_trait]
33pub trait SocketStream: Sized + Send + Sync {
34    /// Opens a TCP connection to a remote host.
35    async fn connect(addr: &str) -> io::Result<Self>
36    where
37        Self: Sized + 'static;
38
39    /// Splits the inner `TcpStream` into a read half and a write half, which
40    /// can be used to read and write the stream concurrently.
41    fn into_split(self) -> (ReadStream, WriteStream);
42}
43
44/// A simple wrapper around Tokio TcpListener to make it mockable
45#[derive(Debug)]
46pub struct TcpListener {
47    inner: TokioTcpListener,
48}
49
50/// A simple wrapper around Tokio TcpStream to make it mockable
51#[derive(Debug)]
52pub struct TcpStream {
53    inner: TokioTcpStream,
54}
55
56#[async_trait]
57impl SocketListener for TcpListener {
58    type Stream = TcpStream;
59
60    async fn bind(addr: &str) -> io::Result<TcpListener>
61    where
62        Self: Sized,
63    {
64        Ok(TcpListener {
65            inner: TokioTcpListener::bind(addr).await?,
66        })
67    }
68
69    async fn accept(&self) -> io::Result<(Self::Stream, SocketAddr)> {
70        let (stream, addr) = self.inner.accept().await?;
71        let wrapper = TcpStream { inner: stream };
72        Ok((wrapper, addr))
73    }
74}
75#[async_trait]
76impl SocketStream for TcpStream {
77    async fn connect(addr: &str) -> io::Result<Self>
78    where
79        Self: Sized,
80    {
81        let inner = TokioTcpStream::connect(addr).await?;
82        Ok(TcpStream { inner })
83    }
84
85    #[cfg(not(test))]
86    fn into_split(self) -> (ReadStream, WriteStream) {
87        let (read_half, write_half) = self.inner.into_split();
88        (ReadStream::new(read_half), WriteStream::new(write_half))
89    }
90
91    #[cfg(test)]
92    fn into_split(self) -> (ReadStream, WriteStream) {
93        unimplemented!("must mock")
94    }
95}
96
97#[cfg(not(test))]
98type ReadHalf = OwnedReadHalf;
99#[cfg(test)]
100type ReadHalf = tokio_test::io::Mock;
101
102#[cfg(not(test))]
103type WriteHalf = OwnedWriteHalf;
104#[cfg(test)]
105type WriteHalf = tokio_test::io::Mock;
106
107pin_project! {
108    /// Wrapper for OwnedReadHalf for mocking
109    #[derive(Debug)]
110    pub struct ReadStream {
111        #[pin]
112        inner: ReadHalf,
113    }
114}
115
116pin_project! {
117    /// Wrapper for OwnedWriteHalf for mocking
118    #[derive(Debug)]
119    pub struct WriteStream {
120        #[pin]
121        inner: WriteHalf,
122
123    }
124}
125
126#[cfg_attr(test, automock)]
127impl ReadStream {
128    pub(crate) fn new(inner: ReadHalf) -> ReadStream {
129        ReadStream { inner }
130    }
131}
132
133#[cfg_attr(test, automock)]
134impl WriteStream {
135    pub(crate) fn new(inner: WriteHalf) -> WriteStream {
136        WriteStream { inner }
137    }
138}
139
140impl AsyncRead for ReadStream {
141    fn poll_read(
142        self: std::pin::Pin<&mut Self>,
143        cx: &mut std::task::Context<'_>,
144        buf: &mut tokio::io::ReadBuf<'_>,
145    ) -> std::task::Poll<io::Result<()>> {
146        self.project().inner.poll_read(cx, buf)
147    }
148}
149
150impl AsyncWrite for WriteStream {
151    fn poll_write(
152        self: std::pin::Pin<&mut Self>,
153        cx: &mut std::task::Context<'_>,
154        buf: &[u8],
155    ) -> std::task::Poll<Result<usize, io::Error>> {
156        self.project().inner.poll_write(cx, buf)
157    }
158
159    fn poll_flush(
160        self: std::pin::Pin<&mut Self>,
161        cx: &mut std::task::Context<'_>,
162    ) -> std::task::Poll<Result<(), io::Error>> {
163        self.project().inner.poll_flush(cx)
164    }
165
166    fn poll_shutdown(
167        self: std::pin::Pin<&mut Self>,
168        cx: &mut std::task::Context<'_>,
169    ) -> std::task::Poll<Result<(), io::Error>> {
170        self.project().inner.poll_shutdown(cx)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use tokio_test::assert_ok;
177
178    use super::*;
179
180    // Dummy test for coverage's sake
181    #[tokio::test]
182    async fn test_tcp_stream() {
183        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
184        tokio::spawn(async move {
185            let listener = TcpListener {
186                inner: TokioTcpListener::bind("127.0.0.1:9909").await.unwrap(),
187            };
188            let _ = ready_tx.send(());
189            let _ = listener.accept().await.unwrap();
190        });
191
192        assert_ok!(ready_rx.await);
193        let _stream = TcpStream::connect("127.0.0.1:9909").await.unwrap();
194        // let _ = stream.into_split();
195    }
196}