gel_stream/common/
tokio_stream.rs

1//! This module provides functionality to connect to Tokio TCP and Unix sockets.
2
3use std::net::{IpAddr, ToSocketAddrs};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8#[cfg(unix)]
9use tokio::net::{UnixListener, UnixStream};
10
11use super::target::{LocalAddress, ResolvedTarget};
12
13pub(crate) struct Resolver {
14    #[cfg(feature = "hickory")]
15    resolver: hickory_resolver::TokioResolver,
16}
17
18#[allow(unused)]
19async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<IpAddr> {
20    let res = tokio::task::spawn_blocking(move || format!("{}:0", host).to_socket_addrs())
21        .await
22        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
23    res.into_iter()
24        .next()
25        .ok_or(std::io::Error::new(
26            std::io::ErrorKind::NotFound,
27            "No address found",
28        ))
29        .map(|addr| addr.ip())
30}
31
32impl Resolver {
33    pub fn new() -> Result<Self, std::io::Error> {
34        Ok(Self {
35            #[cfg(feature = "hickory")]
36            resolver: hickory_resolver::Resolver::builder_tokio()?.build(),
37        })
38    }
39
40    pub async fn resolve_remote(&self, host: String) -> std::io::Result<IpAddr> {
41        #[cfg(feature = "hickory")]
42        {
43            let addr = self.resolver.lookup_ip(host).await?.iter().next().unwrap();
44            Ok(addr)
45        }
46        #[cfg(not(feature = "hickory"))]
47        {
48            resolve_host_to_socket_addrs(host).await
49        }
50    }
51}
52
53impl ResolvedTarget {
54    #[cfg(feature = "client")]
55    /// Connects to the socket address and returns a [`TokioStream`].
56    pub async fn connect(&self) -> std::io::Result<TokioStream> {
57        match self {
58            ResolvedTarget::SocketAddr(addr) => {
59                let stream = TcpStream::connect(addr).await?;
60                Ok(TokioStream::Tcp(stream))
61            }
62            #[cfg(unix)]
63            ResolvedTarget::UnixSocketAddr(path) => {
64                let stm = std::os::unix::net::UnixStream::connect_addr(path)?;
65                stm.set_nonblocking(true)?;
66                let stream = UnixStream::from_std(stm)?;
67                Ok(TokioStream::Unix(stream))
68            }
69        }
70    }
71
72    #[cfg(feature = "server")]
73    pub async fn listen(
74        &self,
75    ) -> std::io::Result<
76        impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
77    > {
78        self.listen_raw().await
79    }
80
81    /// Listens for incoming connections on the socket address and returns a
82    /// [`futures::Stream`] of [`TokioStream`]s and the incoming address.
83    #[cfg(feature = "server")]
84    pub(crate) async fn listen_raw(&self) -> std::io::Result<TokioListenerStream> {
85        match self {
86            ResolvedTarget::SocketAddr(addr) => {
87                let listener = TcpListener::bind(addr).await?;
88                Ok(TokioListenerStream::Tcp(listener))
89            }
90            #[cfg(unix)]
91            ResolvedTarget::UnixSocketAddr(path) => {
92                let listener = std::os::unix::net::UnixListener::bind_addr(path)?;
93                listener.set_nonblocking(true)?;
94                let listener = tokio::net::UnixListener::from_std(listener)?;
95                Ok(TokioListenerStream::Unix(listener))
96            }
97        }
98    }
99}
100
101pub(crate) enum TokioListenerStream {
102    Tcp(TcpListener),
103    #[cfg(unix)]
104    Unix(UnixListener),
105}
106
107impl LocalAddress for TokioListenerStream {
108    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
109        match self {
110            TokioListenerStream::Tcp(listener) => {
111                listener.local_addr().map(ResolvedTarget::SocketAddr)
112            }
113            #[cfg(unix)]
114            TokioListenerStream::Unix(listener) => listener
115                .local_addr()
116                .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into())),
117        }
118    }
119}
120
121impl futures::Stream for TokioListenerStream {
122    type Item = std::io::Result<(TokioStream, ResolvedTarget)>;
123
124    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125        match self.get_mut() {
126            TokioListenerStream::Tcp(listener) => {
127                let (stream, addr) = ready!(listener.poll_accept(cx))?;
128                let stream = TokioStream::Tcp(stream);
129                let target = ResolvedTarget::SocketAddr(addr);
130                Poll::Ready(Some(Ok((stream, target))))
131            }
132            #[cfg(unix)]
133            TokioListenerStream::Unix(listener) => {
134                let (stream, addr) = ready!(listener.poll_accept(cx))?;
135                let stream = TokioStream::Unix(stream);
136                let target = ResolvedTarget::UnixSocketAddr(addr.into());
137                Poll::Ready(Some(Ok((stream, target))))
138            }
139        }
140    }
141}
142
143/// Represents a connected Tokio stream, either TCP or Unix
144pub enum TokioStream {
145    /// TCP stream
146    Tcp(TcpStream),
147    /// Unix stream (only available on Unix systems)
148    #[cfg(unix)]
149    Unix(UnixStream),
150}
151
152impl TokioStream {
153    #[cfg(feature = "keepalive")]
154    pub fn set_keepalive(&self, keepalive: Option<std::time::Duration>) -> std::io::Result<()> {
155        use socket2::*;
156        match self {
157            TokioStream::Tcp(stream) => {
158                let sock = socket2::SockRef::from(&stream);
159                if let Some(keepalive) = keepalive {
160                    sock.set_tcp_keepalive(
161                        &TcpKeepalive::new()
162                            .with_interval(keepalive)
163                            .with_time(keepalive),
164                    )
165                } else {
166                    sock.set_keepalive(false)
167                }
168            }
169            #[cfg(unix)]
170            TokioStream::Unix(_) => Err(std::io::Error::new(
171                std::io::ErrorKind::Unsupported,
172                "Unix sockets do not support keepalive",
173            )),
174        }
175    }
176}
177
178impl AsyncRead for TokioStream {
179    #[inline(always)]
180    fn poll_read(
181        self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        buf: &mut tokio::io::ReadBuf<'_>,
184    ) -> Poll<std::io::Result<()>> {
185        match self.get_mut() {
186            TokioStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
187            #[cfg(unix)]
188            TokioStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
189        }
190    }
191}
192
193impl AsyncWrite for TokioStream {
194    #[inline(always)]
195    fn poll_write(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        buf: &[u8],
199    ) -> Poll<Result<usize, std::io::Error>> {
200        match self.get_mut() {
201            TokioStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
202            #[cfg(unix)]
203            TokioStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
204        }
205    }
206
207    #[inline(always)]
208    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
209        match self.get_mut() {
210            TokioStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
211            #[cfg(unix)]
212            TokioStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
213        }
214    }
215
216    #[inline(always)]
217    fn poll_shutdown(
218        self: Pin<&mut Self>,
219        cx: &mut Context<'_>,
220    ) -> Poll<Result<(), std::io::Error>> {
221        match self.get_mut() {
222            TokioStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
223            #[cfg(unix)]
224            TokioStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
225        }
226    }
227
228    #[inline(always)]
229    fn is_write_vectored(&self) -> bool {
230        match self {
231            TokioStream::Tcp(stream) => stream.is_write_vectored(),
232            #[cfg(unix)]
233            TokioStream::Unix(stream) => stream.is_write_vectored(),
234        }
235    }
236
237    #[inline(always)]
238    fn poll_write_vectored(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241        bufs: &[std::io::IoSlice<'_>],
242    ) -> Poll<Result<usize, std::io::Error>> {
243        match self.get_mut() {
244            TokioStream::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
245            #[cfg(unix)]
246            TokioStream::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
247        }
248    }
249}