opengauss_native_tls/
lib.rs1#![warn(rust_2018_idioms, clippy::all, missing_docs)]
49
50use std::future::Future;
51use std::io;
52use std::pin::Pin;
53use std::task::{Context, Poll};
54use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
55use tokio_opengauss::tls;
56#[cfg(feature = "runtime")]
57use tokio_opengauss::tls::MakeTlsConnect;
58use tokio_opengauss::tls::{ChannelBinding, TlsConnect};
59
60#[cfg(test)]
61mod test;
62
63#[cfg(feature = "runtime")]
67#[derive(Clone)]
68pub struct MakeTlsConnector(native_tls::TlsConnector);
69
70#[cfg(feature = "runtime")]
71impl MakeTlsConnector {
72 pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
74 MakeTlsConnector(connector)
75 }
76}
77
78#[cfg(feature = "runtime")]
79impl<S> MakeTlsConnect<S> for MakeTlsConnector
80where
81 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
82{
83 type Stream = TlsStream<S>;
84 type TlsConnect = TlsConnector;
85 type Error = native_tls::Error;
86
87 fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
88 Ok(TlsConnector::new(self.0.clone(), domain))
89 }
90}
91
92pub struct TlsConnector {
94 connector: tokio_native_tls::TlsConnector,
95 domain: String,
96}
97
98impl TlsConnector {
99 pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
101 TlsConnector {
102 connector: tokio_native_tls::TlsConnector::from(connector),
103 domain: domain.to_string(),
104 }
105 }
106}
107
108impl<S> TlsConnect<S> for TlsConnector
109where
110 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
111{
112 type Stream = TlsStream<S>;
113 type Error = native_tls::Error;
114 #[allow(clippy::type_complexity)]
115 type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
116
117 fn connect(self, stream: S) -> Self::Future {
118 let stream = BufReader::with_capacity(8192, stream);
119 let future = async move {
120 let stream = self.connector.connect(&self.domain, stream).await?;
121
122 Ok(TlsStream(stream))
123 };
124
125 Box::pin(future)
126 }
127}
128
129pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
131
132impl<S> AsyncRead for TlsStream<S>
133where
134 S: AsyncRead + AsyncWrite + Unpin,
135{
136 fn poll_read(
137 mut self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &mut ReadBuf<'_>,
140 ) -> Poll<io::Result<()>> {
141 Pin::new(&mut self.0).poll_read(cx, buf)
142 }
143}
144
145impl<S> AsyncWrite for TlsStream<S>
146where
147 S: AsyncRead + AsyncWrite + Unpin,
148{
149 fn poll_write(
150 mut self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &[u8],
153 ) -> Poll<io::Result<usize>> {
154 Pin::new(&mut self.0).poll_write(cx, buf)
155 }
156
157 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
158 Pin::new(&mut self.0).poll_flush(cx)
159 }
160
161 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162 Pin::new(&mut self.0).poll_shutdown(cx)
163 }
164}
165
166impl<S> tls::TlsStream for TlsStream<S>
167where
168 S: AsyncRead + AsyncWrite + Unpin,
169{
170 fn channel_binding(&self) -> ChannelBinding {
171 match self.0.get_ref().tls_server_end_point().ok().flatten() {
172 Some(buf) => ChannelBinding::tls_server_end_point(buf),
173 None => ChannelBinding::none(),
174 }
175 }
176}