1use async_trait::async_trait;
2#[cfg(test)]
3use mockall::automock;
4use pin_project_lite::pin_project;
5use std::{io, net::SocketAddr};
6use tokio::{
7 io::{AsyncRead, AsyncWrite},
8 net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream},
9};
10
11#[cfg(not(test))]
12use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
13
14#[cfg_attr(test, automock(type Stream=TcpStream;))]
16#[async_trait]
17pub trait SocketListener: Sized + Send + Sync {
18 type Stream: SocketStream + 'static;
20
21 async fn bind(addr: &str) -> io::Result<Self>
23 where
24 Self: Sized;
25
26 async fn accept(&self) -> io::Result<(Self::Stream, SocketAddr)>;
28}
29
30#[cfg_attr(test, automock)]
32#[async_trait]
33pub trait SocketStream: Sized + Send + Sync {
34 async fn connect(addr: &str) -> io::Result<Self>
36 where
37 Self: Sized + 'static;
38
39 fn into_split(self) -> (ReadStream, WriteStream);
42}
43
44#[derive(Debug)]
46pub struct TcpListener {
47 inner: TokioTcpListener,
48}
49
50#[derive(Debug)]
52pub struct TcpStream {
53 inner: TokioTcpStream,
54}
55
56#[async_trait]
57impl SocketListener for TcpListener {
58 type Stream = TcpStream;
59
60 async fn bind(addr: &str) -> io::Result<TcpListener>
61 where
62 Self: Sized,
63 {
64 Ok(TcpListener {
65 inner: TokioTcpListener::bind(addr).await?,
66 })
67 }
68
69 async fn accept(&self) -> io::Result<(Self::Stream, SocketAddr)> {
70 let (stream, addr) = self.inner.accept().await?;
71 let wrapper = TcpStream { inner: stream };
72 Ok((wrapper, addr))
73 }
74}
75#[async_trait]
76impl SocketStream for TcpStream {
77 async fn connect(addr: &str) -> io::Result<Self>
78 where
79 Self: Sized,
80 {
81 let inner = TokioTcpStream::connect(addr).await?;
82 Ok(TcpStream { inner })
83 }
84
85 #[cfg(not(test))]
86 fn into_split(self) -> (ReadStream, WriteStream) {
87 let (read_half, write_half) = self.inner.into_split();
88 (ReadStream::new(read_half), WriteStream::new(write_half))
89 }
90
91 #[cfg(test)]
92 fn into_split(self) -> (ReadStream, WriteStream) {
93 unimplemented!("must mock")
94 }
95}
96
97#[cfg(not(test))]
98type ReadHalf = OwnedReadHalf;
99#[cfg(test)]
100type ReadHalf = tokio_test::io::Mock;
101
102#[cfg(not(test))]
103type WriteHalf = OwnedWriteHalf;
104#[cfg(test)]
105type WriteHalf = tokio_test::io::Mock;
106
107pin_project! {
108 #[derive(Debug)]
110 pub struct ReadStream {
111 #[pin]
112 inner: ReadHalf,
113 }
114}
115
116pin_project! {
117 #[derive(Debug)]
119 pub struct WriteStream {
120 #[pin]
121 inner: WriteHalf,
122
123 }
124}
125
126#[cfg_attr(test, automock)]
127impl ReadStream {
128 pub(crate) fn new(inner: ReadHalf) -> ReadStream {
129 ReadStream { inner }
130 }
131}
132
133#[cfg_attr(test, automock)]
134impl WriteStream {
135 pub(crate) fn new(inner: WriteHalf) -> WriteStream {
136 WriteStream { inner }
137 }
138}
139
140impl AsyncRead for ReadStream {
141 fn poll_read(
142 self: std::pin::Pin<&mut Self>,
143 cx: &mut std::task::Context<'_>,
144 buf: &mut tokio::io::ReadBuf<'_>,
145 ) -> std::task::Poll<io::Result<()>> {
146 self.project().inner.poll_read(cx, buf)
147 }
148}
149
150impl AsyncWrite for WriteStream {
151 fn poll_write(
152 self: std::pin::Pin<&mut Self>,
153 cx: &mut std::task::Context<'_>,
154 buf: &[u8],
155 ) -> std::task::Poll<Result<usize, io::Error>> {
156 self.project().inner.poll_write(cx, buf)
157 }
158
159 fn poll_flush(
160 self: std::pin::Pin<&mut Self>,
161 cx: &mut std::task::Context<'_>,
162 ) -> std::task::Poll<Result<(), io::Error>> {
163 self.project().inner.poll_flush(cx)
164 }
165
166 fn poll_shutdown(
167 self: std::pin::Pin<&mut Self>,
168 cx: &mut std::task::Context<'_>,
169 ) -> std::task::Poll<Result<(), io::Error>> {
170 self.project().inner.poll_shutdown(cx)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use tokio_test::assert_ok;
177
178 use super::*;
179
180 #[tokio::test]
182 async fn test_tcp_stream() {
183 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<()>();
184 tokio::spawn(async move {
185 let listener = TcpListener {
186 inner: TokioTcpListener::bind("127.0.0.1:9909").await.unwrap(),
187 };
188 let _ = ready_tx.send(());
189 let _ = listener.accept().await.unwrap();
190 });
191
192 assert_ok!(ready_rx.await);
193 let _stream = TcpStream::connect("127.0.0.1:9909").await.unwrap();
194 }
196}