Skip to main content

irontide_session/
transport.rs

1//! Network transport abstraction layer.
2//!
3//! Provides [`NetworkFactory`] — a factory for creating TCP listeners and
4//! connections using either real tokio sockets (production) or pluggable
5//! in-memory channels (testing/simulation).
6//!
7//! The key abstraction is [`TransportListener`], an object-safe trait for
8//! accepting inbound connections, and [`BoxedStream`], a type-erased
9//! async read/write stream.
10
11use std::future::Future;
12use std::io;
13use std::net::SocketAddr;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::net::TcpListener;
19use tokio::net::TcpStream;
20
21// ---------------------------------------------------------------------------
22// Type aliases — tame clippy::type_complexity
23// ---------------------------------------------------------------------------
24
25/// Boxed future returned by [`TransportListener::accept`].
26type AcceptFuture<'a> =
27    Pin<Box<dyn Future<Output = io::Result<(BoxedStream, SocketAddr)>> + Send + 'a>>;
28
29/// Closure type for [`NetworkFactory`]'s bind operation.
30type BindFn = Box<
31    dyn Fn(
32            SocketAddr,
33        ) -> Pin<Box<dyn Future<Output = io::Result<Box<dyn TransportListener>>> + Send>>
34        + Send
35        + Sync,
36>;
37
38/// Closure type for [`NetworkFactory`]'s connect operation.
39type ConnectFn = Box<
40    dyn Fn(SocketAddr) -> Pin<Box<dyn Future<Output = io::Result<BoxedStream>> + Send>>
41        + Send
42        + Sync,
43>;
44
45// ---------------------------------------------------------------------------
46// BoxedStream
47// ---------------------------------------------------------------------------
48
49/// A type-erased bidirectional async stream.
50///
51/// Wraps any `AsyncRead + AsyncWrite + Unpin + Send` type behind a single
52/// trait object. This avoids the Rust limitation that `dyn` can only name
53/// one non-auto trait.
54pub struct BoxedStream {
55    inner: Pin<Box<dyn StreamRw + Send>>,
56}
57
58impl std::fmt::Debug for BoxedStream {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("BoxedStream").finish_non_exhaustive()
61    }
62}
63
64/// Combined read/write supertrait for dyn compatibility.
65trait StreamRw: AsyncRead + AsyncWrite + Unpin {}
66impl<T: AsyncRead + AsyncWrite + Unpin> StreamRw for T {}
67
68impl BoxedStream {
69    /// Wrap any async read/write stream into a [`BoxedStream`].
70    pub fn new<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(stream: S) -> Self {
71        Self {
72            inner: Box::pin(stream),
73        }
74    }
75}
76
77impl AsyncRead for BoxedStream {
78    fn poll_read(
79        mut self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81        buf: &mut ReadBuf<'_>,
82    ) -> Poll<io::Result<()>> {
83        self.inner.as_mut().poll_read(cx, buf)
84    }
85}
86
87impl AsyncWrite for BoxedStream {
88    fn poll_write(
89        mut self: Pin<&mut Self>,
90        cx: &mut Context<'_>,
91        buf: &[u8],
92    ) -> Poll<io::Result<usize>> {
93        self.inner.as_mut().poll_write(cx, buf)
94    }
95
96    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97        self.inner.as_mut().poll_flush(cx)
98    }
99
100    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        self.inner.as_mut().poll_shutdown(cx)
102    }
103}
104
105impl Unpin for BoxedStream {}
106
107// ---------------------------------------------------------------------------
108// TransportListener
109// ---------------------------------------------------------------------------
110
111/// An object-safe listener that accepts inbound connections.
112///
113/// Implemented by [`TokioListener`] for real TCP sockets; simulation backends
114/// provide their own implementation backed by in-memory channels.
115///
116/// The `accept` method returns a boxed future for dyn compatibility.
117pub trait TransportListener: Send + Sync {
118    /// Accept the next inbound connection.
119    fn accept(&mut self) -> AcceptFuture<'_>;
120
121    /// Return the local address this listener is bound to.
122    fn local_addr(&self) -> io::Result<SocketAddr>;
123}
124
125// ---------------------------------------------------------------------------
126// TokioListener
127// ---------------------------------------------------------------------------
128
129/// A [`TransportListener`] backed by a real [`tokio::net::TcpListener`].
130pub struct TokioListener(pub TcpListener);
131
132impl TransportListener for TokioListener {
133    fn accept(&mut self) -> AcceptFuture<'_> {
134        Box::pin(async move {
135            let (stream, addr) = self.0.accept().await?;
136            // RST on close instead of FIN → skip TIME_WAIT (see connect_tcp).
137            #[allow(deprecated)]
138            let _ = stream.set_linger(Some(std::time::Duration::ZERO));
139            Ok((BoxedStream::new(stream), addr))
140        })
141    }
142
143    fn local_addr(&self) -> io::Result<SocketAddr> {
144        self.0.local_addr()
145    }
146}
147
148// ---------------------------------------------------------------------------
149// NetworkFactory
150// ---------------------------------------------------------------------------
151
152/// Factory for creating TCP listeners and outbound connections.
153///
154/// In production, use [`NetworkFactory::tokio()`] to get a factory that
155/// delegates to real tokio networking. For simulation/testing, construct
156/// via [`NetworkFactory::new()`] with custom closures that route through
157/// in-memory channels.
158pub struct NetworkFactory {
159    bind_tcp: BindFn,
160    connect_tcp: ConnectFn,
161    is_simulated: bool,
162}
163
164impl NetworkFactory {
165    /// Create a factory with custom bind/connect closures.
166    ///
167    /// This is the primary constructor for simulation backends.
168    pub fn new(bind_tcp: BindFn, connect_tcp: ConnectFn, is_simulated: bool) -> Self {
169        Self {
170            bind_tcp,
171            connect_tcp,
172            is_simulated,
173        }
174    }
175
176    /// Create a factory that uses real tokio TCP networking.
177    pub fn tokio() -> Self {
178        Self {
179            bind_tcp: Box::new(|addr| {
180                Box::pin(async move {
181                    let listener = TcpListener::bind(addr).await?;
182                    Ok(Box::new(TokioListener(listener)) as Box<dyn TransportListener>)
183                })
184            }),
185            connect_tcp: Box::new(|addr| {
186                Box::pin(async move {
187                    let stream = TcpStream::connect(addr).await?;
188                    // RST on close instead of FIN → skip TIME_WAIT.
189                    // Peer connections are ephemeral; TIME_WAIT accumulation
190                    // degrades performance across rapid reconnection cycles.
191                    // Safe: linger(0) sends RST immediately, never blocks.
192                    #[allow(deprecated)]
193                    let _ = stream.set_linger(Some(std::time::Duration::ZERO));
194                    Ok(BoxedStream::new(stream))
195                })
196            }),
197            is_simulated: false,
198        }
199    }
200
201    /// Bind a TCP listener on the given address.
202    pub async fn bind_tcp(&self, addr: SocketAddr) -> io::Result<Box<dyn TransportListener>> {
203        (self.bind_tcp)(addr).await
204    }
205
206    /// Open an outbound TCP connection to the given address.
207    pub async fn connect_tcp(&self, addr: SocketAddr) -> io::Result<BoxedStream> {
208        (self.connect_tcp)(addr).await
209    }
210
211    /// Returns `true` if this factory uses simulated networking.
212    pub fn is_simulated(&self) -> bool {
213        self.is_simulated
214    }
215}
216
217// ---------------------------------------------------------------------------
218// Tests
219// ---------------------------------------------------------------------------
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use tokio::io::{AsyncReadExt, AsyncWriteExt};
225
226    #[test]
227    fn tokio_factory_creation() {
228        let _factory = NetworkFactory::tokio();
229    }
230
231    #[test]
232    fn tokio_factory_is_not_simulated() {
233        let factory = NetworkFactory::tokio();
234        assert!(!factory.is_simulated());
235    }
236
237    #[tokio::test]
238    async fn tokio_bind_and_accept() {
239        let factory = NetworkFactory::tokio();
240        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
241        let listener = factory.bind_tcp(addr).await.unwrap();
242        let local = listener.local_addr().unwrap();
243        assert_ne!(local.port(), 0);
244    }
245
246    #[tokio::test]
247    async fn tokio_connect_to_listener() {
248        let factory = NetworkFactory::tokio();
249        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
250        let mut listener = factory.bind_tcp(addr).await.unwrap();
251        let local = listener.local_addr().unwrap();
252
253        let accept_handle = tokio::spawn(async move { listener.accept().await.unwrap() });
254
255        let mut client = factory.connect_tcp(local).await.unwrap();
256        client.write_all(b"hello").await.unwrap();
257
258        let (mut server_stream, peer_addr) = accept_handle.await.unwrap();
259        assert_eq!(
260            peer_addr.ip(),
261            "127.0.0.1".parse::<std::net::IpAddr>().unwrap()
262        );
263
264        let mut buf = [0u8; 5];
265        server_stream.read_exact(&mut buf).await.unwrap();
266        assert_eq!(&buf, b"hello");
267    }
268
269    #[test]
270    fn custom_factory_is_simulated() {
271        let factory = NetworkFactory::new(
272            Box::new(|_addr| {
273                Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
274            }),
275            Box::new(|_addr| {
276                Box::pin(async move { Err(io::Error::new(io::ErrorKind::Unsupported, "stub")) })
277            }),
278            true,
279        );
280        assert!(factory.is_simulated());
281    }
282}