irontide_session/
transport.rs1use std::future::Future;
12use std::io;
13use std::net::SocketAddr;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::net::TcpListener;
19use tokio::net::TcpStream;
20
21type AcceptFuture<'a> =
27 Pin<Box<dyn Future<Output = io::Result<(BoxedStream, SocketAddr)>> + Send + 'a>>;
28
29type BindFn = Box<
31 dyn Fn(
32 SocketAddr,
33 ) -> Pin<Box<dyn Future<Output = io::Result<Box<dyn TransportListener>>> + Send>>
34 + Send
35 + Sync,
36>;
37
38type ConnectFn = Box<
40 dyn Fn(SocketAddr) -> Pin<Box<dyn Future<Output = io::Result<BoxedStream>> + Send>>
41 + Send
42 + Sync,
43>;
44
45pub struct BoxedStream {
55 inner: Pin<Box<dyn StreamRw + Send>>,
56}
57
58impl std::fmt::Debug for BoxedStream {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("BoxedStream").finish_non_exhaustive()
61 }
62}
63
64trait StreamRw: AsyncRead + AsyncWrite + Unpin {}
66impl<T: AsyncRead + AsyncWrite + Unpin> StreamRw for T {}
67
68impl BoxedStream {
69 pub fn new<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(stream: S) -> Self {
71 Self {
72 inner: Box::pin(stream),
73 }
74 }
75}
76
77impl AsyncRead for BoxedStream {
78 fn poll_read(
79 mut self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 buf: &mut ReadBuf<'_>,
82 ) -> Poll<io::Result<()>> {
83 self.inner.as_mut().poll_read(cx, buf)
84 }
85}
86
87impl AsyncWrite for BoxedStream {
88 fn poll_write(
89 mut self: Pin<&mut Self>,
90 cx: &mut Context<'_>,
91 buf: &[u8],
92 ) -> Poll<io::Result<usize>> {
93 self.inner.as_mut().poll_write(cx, buf)
94 }
95
96 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97 self.inner.as_mut().poll_flush(cx)
98 }
99
100 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 self.inner.as_mut().poll_shutdown(cx)
102 }
103}
104
105impl Unpin for BoxedStream {}
106
107pub trait TransportListener: Send + Sync {
118 fn accept(&mut self) -> AcceptFuture<'_>;
120
121 fn local_addr(&self) -> io::Result<SocketAddr>;
123}
124
125pub struct TokioListener(pub TcpListener);
131
132impl TransportListener for TokioListener {
133 fn accept(&mut self) -> AcceptFuture<'_> {
134 Box::pin(async move {
135 let (stream, addr) = self.0.accept().await?;
136 #[allow(deprecated)]
138 let _ = stream.set_linger(Some(std::time::Duration::ZERO));
139 Ok((BoxedStream::new(stream), addr))
140 })
141 }
142
143 fn local_addr(&self) -> io::Result<SocketAddr> {
144 self.0.local_addr()
145 }
146}
147
148pub struct NetworkFactory {
159 bind_tcp: BindFn,
160 connect_tcp: ConnectFn,
161 is_simulated: bool,
162}
163
164impl NetworkFactory {
165 pub fn new(bind_tcp: BindFn, connect_tcp: ConnectFn, is_simulated: bool) -> Self {
169 Self {
170 bind_tcp,
171 connect_tcp,
172 is_simulated,
173 }
174 }
175
176 pub fn tokio() -> Self {
178 Self {
179 bind_tcp: Box::new(|addr| {
180 Box::pin(async move {
181 let listener = TcpListener::bind(addr).await?;
182 Ok(Box::new(TokioListener(listener)) as Box<dyn TransportListener>)
183 })
184 }),
185 connect_tcp: Box::new(|addr| {
186 Box::pin(async move {
187 let stream = TcpStream::connect(addr).await?;
188 #[allow(deprecated)]
193 let _ = stream.set_linger(Some(std::time::Duration::ZERO));
194 Ok(BoxedStream::new(stream))
195 })
196 }),
197 is_simulated: false,
198 }
199 }
200
201 pub async fn bind_tcp(&self, addr: SocketAddr) -> io::Result<Box<dyn TransportListener>> {
203 (self.bind_tcp)(addr).await
204 }
205
206 pub async fn connect_tcp(&self, addr: SocketAddr) -> io::Result<BoxedStream> {
208 (self.connect_tcp)(addr).await
209 }
210
211 pub fn is_simulated(&self) -> bool {
213 self.is_simulated
214 }
215}
216
217#[cfg(test)]
222mod tests {
223 use super::*;
224 use tokio::io::{AsyncReadExt, AsyncWriteExt};
225
226 #[test]
227 fn tokio_factory_creation() {
228 let _factory = NetworkFactory::tokio();
229 }
230
231 #[test]
232 fn tokio_factory_is_not_simulated() {
233 let factory = NetworkFactory::tokio();
234 assert!(!factory.is_simulated());
235 }
236
237 #[tokio::test]
238 async fn tokio_bind_and_accept() {
239 let factory = NetworkFactory::tokio();
240 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
241 let listener = factory.bind_tcp(addr).await.unwrap();
242 let local = listener.local_addr().unwrap();
243 assert_ne!(local.port(), 0);
244 }
245
246 #[tokio::test]
247 async fn tokio_connect_to_listener() {
248 let factory = NetworkFactory::tokio();
249 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
250 let mut listener = factory.bind_tcp(addr).await.unwrap();
251 let local = listener.local_addr().unwrap();
252
253 let accept_handle = tokio::spawn(async move { listener.accept().await.unwrap() });
254
255 let mut client = factory.connect_tcp(local).await.unwrap();
256 client.write_all(b"hello").await.unwrap();
257
258 let (mut server_stream, peer_addr) = accept_handle.await.unwrap();
259 assert_eq!(
260 peer_addr.ip(),
261 "127.0.0.1".parse::<std::net::IpAddr>().unwrap()
262 );
263
264 let mut buf = [0u8; 5];
265 server_stream.read_exact(&mut buf).await.unwrap();
266 assert_eq!(&buf, b"hello");
267 }
268
269 #[test]
270 fn custom_factory_is_simulated() {
271 let factory = NetworkFactory::new(
272 Box::new(|_addr| {
273 Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
274 }),
275 Box::new(|_addr| {
276 Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
277 }),
278 true,
279 );
280 assert!(factory.is_simulated());
281 }
282}