1#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![cfg_attr(not(test), warn(unused_crate_dependencies))]
4
5use std::{
6 fmt::Debug,
7 hash::Hash,
8 io::{self, IoSlice},
9 marker::PhantomData,
10 net::SocketAddr,
11 path::PathBuf,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15 time::{Duration, Instant},
16};
17
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use futures::{Future, FutureExt};
21use tokio::io::{AsyncRead, AsyncWrite};
22
23pub mod ipc;
24#[cfg(feature = "quic")]
25pub mod quic;
26pub mod tcp;
27#[cfg(feature = "tcp-tls")]
28pub mod tcp_tls;
29
30pub trait Address: Clone + Debug + Send + Sync + Unpin + Hash + Eq + 'static {}
32
33impl Address for SocketAddr {}
35
36impl Address for PathBuf {}
38
39pub struct MeteredIo<Io, S, A>
43where
44 Io: AsyncRead + AsyncWrite + PeerAddress<A>,
45 A: Address,
46{
47 inner: Io,
49 stats: Arc<ArcSwap<S>>,
51 next_refresh: Instant,
53 refresh_interval: Duration,
55
56 _marker: PhantomData<A>,
57}
58
59impl<Io, S, A> AsyncRead for MeteredIo<Io, S, A>
60where
61 Io: AsyncRead + AsyncWrite + PeerAddress<A> + Unpin,
62 A: Address,
63 S: for<'a> TryFrom<&'a Io, Error: Debug>,
64{
65 fn poll_read(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 buf: &mut tokio::io::ReadBuf<'_>,
69 ) -> Poll<io::Result<()>> {
70 let this = self.get_mut();
71
72 this.maybe_refresh();
73
74 Pin::new(&mut this.inner).poll_read(cx, buf)
75 }
76}
77
78impl<Io, S, A> AsyncWrite for MeteredIo<Io, S, A>
79where
80 Io: AsyncRead + AsyncWrite + PeerAddress<A> + Unpin,
81 A: Address,
82 S: for<'a> TryFrom<&'a Io, Error: Debug>,
83{
84 fn poll_write(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 buf: &[u8],
88 ) -> Poll<io::Result<usize>> {
89 let this = self.get_mut();
90
91 this.maybe_refresh();
92
93 Pin::new(&mut this.inner).poll_write(cx, buf)
94 }
95
96 fn poll_write_vectored(
97 self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 bufs: &[IoSlice<'_>],
100 ) -> Poll<io::Result<usize>> {
101 let this = self.get_mut();
102
103 this.maybe_refresh();
104
105 Pin::new(&mut this.inner).poll_write_vectored(cx, bufs)
106 }
107
108 fn is_write_vectored(&self) -> bool {
109 self.inner.is_write_vectored()
110 }
111
112 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113 let this = self.get_mut();
114
115 this.maybe_refresh();
116
117 Pin::new(&mut this.inner).poll_flush(cx)
118 }
119
120 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
121 let this = self.get_mut();
122
123 this.maybe_refresh();
124
125 Pin::new(&mut this.inner).poll_shutdown(cx)
126 }
127}
128
129impl<Io, M, A> PeerAddress<A> for MeteredIo<Io, M, A>
130where
131 Io: AsyncRead + AsyncWrite + PeerAddress<A>,
132 A: Address,
133{
134 fn peer_addr(&self) -> Result<A, io::Error> {
135 self.inner.peer_addr()
136 }
137}
138
139impl<Io, S, A> MeteredIo<Io, S, A>
140where
141 Io: AsyncRead + AsyncWrite + PeerAddress<A>,
142 A: Address,
143 S: for<'a> TryFrom<&'a Io, Error: Debug>,
144{
145 pub fn new(inner: Io, stats: Arc<ArcSwap<S>>) -> Self {
150 Self {
151 inner,
152 stats,
153 _marker: PhantomData,
154 next_refresh: Instant::now(),
155 refresh_interval: Duration::from_secs(2),
156 }
157 }
158
159 #[inline]
160 fn maybe_refresh(&mut self) {
161 let now = Instant::now();
162 if self.next_refresh <= now {
163 match S::try_from(&self.inner) {
164 Ok(stats) => {
165 self.stats.store(Arc::new(stats));
166 }
167 Err(e) => tracing::error!(errror = ?e, "failed to gather transport stats"),
168 }
169
170 self.next_refresh = now + self.refresh_interval;
171 }
172 }
173}
174
175#[async_trait]
180pub trait Transport<A: Address>: Send + Sync + Unpin + 'static {
181 type Io: AsyncRead + AsyncWrite + PeerAddress<A> + Send + Unpin;
186
187 type Stats: Default + Debug + Send + Sync + for<'a> TryFrom<&'a Self::Io, Error: Debug>;
189
190 type Error: std::error::Error + From<io::Error> + Send + Sync;
192
193 type Connect: Future<Output = Result<Self::Io, Self::Error>> + Send;
195
196 type Accept: Future<Output = Result<Self::Io, Self::Error>> + Send + Unpin;
199
200 type Control: Send + Sync + Unpin;
202
203 fn local_addr(&self) -> Option<A>;
205
206 async fn bind(&mut self, addr: A) -> Result<(), Self::Error>;
208
209 fn connect(&mut self, addr: A) -> Self::Connect;
212
213 fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Accept>;
216
217 fn on_control(&mut self, _ctrl: Self::Control) {}
220}
221
222pub trait TransportExt<A: Address>: Transport<A> {
224 fn accept(&mut self) -> Acceptor<'_, Self, A>
226 where
227 Self: Sized + Unpin,
228 {
229 Acceptor::new(self)
230 }
231}
232
233pub struct Acceptor<'a, T, A>
238where
239 T: Transport<A>,
240 A: Address,
241{
242 inner: &'a mut T,
243 pending: Option<T::Accept>,
245 _marker: PhantomData<A>,
246}
247
248impl<'a, T, A> Acceptor<'a, T, A>
249where
250 T: Transport<A>,
251 A: Address,
252{
253 fn new(inner: &'a mut T) -> Self {
255 Self { inner, pending: None, _marker: PhantomData }
256 }
257}
258
259impl<T, A> Future for Acceptor<'_, T, A>
260where
261 T: Transport<A> + Unpin,
262 A: Address,
263{
264 type Output = Result<T::Io, T::Error>;
265
266 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 let this = self.get_mut();
268
269 loop {
270 if let Some(pending) = this.pending.as_mut() {
272 match pending.poll_unpin(cx) {
273 Poll::Ready(res) => {
274 this.pending = None;
275 return Poll::Ready(res);
276 }
277 Poll::Pending => return Poll::Pending,
278 }
279 }
280
281 match Pin::new(&mut *this.inner).poll_accept(cx) {
283 Poll::Ready(accept) => {
284 this.pending = Some(accept);
285 continue;
286 }
287 Poll::Pending => return Poll::Pending,
288 }
289 }
290 }
291}
292
293pub trait PeerAddress<A: Address> {
295 fn peer_addr(&self) -> Result<A, io::Error>;
296}