gel_stream/common/
tokio_stream.rs1use std::net::{IpAddr, ToSocketAddrs};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8#[cfg(unix)]
9use tokio::net::{UnixListener, UnixStream};
10
11use super::target::{LocalAddress, ResolvedTarget};
12
13pub(crate) struct Resolver {
14 #[cfg(feature = "hickory")]
15 resolver: hickory_resolver::TokioResolver,
16}
17
18#[allow(unused)]
19async fn resolve_host_to_socket_addrs(host: String) -> std::io::Result<IpAddr> {
20 let res = tokio::task::spawn_blocking(move || format!("{}:0", host).to_socket_addrs())
21 .await
22 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Interrupted, e.to_string()))??;
23 res.into_iter()
24 .next()
25 .ok_or(std::io::Error::new(
26 std::io::ErrorKind::NotFound,
27 "No address found",
28 ))
29 .map(|addr| addr.ip())
30}
31
32impl Resolver {
33 pub fn new() -> Result<Self, std::io::Error> {
34 Ok(Self {
35 #[cfg(feature = "hickory")]
36 resolver: hickory_resolver::Resolver::builder_tokio()?.build(),
37 })
38 }
39
40 pub async fn resolve_remote(&self, host: String) -> std::io::Result<IpAddr> {
41 #[cfg(feature = "hickory")]
42 {
43 let addr = self.resolver.lookup_ip(host).await?.iter().next().unwrap();
44 Ok(addr)
45 }
46 #[cfg(not(feature = "hickory"))]
47 {
48 resolve_host_to_socket_addrs(host).await
49 }
50 }
51}
52
53impl ResolvedTarget {
54 #[cfg(feature = "client")]
55 pub async fn connect(&self) -> std::io::Result<TokioStream> {
57 match self {
58 ResolvedTarget::SocketAddr(addr) => {
59 let stream = TcpStream::connect(addr).await?;
60 Ok(TokioStream::Tcp(stream))
61 }
62 #[cfg(unix)]
63 ResolvedTarget::UnixSocketAddr(path) => {
64 let stm = std::os::unix::net::UnixStream::connect_addr(path)?;
65 stm.set_nonblocking(true)?;
66 let stream = UnixStream::from_std(stm)?;
67 Ok(TokioStream::Unix(stream))
68 }
69 }
70 }
71
72 #[cfg(feature = "server")]
73 pub async fn listen(
74 &self,
75 ) -> std::io::Result<
76 impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
77 > {
78 self.listen_raw().await
79 }
80
81 #[cfg(feature = "server")]
84 pub(crate) async fn listen_raw(&self) -> std::io::Result<TokioListenerStream> {
85 match self {
86 ResolvedTarget::SocketAddr(addr) => {
87 let listener = TcpListener::bind(addr).await?;
88 Ok(TokioListenerStream::Tcp(listener))
89 }
90 #[cfg(unix)]
91 ResolvedTarget::UnixSocketAddr(path) => {
92 let listener = std::os::unix::net::UnixListener::bind_addr(path)?;
93 listener.set_nonblocking(true)?;
94 let listener = tokio::net::UnixListener::from_std(listener)?;
95 Ok(TokioListenerStream::Unix(listener))
96 }
97 }
98 }
99}
100
101pub(crate) enum TokioListenerStream {
102 Tcp(TcpListener),
103 #[cfg(unix)]
104 Unix(UnixListener),
105}
106
107impl LocalAddress for TokioListenerStream {
108 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
109 match self {
110 TokioListenerStream::Tcp(listener) => {
111 listener.local_addr().map(ResolvedTarget::SocketAddr)
112 }
113 #[cfg(unix)]
114 TokioListenerStream::Unix(listener) => listener
115 .local_addr()
116 .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into())),
117 }
118 }
119}
120
121impl futures::Stream for TokioListenerStream {
122 type Item = std::io::Result<(TokioStream, ResolvedTarget)>;
123
124 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125 match self.get_mut() {
126 TokioListenerStream::Tcp(listener) => {
127 let (stream, addr) = ready!(listener.poll_accept(cx))?;
128 let stream = TokioStream::Tcp(stream);
129 let target = ResolvedTarget::SocketAddr(addr);
130 Poll::Ready(Some(Ok((stream, target))))
131 }
132 #[cfg(unix)]
133 TokioListenerStream::Unix(listener) => {
134 let (stream, addr) = ready!(listener.poll_accept(cx))?;
135 let stream = TokioStream::Unix(stream);
136 let target = ResolvedTarget::UnixSocketAddr(addr.into());
137 Poll::Ready(Some(Ok((stream, target))))
138 }
139 }
140 }
141}
142
143pub enum TokioStream {
145 Tcp(TcpStream),
147 #[cfg(unix)]
149 Unix(UnixStream),
150}
151
152impl TokioStream {
153 #[cfg(feature = "keepalive")]
154 pub fn set_keepalive(&self, keepalive: Option<std::time::Duration>) -> std::io::Result<()> {
155 use socket2::*;
156 match self {
157 TokioStream::Tcp(stream) => {
158 let sock = socket2::SockRef::from(&stream);
159 if let Some(keepalive) = keepalive {
160 sock.set_tcp_keepalive(
161 &TcpKeepalive::new()
162 .with_interval(keepalive)
163 .with_time(keepalive),
164 )
165 } else {
166 sock.set_keepalive(false)
167 }
168 }
169 #[cfg(unix)]
170 TokioStream::Unix(_) => Err(std::io::Error::new(
171 std::io::ErrorKind::Unsupported,
172 "Unix sockets do not support keepalive",
173 )),
174 }
175 }
176}
177
178impl AsyncRead for TokioStream {
179 #[inline(always)]
180 fn poll_read(
181 self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 buf: &mut tokio::io::ReadBuf<'_>,
184 ) -> Poll<std::io::Result<()>> {
185 match self.get_mut() {
186 TokioStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
187 #[cfg(unix)]
188 TokioStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
189 }
190 }
191}
192
193impl AsyncWrite for TokioStream {
194 #[inline(always)]
195 fn poll_write(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &[u8],
199 ) -> Poll<Result<usize, std::io::Error>> {
200 match self.get_mut() {
201 TokioStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
202 #[cfg(unix)]
203 TokioStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
204 }
205 }
206
207 #[inline(always)]
208 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
209 match self.get_mut() {
210 TokioStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
211 #[cfg(unix)]
212 TokioStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
213 }
214 }
215
216 #[inline(always)]
217 fn poll_shutdown(
218 self: Pin<&mut Self>,
219 cx: &mut Context<'_>,
220 ) -> Poll<Result<(), std::io::Error>> {
221 match self.get_mut() {
222 TokioStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
223 #[cfg(unix)]
224 TokioStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
225 }
226 }
227
228 #[inline(always)]
229 fn is_write_vectored(&self) -> bool {
230 match self {
231 TokioStream::Tcp(stream) => stream.is_write_vectored(),
232 #[cfg(unix)]
233 TokioStream::Unix(stream) => stream.is_write_vectored(),
234 }
235 }
236
237 #[inline(always)]
238 fn poll_write_vectored(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 bufs: &[std::io::IoSlice<'_>],
242 ) -> Poll<Result<usize, std::io::Error>> {
243 match self.get_mut() {
244 TokioStream::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
245 #[cfg(unix)]
246 TokioStream::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
247 }
248 }
249}