Skip to main content

protwrap/tokio/
client.rs

1//! Helpers for working on the end-points initiating connection requests.
2
3pub mod connector;
4
5use std::{
6  future::Future,
7  pin::Pin,
8  task::{Context, Poll}
9};
10
11use tokio::{
12  io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, Result},
13  net::TcpStream
14};
15
16use futures::future::BoxFuture;
17
18#[cfg(unix)]
19use tokio::net::UnixStream;
20
21#[cfg(feature = "tls")]
22use tokio_rustls::client::TlsStream;
23
24pub use connector::{Connector, TcpConnInfo};
25
26#[cfg(unix)]
27pub use connector::UdsConnInfo;
28
29#[cfg(feature = "tls")]
30pub use connector::TlsTcpConnInfo;
31
32
33/// Representation of a stream acting as a client end-point (actively
34/// established connection).
35#[allow(clippy::large_enum_variant)]
36pub enum Stream {
37  /// TCP-based client stream.
38  Tcp(TcpStream),
39
40  /// Unix local domain client stream.
41  #[cfg(unix)]
42  Uds(UnixStream),
43
44  /// TLS, based on TCP, client stream.
45  #[cfg(feature = "tls")]
46  TlsTcp(TlsStream<TcpStream>)
47}
48
49impl Stream {
50  /// # Errors
51  /// Returns `self` if variant isn't `tcp`.
52  #[allow(clippy::result_large_err)]
53  pub fn try_into_tcp(self) -> std::result::Result<TcpStream, Self> {
54    if let Self::Tcp(strm) = self {
55      Ok(strm)
56    } else {
57      Err(self)
58    }
59  }
60
61  /// # Errors
62  /// Returns `self` if variant isn't `uds`.
63  #[cfg(unix)]
64  #[allow(clippy::result_large_err)]
65  pub fn try_into_uds(self) -> std::result::Result<UnixStream, Self> {
66    if let Self::Uds(strm) = self {
67      Ok(strm)
68    } else {
69      Err(self)
70    }
71  }
72
73  /// # Errors
74  /// Returns `self` if variant isn't `tlstcp`.
75  #[cfg(unix)]
76  #[allow(clippy::result_large_err)]
77  pub fn try_into_tlstcp(
78    self
79  ) -> std::result::Result<TlsStream<TcpStream>, Self> {
80    if let Self::TlsTcp(strm) = self {
81      Ok(strm)
82    } else {
83      Err(self)
84    }
85  }
86}
87
88impl Stream {
89  #[inline]
90  pub const fn reqflush(&self) -> bool {
91    match self {
92      Self::Tcp(_) => false,
93      #[cfg(unix)]
94      Self::Uds(_) => false,
95      #[cfg(feature = "tls")]
96      Self::TlsTcp(_) => true
97    }
98  }
99
100  pub fn ciphersuite(&self) -> Option<String> {
101    match self {
102      #[cfg(feature = "tls")]
103      Self::TlsTcp(strm) => {
104        let (_, conn) = strm.get_ref();
105        let ciphersuite = conn.negotiated_cipher_suite()?;
106        Some(format!("{:?}", ciphersuite.suite()))
107      }
108      _ => None
109    }
110  }
111}
112
113macro_rules! delegate_call {
114  ($self:ident.$method:ident($($args:ident),+)) => {
115    unsafe {
116      match $self.get_unchecked_mut() {
117        Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+),
118        #[cfg(unix)]
119        Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+),
120        #[cfg(feature = "tls")]
121        Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+),
122      }
123    }
124  }
125}
126
127impl AsyncRead for Stream {
128  fn poll_read(
129    self: Pin<&mut Self>,
130    cx: &mut Context<'_>,
131    buf: &mut ReadBuf<'_>
132  ) -> Poll<Result<()>> {
133    delegate_call!(self.poll_read(cx, buf))
134  }
135}
136
137impl AsyncWrite for Stream {
138  fn poll_write(
139    self: Pin<&mut Self>,
140    cx: &mut Context<'_>,
141    buf: &[u8]
142  ) -> Poll<Result<usize>> {
143    delegate_call!(self.poll_write(cx, buf))
144  }
145
146  fn poll_flush(
147    self: Pin<&mut Self>,
148    cx: &mut Context<'_>
149  ) -> Poll<tokio::io::Result<()>> {
150    delegate_call!(self.poll_flush(cx))
151  }
152
153  fn poll_shutdown(
154    self: Pin<&mut Self>,
155    cx: &mut Context<'_>
156  ) -> Poll<tokio::io::Result<()>> {
157    delegate_call!(self.poll_shutdown(cx))
158  }
159}
160
161
162/// Wrapper which forces shutdown.
163///
164/// rustls is picky about wanting `close_notify` being sent before closing the
165/// write side of the connection.  Because we don't have `AsyncDrop` (yet?),
166/// this wrapper can be used instead to make it a little more difficult to
167/// forget to perform the cleanup.
168///
169/// # Errors
170pub async fn with_conn<F, T>(mut strm: Stream, f: F) -> Result<T>
171where
172  F: for<'a> FnOnce(&'a mut Stream) -> BoxFuture<'a, Result<T>>,
173  T: Send
174{
175  // Run application closure
176  let res = f(&mut strm).await;
177
178  // Flush connection
179  strm.flush().await?;
180
181  // Explicit shutdown.  This should trigger `close_notify`.
182  strm.shutdown().await?;
183
184  res
185}
186
187/// Wrapper which forces shutdown.
188///
189/// This serves the same role as [`with_conn()`], but passes ownership of the
190/// connection to the application's closure.  The closure must return the
191/// connection ownership when done.
192///
193/// # Errors
194pub async fn with_conn_owned<F, Fut, T>(strm: Stream, f: F) -> Result<T>
195where
196  F: FnOnce(Stream) -> Fut,
197  Fut: Future<Output = (Stream, Result<T>)>,
198  T: Send
199{
200  // Run application closure
201  let (mut strm, res) = f(strm).await;
202
203  // Flush connection
204  strm.flush().await?;
205
206  // Explicit shutdown.  This should trigger `close_notify`.
207  strm.shutdown().await?;
208
209  res
210}
211
212// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 :