gel_stream/common/
tokio_stream.rs1use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5use tokio::net::{TcpListener, TcpStream};
6#[cfg(unix)]
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::{AsHandle, PeekableStream, PeerCred, RemoteAddress, StreamMetadata, Transport};
10
11use super::target::{LocalAddress, ResolvedTarget};
12
13impl ResolvedTarget {
14 #[cfg(feature = "client")]
15 pub async fn connect(&self) -> std::io::Result<TokioStream> {
17 match self {
18 ResolvedTarget::SocketAddr(addr) => {
19 let stream = TcpStream::connect(addr).await?;
20 Ok(TokioStream::Tcp(stream))
21 }
22 #[cfg(unix)]
23 ResolvedTarget::UnixSocketAddr(path) => {
24 let stm = std::os::unix::net::UnixStream::connect_addr(path)?;
25 stm.set_nonblocking(true)?;
26 let stream = UnixStream::from_std(stm)?;
27 Ok(TokioStream::Unix(stream))
28 }
29 }
30 }
31
32 #[cfg(feature = "server")]
33 pub async fn listen(
34 &self,
35 ) -> std::io::Result<
36 impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
37 > {
38 self.listen_raw(None).await
39 }
40
41 #[cfg(feature = "server")]
42 pub async fn listen_backlog(
43 &self,
44 backlog: usize,
45 ) -> std::io::Result<
46 impl futures::Stream<Item = std::io::Result<(TokioStream, ResolvedTarget)>> + LocalAddress,
47 > {
48 if !self.is_tcp() {
49 return Err(std::io::Error::new(
50 std::io::ErrorKind::InvalidInput,
51 "Unix sockets do not support a connectionbacklog",
52 ));
53 }
54 let backlog = u32::try_from(backlog)
55 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
56 self.listen_raw(Some(backlog)).await
57 }
58
59 #[cfg(feature = "server")]
62 pub(crate) async fn listen_raw(
63 &self,
64 backlog: Option<u32>,
65 ) -> std::io::Result<TokioListenerStream> {
66 use std::net::SocketAddr;
67
68 use tokio::net::TcpSocket;
69
70 use crate::DEFAULT_TCP_BACKLOG;
71
72 match self {
73 ResolvedTarget::SocketAddr(addr) => {
74 let backlog = backlog.unwrap_or(DEFAULT_TCP_BACKLOG);
75 let socket = match addr {
76 SocketAddr::V4(..) => TcpSocket::new_v4()?,
77 SocketAddr::V6(..) => TcpSocket::new_v6()?,
78 };
79 socket.bind(*addr)?;
80 let listener = socket.listen(backlog)?;
81
82 Ok(TokioListenerStream::Tcp(listener))
83 }
84 #[cfg(unix)]
85 ResolvedTarget::UnixSocketAddr(path) => {
86 let listener = std::os::unix::net::UnixListener::bind_addr(path)?;
87 listener.set_nonblocking(true)?;
88 let listener = tokio::net::UnixListener::from_std(listener)?;
89 Ok(TokioListenerStream::Unix(listener))
90 }
91 }
92 }
93}
94
95pub(crate) enum TokioListenerStream {
96 Tcp(TcpListener),
97 #[cfg(unix)]
98 Unix(UnixListener),
99}
100
101impl LocalAddress for TokioListenerStream {
102 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
103 match self {
104 TokioListenerStream::Tcp(listener) => {
105 listener.local_addr().map(ResolvedTarget::SocketAddr)
106 }
107 #[cfg(unix)]
108 TokioListenerStream::Unix(listener) => listener
109 .local_addr()
110 .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into())),
111 }
112 }
113}
114
115impl futures::Stream for TokioListenerStream {
116 type Item = std::io::Result<(TokioStream, ResolvedTarget)>;
117
118 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 match self.get_mut() {
120 TokioListenerStream::Tcp(listener) => {
121 let (stream, addr) = ready!(listener.poll_accept(cx))?;
122 let stream = TokioStream::Tcp(stream);
123 let target = ResolvedTarget::SocketAddr(addr);
124 Poll::Ready(Some(Ok((stream, target))))
125 }
126 #[cfg(unix)]
127 TokioListenerStream::Unix(listener) => {
128 let (stream, addr) = ready!(listener.poll_accept(cx))?;
129 let stream = TokioStream::Unix(stream);
130 let target = ResolvedTarget::UnixSocketAddr(addr.into());
131 Poll::Ready(Some(Ok((stream, target))))
132 }
133 }
134 }
135}
136
137#[derive(derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
139pub enum TokioStream {
140 Tcp(
142 #[read]
143 #[write]
144 #[descriptor]
145 TcpStream,
146 ),
147 #[cfg(unix)]
149 Unix(
150 #[read]
151 #[write]
152 #[descriptor]
153 UnixStream,
154 ),
155}
156
157impl TokioStream {
158 #[cfg(feature = "keepalive")]
159 pub fn set_keepalive(&self, keepalive: Option<std::time::Duration>) -> std::io::Result<()> {
160 use socket2::*;
161 match self {
162 TokioStream::Tcp(stream) => {
163 let sock = socket2::SockRef::from(&stream);
164 if let Some(keepalive) = keepalive {
165 sock.set_tcp_keepalive(
166 &TcpKeepalive::new()
167 .with_interval(keepalive)
168 .with_time(keepalive),
169 )
170 } else {
171 sock.set_keepalive(false)
172 }
173 }
174 #[cfg(unix)]
175 TokioStream::Unix(_) => Err(std::io::Error::new(
176 std::io::ErrorKind::Unsupported,
177 "Unix sockets do not support keepalive",
178 )),
179 }
180 }
181}
182
183impl AsHandle for TokioStream {
184 #[cfg(windows)]
185 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
186 <Self as std::os::windows::io::AsSocket>::as_socket(self)
187 }
188
189 #[cfg(unix)]
190 fn as_fd(&self) -> std::os::fd::BorrowedFd {
191 <Self as std::os::fd::AsFd>::as_fd(self)
192 }
193}
194
195impl PeekableStream for TokioStream {
196 fn poll_peek(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &mut tokio::io::ReadBuf<'_>,
200 ) -> Poll<std::io::Result<usize>> {
201 match self.get_mut() {
202 TokioStream::Tcp(stream) => Pin::new(stream).poll_peek(cx, buf),
203 #[cfg(unix)]
204 TokioStream::Unix(stream) => loop {
205 ready!(stream.poll_read_ready(cx))?;
206 let sock = socket2::SockRef::from(&*stream);
207 break match sock.recv_with_flags(unsafe { buf.unfilled_mut() }, libc::MSG_PEEK) {
208 Ok(n) => Poll::Ready(Ok(n)),
209 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
210 continue;
211 }
212 Err(e) => Poll::Ready(Err(e)),
213 };
214 },
215 }
216 }
217}
218
219impl LocalAddress for TokioStream {
220 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
221 match self {
222 TokioStream::Tcp(stream) => <TcpStream as LocalAddress>::local_address(stream),
223 #[cfg(unix)]
224 TokioStream::Unix(stream) => <UnixStream as LocalAddress>::local_address(stream),
225 }
226 }
227}
228
229impl RemoteAddress for TokioStream {
230 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
231 match self {
232 TokioStream::Tcp(stream) => <TcpStream as RemoteAddress>::remote_address(stream),
233 #[cfg(unix)]
234 TokioStream::Unix(stream) => <UnixStream as RemoteAddress>::remote_address(stream),
235 }
236 }
237}
238
239impl PeerCred for TokioStream {
240 #[cfg(all(unix, feature = "tokio"))]
241 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
242 match self {
243 TokioStream::Unix(unix) => unix.peer_cred(),
244 TokioStream::Tcp(_) => Err(std::io::Error::new(
245 std::io::ErrorKind::Unsupported,
246 "TCP sockets do not support peer credentials",
247 )),
248 }
249 }
250}
251
252impl StreamMetadata for TokioStream {
253 fn transport(&self) -> Transport {
254 match self {
255 TokioStream::Tcp(_) => Transport::Tcp,
256 #[cfg(unix)]
257 TokioStream::Unix(_) => Transport::Unix,
258 }
259 }
260}
261
262impl LocalAddress for TcpStream {
263 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
264 self.local_addr().map(ResolvedTarget::SocketAddr)
265 }
266}
267
268impl RemoteAddress for TcpStream {
269 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
270 self.peer_addr().map(ResolvedTarget::SocketAddr)
271 }
272}
273
274impl PeerCred for TcpStream {
275 #[cfg(all(unix, feature = "tokio"))]
276 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
277 Err(std::io::Error::new(
278 std::io::ErrorKind::Unsupported,
279 "TCP sockets do not support peer credentials",
280 ))
281 }
282}
283
284impl StreamMetadata for TcpStream {
285 fn transport(&self) -> Transport {
286 Transport::Tcp
287 }
288}
289
290impl AsHandle for TcpStream {
291 #[cfg(windows)]
292 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
293 <Self as std::os::windows::io::AsSocket>::as_socket(self)
294 }
295
296 #[cfg(unix)]
297 fn as_fd(&self) -> std::os::fd::BorrowedFd {
298 <Self as std::os::fd::AsFd>::as_fd(self)
299 }
300}
301
302#[cfg(unix)]
303impl LocalAddress for UnixStream {
304 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
305 self.local_addr()
306 .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into()))
307 }
308}
309
310#[cfg(unix)]
311impl RemoteAddress for UnixStream {
312 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
313 self.peer_addr()
314 .map(|addr| ResolvedTarget::UnixSocketAddr(addr.into()))
315 }
316}
317
318#[cfg(unix)]
319impl PeerCred for UnixStream {
320 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
321 self.peer_cred()
322 }
323}
324
325#[cfg(unix)]
326impl StreamMetadata for UnixStream {
327 fn transport(&self) -> Transport {
328 Transport::Unix
329 }
330}
331
332#[cfg(unix)]
333impl AsHandle for UnixStream {
334 fn as_fd(&self) -> std::os::fd::BorrowedFd {
335 <Self as std::os::fd::AsFd>::as_fd(self)
336 }
337}