1pub mod connector;
4
5use std::{
6 future::Future,
7 pin::Pin,
8 task::{Context, Poll}
9};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, Result},
13 net::TcpStream
14};
15
16use futures::future::BoxFuture;
17
18#[cfg(unix)]
19use tokio::net::UnixStream;
20
21#[cfg(feature = "tls")]
22use tokio_rustls::client::TlsStream;
23
24pub use connector::{Connector, TcpConnInfo};
25
26#[cfg(unix)]
27pub use connector::UdsConnInfo;
28
29#[cfg(feature = "tls")]
30pub use connector::TlsTcpConnInfo;
31
32
33#[allow(clippy::large_enum_variant)]
36pub enum Stream {
37 Tcp(TcpStream),
39
40 #[cfg(unix)]
42 Uds(UnixStream),
43
44 #[cfg(feature = "tls")]
46 TlsTcp(TlsStream<TcpStream>)
47}
48
49impl Stream {
50 #[allow(clippy::result_large_err)]
53 pub fn try_into_tcp(self) -> std::result::Result<TcpStream, Self> {
54 if let Self::Tcp(strm) = self {
55 Ok(strm)
56 } else {
57 Err(self)
58 }
59 }
60
61 #[cfg(unix)]
64 #[allow(clippy::result_large_err)]
65 pub fn try_into_uds(self) -> std::result::Result<UnixStream, Self> {
66 if let Self::Uds(strm) = self {
67 Ok(strm)
68 } else {
69 Err(self)
70 }
71 }
72
73 #[cfg(unix)]
76 #[allow(clippy::result_large_err)]
77 pub fn try_into_tlstcp(
78 self
79 ) -> std::result::Result<TlsStream<TcpStream>, Self> {
80 if let Self::TlsTcp(strm) = self {
81 Ok(strm)
82 } else {
83 Err(self)
84 }
85 }
86}
87
88impl Stream {
89 #[inline]
90 pub const fn reqflush(&self) -> bool {
91 match self {
92 Self::Tcp(_) => false,
93 #[cfg(unix)]
94 Self::Uds(_) => false,
95 #[cfg(feature = "tls")]
96 Self::TlsTcp(_) => true
97 }
98 }
99
100 pub fn ciphersuite(&self) -> Option<String> {
101 match self {
102 #[cfg(feature = "tls")]
103 Self::TlsTcp(strm) => {
104 let (_, conn) = strm.get_ref();
105 let ciphersuite = conn.negotiated_cipher_suite()?;
106 Some(format!("{:?}", ciphersuite.suite()))
107 }
108 _ => None
109 }
110 }
111}
112
113macro_rules! delegate_call {
114 ($self:ident.$method:ident($($args:ident),+)) => {
115 unsafe {
116 match $self.get_unchecked_mut() {
117 Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+),
118 #[cfg(unix)]
119 Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+),
120 #[cfg(feature = "tls")]
121 Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+),
122 }
123 }
124 }
125}
126
127impl AsyncRead for Stream {
128 fn poll_read(
129 self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 buf: &mut ReadBuf<'_>
132 ) -> Poll<Result<()>> {
133 delegate_call!(self.poll_read(cx, buf))
134 }
135}
136
137impl AsyncWrite for Stream {
138 fn poll_write(
139 self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 buf: &[u8]
142 ) -> Poll<Result<usize>> {
143 delegate_call!(self.poll_write(cx, buf))
144 }
145
146 fn poll_flush(
147 self: Pin<&mut Self>,
148 cx: &mut Context<'_>
149 ) -> Poll<tokio::io::Result<()>> {
150 delegate_call!(self.poll_flush(cx))
151 }
152
153 fn poll_shutdown(
154 self: Pin<&mut Self>,
155 cx: &mut Context<'_>
156 ) -> Poll<tokio::io::Result<()>> {
157 delegate_call!(self.poll_shutdown(cx))
158 }
159}
160
161
162pub async fn with_conn<F, T>(mut strm: Stream, f: F) -> Result<T>
171where
172 F: for<'a> FnOnce(&'a mut Stream) -> BoxFuture<'a, Result<T>>,
173 T: Send
174{
175 let res = f(&mut strm).await;
177
178 strm.flush().await?;
180
181 strm.shutdown().await?;
183
184 res
185}
186
187pub async fn with_conn_owned<F, Fut, T>(strm: Stream, f: F) -> Result<T>
195where
196 F: FnOnce(Stream) -> Fut,
197 Fut: Future<Output = (Stream, Result<T>)>,
198 T: Send
199{
200 let (mut strm, res) = f(strm).await;
202
203 strm.flush().await?;
205
206 strm.shutdown().await?;
208
209 res
210}
211
212