gel_stream/common/
tokio_stream.rs

1//! This module provides functionality to connect to Tokio TCP and Unix sockets.
2
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5use tokio::net::{TcpListener, TcpStream};
6#[cfg(unix)]
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::{AsHandle, PeekableStream, PeerCred, RemoteAddress, StreamMetadata, Transport};
10
11use super::target::{LocalAddress, ResolvedTarget};
12
13impl ResolvedTarget {
14    #[cfg(feature = "client")]
15    /// Connects to the socket address and returns a [`TokioStream`].
16    pub async fn connect(&self) -> std::io::Result<TokioStream> {
17        match self {
18            ResolvedTarget::SocketAddr(addr) => {
19                let stream = TcpStream::connect(addr).await?;
20                Ok(TokioStream::Tcp(stream))
21            }
22            #[cfg(unix)]
23            ResolvedTarget::UnixSocketAddr(path) => {
24                let stm = std::os::unix::net::UnixStream::connect_addr(path)?;
25                stm.set_nonblocking(true)?;
26                let stream = UnixStream::from_std(stm)?;
27                Ok(TokioStream::Unix(stream))
28            }
29        }
30    }
31
32    #[cfg(feature = "server")]
33    pub async fn listen(
34        &self,
35    ) -> std::io::Result<
36        impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
37    > {
38        self.listen_raw(None).await
39    }
40
41    #[cfg(feature = "server")]
42    pub async fn listen_backlog(
43        &self,
44        backlog: usize,
45    ) -> std::io::Result<
46        impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
47    > {
48        if !self.is_tcp() {
49            return Err(std::io::Error::new(
50                std::io::ErrorKind::InvalidInput,
51                "Unix sockets do not support a connectionbacklog",
52            ));
53        }
54        let backlog = u32::try_from(backlog)
55            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
56        self.listen_raw(Some(backlog)).await
57    }
58
59    /// Listens for incoming connections on the socket address and returns a
60    /// [`futures::Stream`] of [`TokioStream`]s and the incoming address.
61    #[cfg(feature = "server")]
62    pub(crate) async fn listen_raw(
63        &self,
64        backlog: Option<u32>,
65    ) -> std::io::Result<TokioListenerStream> {
66        use std::net::SocketAddr;
67
68        use tokio::net::TcpSocket;
69
70        use crate::DEFAULT_TCP_BACKLOG;
71
72        match self {
73            ResolvedTarget::SocketAddr(addr) => {
74                let backlog = backlog.unwrap_or(DEFAULT_TCP_BACKLOG);
75                let socket = match addr {
76                    SocketAddr::V4(..) => TcpSocket::new_v4()?,
77                    SocketAddr::V6(..) => TcpSocket::new_v6()?,
78                };
79                socket.bind(*addr)?;
80                let listener = socket.listen(backlog)?;
81
82                Ok(TokioListenerStream::Tcp(listener))
83            }
84            #[cfg(unix)]
85            ResolvedTarget::UnixSocketAddr(path) => {
86                let listener = std::os::unix::net::UnixListener::bind_addr(path)?;
87                listener.set_nonblocking(true)?;
88                let listener = tokio::net::UnixListener::from_std(listener)?;
89                Ok(TokioListenerStream::Unix(listener))
90            }
91        }
92    }
93}
94
95pub(crate) enum TokioListenerStream {
96    Tcp(TcpListener),
97    #[cfg(unix)]
98    Unix(UnixListener),
99}
100
101impl LocalAddress for TokioListenerStream {
102    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
103        match self {
104            TokioListenerStream::Tcp(listener) => {
105                listener.local_addr().map(ResolvedTarget::SocketAddr)
106            }
107            #[cfg(unix)]
108            TokioListenerStream::Unix(listener) => listener
109                .local_addr()
110                .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into())),
111        }
112    }
113}
114
115impl futures::Stream for TokioListenerStream {
116    type Item = std::io::Result<(TokioStream, ResolvedTarget)>;
117
118    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        match self.get_mut() {
120            TokioListenerStream::Tcp(listener) => {
121                let (stream, addr) = ready!(listener.poll_accept(cx))?;
122                let stream = TokioStream::Tcp(stream);
123                let target = ResolvedTarget::SocketAddr(addr);
124                Poll::Ready(Some(Ok((stream, target))))
125            }
126            #[cfg(unix)]
127            TokioListenerStream::Unix(listener) => {
128                let (stream, addr) = ready!(listener.poll_accept(cx))?;
129                let stream = TokioStream::Unix(stream);
130                let target = ResolvedTarget::UnixSocketAddr(addr.into());
131                Poll::Ready(Some(Ok((stream, target))))
132            }
133        }
134    }
135}
136
137/// Represents a connected Tokio stream, either TCP or Unix
138#[derive(derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
139pub enum TokioStream {
140    /// TCP stream
141    Tcp(
142        #[read]
143        #[write]
144        #[descriptor]
145        TcpStream,
146    ),
147    /// Unix stream (only available on Unix systems)
148    #[cfg(unix)]
149    Unix(
150        #[read]
151        #[write]
152        #[descriptor]
153        UnixStream,
154    ),
155}
156
157impl TokioStream {
158    #[cfg(feature = "keepalive")]
159    pub fn set_keepalive(&self, keepalive: Option<std::time::Duration>) -> std::io::Result<()> {
160        use socket2::*;
161        match self {
162            TokioStream::Tcp(stream) => {
163                let sock = socket2::SockRef::from(&stream);
164                if let Some(keepalive) = keepalive {
165                    sock.set_tcp_keepalive(
166                        &TcpKeepalive::new()
167                            .with_interval(keepalive)
168                            .with_time(keepalive),
169                    )
170                } else {
171                    sock.set_keepalive(false)
172                }
173            }
174            #[cfg(unix)]
175            TokioStream::Unix(_) => Err(std::io::Error::new(
176                std::io::ErrorKind::Unsupported,
177                "Unix sockets do not support keepalive",
178            )),
179        }
180    }
181}
182
183impl AsHandle for TokioStream {
184    #[cfg(windows)]
185    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
186        <Self as std::os::windows::io::AsSocket>::as_socket(self)
187    }
188
189    #[cfg(unix)]
190    fn as_fd(&self) -> std::os::fd::BorrowedFd {
191        <Self as std::os::fd::AsFd>::as_fd(self)
192    }
193}
194
195impl PeekableStream for TokioStream {
196    fn poll_peek(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &mut tokio::io::ReadBuf<'_>,
200    ) -> Poll<std::io::Result<usize>> {
201        match self.get_mut() {
202            TokioStream::Tcp(stream) => Pin::new(stream).poll_peek(cx, buf),
203            #[cfg(unix)]
204            TokioStream::Unix(stream) => loop {
205                ready!(stream.poll_read_ready(cx))?;
206                let sock = socket2::SockRef::from(&*stream);
207                break match sock.recv_with_flags(unsafe { buf.unfilled_mut() }, libc::MSG_PEEK) {
208                    Ok(n) => Poll::Ready(Ok(n)),
209                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
210                        continue;
211                    }
212                    Err(e) => Poll::Ready(Err(e)),
213                };
214            },
215        }
216    }
217}
218
219impl LocalAddress for TokioStream {
220    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
221        match self {
222            TokioStream::Tcp(stream) => <TcpStream as LocalAddress>::local_address(stream),
223            #[cfg(unix)]
224            TokioStream::Unix(stream) => <UnixStream as LocalAddress>::local_address(stream),
225        }
226    }
227}
228
229impl RemoteAddress for TokioStream {
230    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
231        match self {
232            TokioStream::Tcp(stream) => <TcpStream as RemoteAddress>::remote_address(stream),
233            #[cfg(unix)]
234            TokioStream::Unix(stream) => <UnixStream as RemoteAddress>::remote_address(stream),
235        }
236    }
237}
238
239impl PeerCred for TokioStream {
240    #[cfg(all(unix, feature = "tokio"))]
241    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
242        match self {
243            TokioStream::Unix(unix) => unix.peer_cred(),
244            TokioStream::Tcp(_) => Err(std::io::Error::new(
245                std::io::ErrorKind::Unsupported,
246                "TCP sockets do not support peer credentials",
247            )),
248        }
249    }
250}
251
252impl StreamMetadata for TokioStream {
253    fn transport(&self) -> Transport {
254        match self {
255            TokioStream::Tcp(_) => Transport::Tcp,
256            #[cfg(unix)]
257            TokioStream::Unix(_) => Transport::Unix,
258        }
259    }
260}
261
262impl LocalAddress for TcpStream {
263    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
264        self.local_addr().map(ResolvedTarget::SocketAddr)
265    }
266}
267
268impl RemoteAddress for TcpStream {
269    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
270        self.peer_addr().map(ResolvedTarget::SocketAddr)
271    }
272}
273
274impl PeerCred for TcpStream {
275    #[cfg(all(unix, feature = "tokio"))]
276    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
277        Err(std::io::Error::new(
278            std::io::ErrorKind::Unsupported,
279            "TCP sockets do not support peer credentials",
280        ))
281    }
282}
283
284impl StreamMetadata for TcpStream {
285    fn transport(&self) -> Transport {
286        Transport::Tcp
287    }
288}
289
290impl AsHandle for TcpStream {
291    #[cfg(windows)]
292    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
293        <Self as std::os::windows::io::AsSocket>::as_socket(self)
294    }
295
296    #[cfg(unix)]
297    fn as_fd(&self) -> std::os::fd::BorrowedFd {
298        <Self as std::os::fd::AsFd>::as_fd(self)
299    }
300}
301
302#[cfg(unix)]
303impl LocalAddress for UnixStream {
304    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
305        self.local_addr()
306            .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into()))
307    }
308}
309
310#[cfg(unix)]
311impl RemoteAddress for UnixStream {
312    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
313        self.peer_addr()
314            .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into()))
315    }
316}
317
318#[cfg(unix)]
319impl PeerCred for UnixStream {
320    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
321        self.peer_cred()
322    }
323}
324
325#[cfg(unix)]
326impl StreamMetadata for UnixStream {
327    fn transport(&self) -> Transport {
328        Transport::Unix
329    }
330}
331
332#[cfg(unix)]
333impl AsHandle for UnixStream {
334    fn as_fd(&self) -> std::os::fd::BorrowedFd {
335        <Self as std::os::fd::AsFd>::as_fd(self)
336    }
337}