cloudpub_common/unix_tcp/
stream.rs

1/*
2 * Copyright (c) 2023, networkException <git@nwex.de>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause OR MIT
5 */
6
7#[cfg(unix)]
8use tokio::net::UnixStream;
9
10use std::io;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpStream;
16
17use crate::unix_tcp::{NamedSocketAddr, SocketAddr};
18
19#[derive(Debug)]
20pub enum Stream {
21    Tcp(TcpStream),
22    #[cfg(unix)]
23    Unix(UnixStream),
24}
25
26impl From<TcpStream> for Stream {
27    fn from(tcp_stream: TcpStream) -> Self {
28        Stream::Tcp(tcp_stream)
29    }
30}
31
32#[cfg(unix)]
33impl From<UnixStream> for Stream {
34    fn from(unix_stream: UnixStream) -> Self {
35        Stream::Unix(unix_stream)
36    }
37}
38
39impl Stream {
40    pub async fn connect(named_socket_addr: &NamedSocketAddr) -> io::Result<Self> {
41        match named_socket_addr {
42            NamedSocketAddr::Inet(inet_socket_addr) => {
43                TcpStream::connect(inet_socket_addr).await.map(Stream::Tcp)
44            }
45            #[cfg(unix)]
46            NamedSocketAddr::Unix(path) => UnixStream::connect(path).await.map(Stream::Unix),
47        }
48    }
49
50    pub fn local_addr(&self) -> io::Result<SocketAddr> {
51        match self {
52            Stream::Tcp(tcp_stream) => tcp_stream.local_addr().map(SocketAddr::Inet),
53            #[cfg(unix)]
54            Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
55        }
56    }
57
58    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
59        match self {
60            Stream::Tcp(tcp_stream) => tcp_stream.peer_addr().map(SocketAddr::Inet),
61            #[cfg(unix)]
62            Stream::Unix(unix_stream) => Ok(SocketAddr::Unix(unix_stream.local_addr()?.into())),
63        }
64    }
65}
66
67impl AsyncRead for Stream {
68    fn poll_read(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &mut ReadBuf<'_>,
72    ) -> Poll<io::Result<()>> {
73        match Pin::into_inner(self) {
74            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_read(cx, buf),
75            #[cfg(unix)]
76            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_read(cx, buf),
77        }
78    }
79}
80
81impl AsyncWrite for Stream {
82    fn poll_write(
83        self: Pin<&mut Self>,
84        cx: &mut Context<'_>,
85        buf: &[u8],
86    ) -> Poll<io::Result<usize>> {
87        match Pin::into_inner(self) {
88            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write(cx, buf),
89            #[cfg(unix)]
90            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write(cx, buf),
91        }
92    }
93
94    fn poll_write_vectored(
95        self: Pin<&mut Self>,
96        cx: &mut Context<'_>,
97        bufs: &[io::IoSlice<'_>],
98    ) -> Poll<io::Result<usize>> {
99        match Pin::into_inner(self) {
100            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_write_vectored(cx, bufs),
101            #[cfg(unix)]
102            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_write_vectored(cx, bufs),
103        }
104    }
105
106    fn is_write_vectored(&self) -> bool {
107        match self {
108            Stream::Tcp(tcp_stream) => tcp_stream.is_write_vectored(),
109            #[cfg(unix)]
110            Stream::Unix(unix_stream) => unix_stream.is_write_vectored(),
111        }
112    }
113
114    #[inline]
115    fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
116        match Pin::into_inner(self) {
117            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_flush(context),
118            #[cfg(unix)]
119            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_flush(context),
120        }
121    }
122
123    fn poll_shutdown(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<io::Result<()>> {
124        match Pin::into_inner(self) {
125            Stream::Tcp(tcp_stream) => Pin::new(tcp_stream).poll_shutdown(context),
126            #[cfg(unix)]
127            Stream::Unix(unix_stream) => Pin::new(unix_stream).poll_shutdown(context),
128        }
129    }
130}