Skip to main content

msg_transport/
lib.rs

1#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![cfg_attr(not(test), warn(unused_crate_dependencies))]
4
5use std::{
6    fmt::Debug,
7    hash::Hash,
8    io::{self, IoSlice},
9    marker::PhantomData,
10    net::SocketAddr,
11    path::PathBuf,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15    time::{Duration, Instant},
16};
17
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use futures::{Future, FutureExt};
21use tokio::io::{AsyncRead, AsyncWrite};
22
23pub mod ipc;
24#[cfg(feature = "quic")]
25pub mod quic;
26pub mod tcp;
27#[cfg(feature = "tcp-tls")]
28pub mod tcp_tls;
29
30/// A trait for address types that can be used by any transport.
31pub trait Address: Clone + Debug + Send + Sync + Unpin + Hash + Eq + 'static {}
32
33/// IP address types, used for TCP and QUIC transports.
34impl Address for SocketAddr {}
35
36/// File system path, used for IPC transport.
37impl Address for PathBuf {}
38
39/// A wrapper around an `Io` object that records and provides transport-specific metrics.
40/// The link with the transport-specific metrics is achieved by the `S` type parameter, which
41/// must implement the `TryFrom<&Io>` trait.
42pub struct MeteredIo<Io, S, A>
43where
44    Io: AsyncRead + AsyncWrite + PeerAddress<A>,
45    A: Address,
46{
47    /// The inner IO object.
48    inner: Io,
49    /// The sender for the stats.
50    stats: Arc<ArcSwap<S>>,
51    /// The next time the stats should be refreshed.
52    next_refresh: Instant,
53    /// The interval at which the stats should be refreshed.
54    refresh_interval: Duration,
55
56    _marker: PhantomData<A>,
57}
58
59impl<Io, S, A> AsyncRead for MeteredIo<Io, S, A>
60where
61    Io: AsyncRead + AsyncWrite + PeerAddress<A> + Unpin,
62    A: Address,
63    S: for<'a> TryFrom<&'a Io, Error: Debug>,
64{
65    fn poll_read(
66        self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68        buf: &mut tokio::io::ReadBuf<'_>,
69    ) -> Poll<io::Result<()>> {
70        let this = self.get_mut();
71
72        this.maybe_refresh();
73
74        Pin::new(&mut this.inner).poll_read(cx, buf)
75    }
76}
77
78impl<Io, S, A> AsyncWrite for MeteredIo<Io, S, A>
79where
80    Io: AsyncRead + AsyncWrite + PeerAddress<A> + Unpin,
81    A: Address,
82    S: for<'a> TryFrom<&'a Io, Error: Debug>,
83{
84    fn poll_write(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        buf: &[u8],
88    ) -> Poll<io::Result<usize>> {
89        let this = self.get_mut();
90
91        this.maybe_refresh();
92
93        Pin::new(&mut this.inner).poll_write(cx, buf)
94    }
95
96    fn poll_write_vectored(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        bufs: &[IoSlice<'_>],
100    ) -> Poll<io::Result<usize>> {
101        let this = self.get_mut();
102
103        this.maybe_refresh();
104
105        Pin::new(&mut this.inner).poll_write_vectored(cx, bufs)
106    }
107
108    fn is_write_vectored(&self) -> bool {
109        self.inner.is_write_vectored()
110    }
111
112    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113        let this = self.get_mut();
114
115        this.maybe_refresh();
116
117        Pin::new(&mut this.inner).poll_flush(cx)
118    }
119
120    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
121        let this = self.get_mut();
122
123        this.maybe_refresh();
124
125        Pin::new(&mut this.inner).poll_shutdown(cx)
126    }
127}
128
129impl<Io, M, A> PeerAddress<A> for MeteredIo<Io, M, A>
130where
131    Io: AsyncRead + AsyncWrite + PeerAddress<A>,
132    A: Address,
133{
134    fn peer_addr(&self) -> Result<A, io::Error> {
135        self.inner.peer_addr()
136    }
137}
138
139impl<Io, S, A> MeteredIo<Io, S, A>
140where
141    Io: AsyncRead + AsyncWrite + PeerAddress<A>,
142    A: Address,
143    S: for<'a> TryFrom<&'a Io, Error: Debug>,
144{
145    /// Creates a new `MeteredIo` wrapper around the given `Io` object, and initializes default
146    /// stats. The `sender` is used to send the latest stats to the caller.
147    ///
148    /// TODO: Specify configuration options.
149    pub fn new(inner: Io, stats: Arc<ArcSwap<S>>) -> Self {
150        Self {
151            inner,
152            stats,
153            _marker: PhantomData,
154            next_refresh: Instant::now(),
155            refresh_interval: Duration::from_secs(2),
156        }
157    }
158
159    #[inline]
160    fn maybe_refresh(&mut self) {
161        let now = Instant::now();
162        if self.next_refresh <= now {
163            match S::try_from(&self.inner) {
164                Ok(stats) => {
165                    self.stats.store(Arc::new(stats));
166                }
167                Err(e) => tracing::error!(errror = ?e, "failed to gather transport stats"),
168            }
169
170            self.next_refresh = now + self.refresh_interval;
171        }
172    }
173}
174
175/// A transport provides connection-oriented communication between two peers through
176/// ordered and reliable streams of bytes.
177///
178/// It provides an interface to manage both inbound and outbound connections.
179#[async_trait]
180pub trait Transport<A: Address>: Send + Sync + Unpin + 'static {
181    /// The result of a successful connection.
182    ///
183    /// The output type is transport-specific, and can be a handle to directly write to the
184    /// connection, or it can be a substream multiplexer in the case of stream protocols.
185    type Io: AsyncRead + AsyncWrite + PeerAddress<A> + Send + Unpin;
186
187    /// The statistics for the transport (specifically its underlying IO object).
188    type Stats: Default + Debug + Send + Sync + for<'a> TryFrom<&'a Self::Io, Error: Debug>;
189
190    /// An error that occurred when setting up the connection.
191    type Error: std::error::Error + From<io::Error> + Send + Sync;
192
193    /// A pending output for an outbound connection, obtained when calling [`Transport::connect`].
194    type Connect: Future<Output = Result<Self::Io, Self::Error>> + Send;
195
196    /// A pending output for an inbound connection, obtained when calling
197    /// [`Transport::poll_accept`].
198    type Accept: Future<Output = Result<Self::Io, Self::Error>> + Send + Unpin;
199
200    /// Control-plane messages that modify the runtime behavior of the transport.
201    type Control: Send + Sync + Unpin;
202
203    /// Returns the local address this transport is bound to (if it is bound).
204    fn local_addr(&self) -> Option<A>;
205
206    /// Binds to the given address.
207    async fn bind(&mut self, addr: A) -> Result<(), Self::Error>;
208
209    /// Connects to the given address, returning a future representing a
210    /// pending outbound connection.
211    fn connect(&mut self, addr: A) -> Self::Connect;
212
213    /// Poll for incoming connections. If an inbound connection is received, a future representing
214    /// a pending inbound connection is returned. The future will resolve to [`Transport::Accept`].
215    fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Accept>;
216
217    /// Applies a control-plane message to the transport. It is expected to update internal state
218    /// only and should not perform long-running operations.
219    fn on_control(&mut self, _ctrl: Self::Control) {}
220}
221
222/// Extension trait for transports that provides additional methods.
223pub trait TransportExt<A: Address>: Transport<A> {
224    /// Async-friendly interface for accepting inbound connections.
225    fn accept(&mut self) -> Acceptor<'_, Self, A>
226    where
227        Self: Sized + Unpin,
228    {
229        Acceptor::new(self)
230    }
231}
232
233/// An `await`-friendly interface for accepting inbound connections.
234///
235/// This struct is used to accept inbound connections from a transport. It is
236/// created using the [`TransportExt::accept`] method.
237pub struct Acceptor<'a, T, A>
238where
239    T: Transport<A>,
240    A: Address,
241{
242    inner: &'a mut T,
243    /// The pending [`Transport::Accept`] future.
244    pending: Option<T::Accept>,
245    _marker: PhantomData<A>,
246}
247
248impl<'a, T, A> Acceptor<'a, T, A>
249where
250    T: Transport<A>,
251    A: Address,
252{
253    /// Creates a new `Acceptor` for the given transport.
254    fn new(inner: &'a mut T) -> Self {
255        Self { inner, pending: None, _marker: PhantomData }
256    }
257}
258
259impl<T, A> Future for Acceptor<'_, T, A>
260where
261    T: Transport<A> + Unpin,
262    A: Address,
263{
264    type Output = Result<T::Io, T::Error>;
265
266    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        let this = self.get_mut();
268
269        loop {
270            // If there's a pending accept future, poll it to completion
271            if let Some(pending) = this.pending.as_mut() {
272                match pending.poll_unpin(cx) {
273                    Poll::Ready(res) => {
274                        this.pending = None;
275                        return Poll::Ready(res);
276                    }
277                    Poll::Pending => return Poll::Pending,
278                }
279            }
280
281            // Otherwise, poll the transport for a new accept future
282            match Pin::new(&mut *this.inner).poll_accept(cx) {
283                Poll::Ready(accept) => {
284                    this.pending = Some(accept);
285                    continue;
286                }
287                Poll::Pending => return Poll::Pending,
288            }
289        }
290    }
291}
292
293/// Trait for connection types that can return their peer address.
294pub trait PeerAddress<A: Address> {
295    fn peer_addr(&self) -> Result<A, io::Error>;
296}