#[cfg(feature = "rustls")]
pub mod rustls;
use crate::Socket;
#[cfg(feature = "rustls")]
use std::io::{BufRead, BufReader};
use std::ops::Deref;
use tokio::io::{AsyncRead, AsyncWrite};
pub trait CustomTlsConnector: Send + Sync + 'static {
fn connect<'a>(
&'a self,
domain: &'a str,
stream: Socket,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::Result<Socket>> + Send + 'a>>;
}
#[macro_export]
macro_rules! impl_tls_stream {
($type:ty, $field:ident) => {
impl tokio::io::AsyncRead for $type {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.$field).poll_read(cx, buf)
}
}
impl tokio::io::AsyncWrite for $type {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.$field).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.$field).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.$field).poll_shutdown(cx)
}
}
};
}
pub trait CustomTlsStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
fn peer_certificate(&self) -> Option<Vec<PeerCertificate>> {
None
}
fn alpn_protocol(&self) -> Option<Vec<u8>> {
None
}
}
#[derive(Clone, Debug)]
pub struct PeerCertificate {
pub inner: Vec<u8>,
}
impl Deref for PeerCertificate {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[derive(Clone, Debug)]
pub struct Certificate {
original: Cert,
}
#[derive(Clone, Debug)]
pub enum Cert {
Der(Vec<u8>),
Pem(Vec<u8>),
}
impl Cert {
pub fn as_der(&self) -> Option<&[u8]> {
match self {
Cert::Der(data) => Some(data),
Cert::Pem(_) => None,
}
}
pub fn as_pem(&self) -> Option<&[u8]> {
match self {
Cert::Pem(data) => Some(data),
Cert::Der(_) => None,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
match self {
Cert::Der(data) | Cert::Pem(data) => data.clone(),
}
}
}
impl Certificate {
pub fn original(&self) -> &Cert {
&self.original
}
pub fn from_der(der: &[u8]) -> crate::Result<Certificate> {
Ok(Certificate {
original: Cert::Der(der.to_owned()),
})
}
pub fn from_pem(pem: &[u8]) -> crate::Result<Certificate> {
Ok(Certificate {
original: Cert::Pem(pem.to_owned()),
})
}
pub fn from_pem_bundle(pem_bundle: &[u8]) -> crate::Result<Vec<Certificate>> {
#[cfg(feature = "rustls")]
{
let mut reader = BufReader::new(pem_bundle);
Self::read_pem_certs(&mut reader)?
.iter()
.map(|cert_vec| Certificate::from_der(cert_vec))
.collect::<crate::Result<Vec<Certificate>>>()
}
#[cfg(not(feature = "rustls"))]
{
Ok(vec![Certificate::from_pem(pem_bundle)?])
}
}
#[cfg(feature = "rustls")]
pub(crate) fn add_to_tls(
self,
root_cert_store: &mut tokio_rustls::rustls::RootCertStore,
) -> crate::Result<()> {
match self.original {
Cert::Der(buf) => root_cert_store
.add(buf.into())
.map_err(crate::errors::builder)?,
Cert::Pem(buf) => {
use std::io::Cursor;
let mut reader = Cursor::new(buf);
let certs = Self::read_pem_certs(&mut reader)?;
for c in certs {
root_cert_store
.add(c.into())
.map_err(crate::errors::builder)?;
}
}
}
Ok(())
}
#[cfg(feature = "rustls")]
fn read_pem_certs(reader: &mut impl BufRead) -> crate::Result<Vec<Vec<u8>>> {
use rustls_pki_types::pem::{PemObject, SectionKind};
let mut buf = Vec::new();
reader
.read_to_end(&mut buf)
.map_err(|_| crate::errors::builder("invalid certificate encoding"))?;
let mut certs: Vec<Vec<u8>> = Vec::new();
for item in <(SectionKind, Vec<u8>) as PemObject>::pem_slice_iter(&buf) {
match item {
Ok((kind, contents)) => {
if kind == SectionKind::Certificate {
certs.push(contents);
}
}
Err(_) => return Err(crate::errors::builder("invalid certificate encoding")),
}
}
Ok(certs)
}
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct Identity {
inner: ClientCert,
}
enum ClientCert {
#[cfg(feature = "rustls")]
RustlsPem {
key: rustls_pki_types::PrivateKeyDer<'static>,
certs: Vec<rustls_pki_types::CertificateDer<'static>>,
},
#[cfg(not(feature = "rustls"))]
CustomPem { pem_data: Vec<u8> },
}
impl Clone for ClientCert {
fn clone(&self) -> Self {
match self {
#[cfg(feature = "rustls")]
ClientCert::RustlsPem { key, certs } => ClientCert::RustlsPem {
key: key.clone_key(),
certs: certs.clone(),
},
#[cfg(not(feature = "rustls"))]
ClientCert::CustomPem { pem_data } => ClientCert::CustomPem {
pem_data: pem_data.clone(),
},
}
}
}
impl Identity {
pub fn from_pem(buf: &[u8]) -> crate::Result<Identity> {
#[cfg(feature = "rustls")]
{
use rustls_pki_types::pem::{PemObject, SectionKind};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
let mut certs = Vec::<CertificateDer<'static>>::new();
let mut keys = Vec::<PrivateKeyDer<'static>>::new();
for item in <(SectionKind, Vec<u8>) as PemObject>::pem_slice_iter(buf) {
match item {
Ok((kind, contents)) => match kind {
SectionKind::Certificate => certs.push(CertificateDer::from(contents)),
SectionKind::RsaPrivateKey | SectionKind::PrivateKey | SectionKind::EcPrivateKey => {
match PrivateKeyDer::try_from(contents) {
Ok(k) => keys.push(k),
Err(_) => {
return Err(crate::errors::builder(
tokio_rustls::rustls::Error::General(String::from("Invalid identity PEM file")),
))
}
}
}
_ => { }
},
Err(_) => {
return Err(crate::errors::builder(
tokio_rustls::rustls::Error::General(String::from("Invalid identity PEM file")),
))
}
}
}
if let (Some(sk), false) = (keys.pop(), certs.is_empty()) {
Ok(Identity {
inner: ClientCert::RustlsPem { key: sk, certs },
})
} else {
Err(crate::errors::builder(
tokio_rustls::rustls::Error::General(String::from(
"private key or certificate not found",
)),
))
}
}
#[cfg(not(feature = "rustls"))]
{
return Ok(Identity {
inner: ClientCert::CustomPem {
pem_data: buf.to_vec(),
},
});
}
}
#[cfg(feature = "rustls")]
pub(crate) fn add_to_tls(
self,
config_builder: tokio_rustls::rustls::ConfigBuilder<
tokio_rustls::rustls::ClientConfig,
tokio_rustls::rustls::client::WantsClientCert,
>,
) -> crate::Result<tokio_rustls::rustls::ClientConfig> {
let ClientCert::RustlsPem { key, certs } = self.inner;
config_builder
.with_client_auth_cert(certs, key)
.map_err(crate::errors::builder)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Version(InnerVersion);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
enum InnerVersion {
Tls1_0,
Tls1_1,
Tls1_2,
Tls1_3,
}
impl Version {
pub const TLS_1_0: Version = Version(InnerVersion::Tls1_0);
pub const TLS_1_1: Version = Version(InnerVersion::Tls1_1);
pub const TLS_1_2: Version = Version(InnerVersion::Tls1_2);
pub const TLS_1_3: Version = Version(InnerVersion::Tls1_3);
}