opengauss_native_tls/
lib.rs

1//! TLS support for `tokio-postgres` and `postgres` via `native-tls`.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use native_tls::{Certificate, TlsConnector};
7//! use opengauss_native_tls::MakeTlsConnector;
8//! use std::fs;
9//!
10//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
11//! let cert = fs::read("database_cert.pem")?;
12//! let cert = Certificate::from_pem(&cert)?;
13//! let connector = TlsConnector::builder()
14//!     .add_root_certificate(cert)
15//!     .build()?;
16//! let connector = MakeTlsConnector::new(connector);
17//!
18//! let connect_future = tokio_opengauss::connect(
19//!     "host=localhost user=postgres password=openGauss#2023 sslmode=require",
20//!     connector,
21//! );
22//!
23//! // ...
24//! # Ok(())
25//! # }
26//! ```
27//!
28//! ```no_run
29//! use native_tls::{Certificate, TlsConnector};
30//! use opengauss_native_tls::MakeTlsConnector;
31//! use std::fs;
32//!
33//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! let cert = fs::read("database_cert.pem")?;
35//! let cert = Certificate::from_pem(&cert)?;
36//! let connector = TlsConnector::builder()
37//!     .add_root_certificate(cert)
38//!     .build()?;
39//! let connector = MakeTlsConnector::new(connector);
40//!
41//! let client = opengauss::Client::connect(
42//!     "host=localhost user=postgres password=openGauss#2023 sslmode=require",
43//!     connector,
44//! )?;
45//! # Ok(())
46//! # }
47//! ```
48#![warn(rust_2018_idioms, clippy::all, missing_docs)]
49
50use std::future::Future;
51use std::io;
52use std::pin::Pin;
53use std::task::{Context, Poll};
54use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
55use tokio_opengauss::tls;
56#[cfg(feature = "runtime")]
57use tokio_opengauss::tls::MakeTlsConnect;
58use tokio_opengauss::tls::{ChannelBinding, TlsConnect};
59
60#[cfg(test)]
61mod test;
62
63/// A `MakeTlsConnect` implementation using the `native-tls` crate.
64///
65/// Requires the `runtime` Cargo feature (enabled by default).
66#[cfg(feature = "runtime")]
67#[derive(Clone)]
68pub struct MakeTlsConnector(native_tls::TlsConnector);
69
70#[cfg(feature = "runtime")]
71impl MakeTlsConnector {
72    /// Creates a new connector.
73    pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
74        MakeTlsConnector(connector)
75    }
76}
77
78#[cfg(feature = "runtime")]
79impl<S> MakeTlsConnect<S> for MakeTlsConnector
80where
81    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
82{
83    type Stream = TlsStream<S>;
84    type TlsConnect = TlsConnector;
85    type Error = native_tls::Error;
86
87    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
88        Ok(TlsConnector::new(self.0.clone(), domain))
89    }
90}
91
92/// A `TlsConnect` implementation using the `native-tls` crate.
93pub struct TlsConnector {
94    connector: tokio_native_tls::TlsConnector,
95    domain: String,
96}
97
98impl TlsConnector {
99    /// Creates a new connector configured to connect to the specified domain.
100    pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
101        TlsConnector {
102            connector: tokio_native_tls::TlsConnector::from(connector),
103            domain: domain.to_string(),
104        }
105    }
106}
107
108impl<S> TlsConnect<S> for TlsConnector
109where
110    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
111{
112    type Stream = TlsStream<S>;
113    type Error = native_tls::Error;
114    #[allow(clippy::type_complexity)]
115    type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
116
117    fn connect(self, stream: S) -> Self::Future {
118        let stream = BufReader::with_capacity(8192, stream);
119        let future = async move {
120            let stream = self.connector.connect(&self.domain, stream).await?;
121
122            Ok(TlsStream(stream))
123        };
124
125        Box::pin(future)
126    }
127}
128
129/// The stream returned by `TlsConnector`.
130pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
131
132impl<S> AsyncRead for TlsStream<S>
133where
134    S: AsyncRead + AsyncWrite + Unpin,
135{
136    fn poll_read(
137        mut self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &mut ReadBuf<'_>,
140    ) -> Poll<io::Result<()>> {
141        Pin::new(&mut self.0).poll_read(cx, buf)
142    }
143}
144
145impl<S> AsyncWrite for TlsStream<S>
146where
147    S: AsyncRead + AsyncWrite + Unpin,
148{
149    fn poll_write(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &[u8],
153    ) -> Poll<io::Result<usize>> {
154        Pin::new(&mut self.0).poll_write(cx, buf)
155    }
156
157    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
158        Pin::new(&mut self.0).poll_flush(cx)
159    }
160
161    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162        Pin::new(&mut self.0).poll_shutdown(cx)
163    }
164}
165
166impl<S> tls::TlsStream for TlsStream<S>
167where
168    S: AsyncRead + AsyncWrite + Unpin,
169{
170    fn channel_binding(&self) -> ChannelBinding {
171        match self.0.get_ref().tls_server_end_point().ok().flatten() {
172            Some(buf) => ChannelBinding::tls_server_end_point(buf),
173            None => ChannelBinding::none(),
174        }
175    }
176}