use std::sync::Arc;
use rustls::ClientConfig;
use rustls_pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use crate::Error;
pub enum LdapStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl AsyncRead for LdapStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
Self::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_read(cx, buf),
}
}
}
impl AsyncWrite for LdapStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
Self::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
Self::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
Self::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
Self::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_shutdown(cx),
}
}
}
pub fn default_tls_config() -> ClientConfig {
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
}
#[cfg(feature = "danger-disable-verify")]
pub fn danger_no_verify_tls_config() -> ClientConfig {
use rustls::DigitallySignedStruct;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::CertificateDer;
#[derive(Debug)]
struct NoVerify;
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth()
}
pub async fn upgrade_to_tls(
stream: TcpStream,
server_name: ServerName<'static>,
tls_config: Arc<ClientConfig>,
timeout: std::time::Duration,
) -> Result<LdapStream, Error> {
let connector = TlsConnector::from(tls_config);
let tls_stream =
match tokio::time::timeout(timeout, connector.connect(server_name, stream)).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(Error::Io(e)),
Err(_) => return Err(Error::Timeout),
};
Ok(LdapStream::Tls(Box::new(tls_stream)))
}