#![deny(unsafe_code)]
#![deny(missing_docs)]
use std::{future::Future, pin::Pin, sync::Arc};
use rustls::{pki_types::ServerName, ClientConfig};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
use tokio_rustls::{client::TlsStream, Connect, TlsConnector};
use x509_cert::{
der::{
oid::db::rfc5912::{
ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_1_WITH_RSA_ENCRYPTION,
SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
},
Decode,
},
spki::ObjectIdentifier,
Certificate,
};
pub trait DigestImplementation {
fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8>;
}
#[cfg(feature = "aws-lc-rs")]
pub use aws_lc_rs_backend::AwsLcRsDigest;
#[cfg(feature = "graviola")]
pub use graviola_backend::GraviolaDigest;
#[cfg(feature = "ring")]
pub use ring_backend::RingDigest;
#[cfg(feature = "rustcrypto")]
pub use rustcrypto_backend::RustcryptoDigest;
#[cfg(feature = "aws-lc-rs")]
mod aws_lc_rs_backend {
use super::{DigestAlgorithm, DigestImplementation};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AwsLcRsDigest;
impl DigestImplementation for AwsLcRsDigest {
fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
let digest_alg = match algorithm {
DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
};
aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
}
}
}
#[cfg(feature = "graviola")]
mod graviola_backend {
use super::{DigestAlgorithm, DigestImplementation};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct GraviolaDigest;
impl DigestImplementation for GraviolaDigest {
fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
use graviola::hashing::Hash;
match algorithm {
DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
graviola::hashing::Sha256::hash(bytes).as_ref().to_vec()
}
DigestAlgorithm::Sha384 => graviola::hashing::Sha384::hash(bytes).as_ref().to_vec(),
DigestAlgorithm::Sha512 => graviola::hashing::Sha512::hash(bytes).as_ref().to_vec(),
}
}
}
}
#[cfg(feature = "ring")]
mod ring_backend {
use super::{DigestAlgorithm, DigestImplementation};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RingDigest;
impl DigestImplementation for RingDigest {
fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
let digest_alg = match algorithm {
DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &ring::digest::SHA256,
DigestAlgorithm::Sha384 => &ring::digest::SHA384,
DigestAlgorithm::Sha512 => &ring::digest::SHA512,
};
ring::digest::digest(digest_alg, bytes).as_ref().into()
}
}
}
#[cfg(feature = "rustcrypto")]
mod rustcrypto_backend {
use super::{DigestAlgorithm, DigestImplementation};
use sha2::{Digest, Sha256, Sha384, Sha512};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RustcryptoDigest;
impl DigestImplementation for RustcryptoDigest {
fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
match algorithm {
DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
Sha256::digest(bytes).as_slice().into()
}
DigestAlgorithm::Sha384 => Sha384::digest(bytes).as_slice().into(),
DigestAlgorithm::Sha512 => Sha512::digest(bytes).as_slice().into(),
}
}
}
}
#[derive(Clone)]
pub struct MakeRustlsConnect<D> {
config: Arc<ClientConfig>,
digest_impl: D,
}
impl<D> MakeRustlsConnect<D>
where
D: DigestImplementation,
{
pub fn new(config: ClientConfig, digest_impl: D) -> Self {
Self {
config: Arc::new(config),
digest_impl,
}
}
}
impl<D, S> MakeTlsConnect<S> for MakeRustlsConnect<D>
where
D: DigestImplementation + Clone + Unpin,
S: AsyncRead + AsyncWrite + Unpin + Send,
{
type Stream = RustlsStream<D, S>;
type TlsConnect = RustlsConnect<D>;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(domain).map(|dns_name| RustlsConnect {
dns_name: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
digest_impl: self.digest_impl.clone(),
})
}
}
#[doc(hidden)]
pub struct RustlsConnect<D> {
dns_name: ServerName<'static>,
connector: TlsConnector,
digest_impl: D,
}
#[doc(hidden)]
pub struct ConnectFuture<D, S> {
connect: Connect<S>,
digest_impl: D,
}
impl<D, S> Future for ConnectFuture<D, S>
where
D: DigestImplementation + Clone + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = std::io::Result<RustlsStream<D, S>>;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.get_mut();
let res = std::task::ready!(Pin::new(&mut this.connect).poll(cx));
std::task::Poll::Ready(res.map(|io| RustlsStream {
io,
digest_impl: this.digest_impl.clone(),
}))
}
}
impl<D, S> TlsConnect<S> for RustlsConnect<D>
where
D: DigestImplementation + Clone + Unpin,
S: AsyncRead + AsyncWrite + Unpin + Send,
{
type Stream = RustlsStream<D, S>;
type Error = std::io::Error;
type Future = ConnectFuture<D, S>;
fn connect(self, stream: S) -> Self::Future {
ConnectFuture {
connect: self.connector.connect(self.dns_name, stream),
digest_impl: self.digest_impl.clone(),
}
}
}
enum SignatureAlgorithm {
Sha1Rsa,
Sha256Rsa,
Sha384Rsa,
Sha512Rsa,
EcdsaSha256,
EcdsaSha384,
}
impl SignatureAlgorithm {
fn try_from_identifier(oid: &ObjectIdentifier) -> Option<Self> {
if oid == &SHA_1_WITH_RSA_ENCRYPTION {
Some(Self::Sha1Rsa)
} else if oid == &SHA_256_WITH_RSA_ENCRYPTION {
Some(Self::Sha256Rsa)
} else if oid == &SHA_384_WITH_RSA_ENCRYPTION {
Some(Self::Sha384Rsa)
} else if oid == &SHA_512_WITH_RSA_ENCRYPTION {
Some(Self::Sha512Rsa)
} else if oid == &ECDSA_WITH_SHA_256 {
Some(Self::EcdsaSha256)
} else if oid == &ECDSA_WITH_SHA_384 {
Some(Self::EcdsaSha384)
} else {
None
}
}
fn digest_algorithm(self) -> DigestAlgorithm {
match self {
Self::Sha1Rsa => DigestAlgorithm::Sha1,
Self::Sha256Rsa | Self::EcdsaSha256 => DigestAlgorithm::Sha256,
Self::Sha384Rsa | Self::EcdsaSha384 => DigestAlgorithm::Sha384,
Self::Sha512Rsa => DigestAlgorithm::Sha512,
}
}
}
pub enum DigestAlgorithm {
Sha1,
Sha256,
Sha384,
Sha512,
}
#[doc(hidden)]
pub struct RustlsStream<D, S> {
io: TlsStream<S>,
digest_impl: D,
}
impl<D, S> tokio_postgres::tls::TlsStream for RustlsStream<D, S>
where
D: DigestImplementation + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
let (_, session) = self.io.get_ref();
match session.peer_certificates() {
Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
.ok()
.and_then(|cert| {
SignatureAlgorithm::try_from_identifier(&cert.signature_algorithm.oid)
})
.map(|signature_algorithm| {
let digest = self
.digest_impl
.digest(signature_algorithm.digest_algorithm(), &certs[0]);
ChannelBinding::tls_server_end_point(digest)
})
.unwrap_or_else(ChannelBinding::none),
_ => ChannelBinding::none(),
}
}
}
impl<D, S> AsyncRead for RustlsStream<D, S>
where
D: Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
}
}
impl<D, S> AsyncWrite for RustlsStream<D, S>
where
D: Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().io).poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().io).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().io).poll_shutdown(cx)
}
fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored()
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().io).poll_write_vectored(cx, bufs)
}
}