use std::{
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Instant,
};
use agnostic::{
net::{Net, TcpListener, TcpStream},
Runtime,
};
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub use futures_rustls::{
client, pki_types::ServerName, rustls, server, TlsAcceptor, TlsConnector,
};
use memberlist_core::transport::{TimeoutableReadStream, TimeoutableWriteStream};
use rustls::{client::danger::ServerCertVerifier, SignatureScheme};
use super::{Listener, PromisedStream, StreamLayer};
#[derive(Debug, Default)]
pub struct NoopCertificateVerifier;
impl NoopCertificateVerifier {
pub fn new() -> Arc<Self> {
Arc::new(Self)
}
}
impl ServerCertVerifier for NoopCertificateVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
pub struct TlsOptions {
#[viewit(
getter(const, style = "ref", attrs(doc = "Get the TLS acceptor."),),
setter(attrs(doc = "Set the TLS acceptor. (Builder pattern)"),)
)]
acceptor: TlsAcceptor,
#[viewit(
getter(const, style = "ref", attrs(doc = "Get the TLS connector."),),
setter(attrs(doc = "Set the TLS connector. (Builder pattern)"),)
)]
connector: TlsConnector,
#[viewit(
getter(const, style = "ref", attrs(doc = "Get the server name."),),
setter(attrs(doc = "Set the server name. (Builder pattern)"),)
)]
server_name: ServerName<'static>,
}
impl TlsOptions {
#[inline]
pub const fn new(
server_name: ServerName<'static>,
acceptor: TlsAcceptor,
connector: TlsConnector,
) -> Self {
Self {
acceptor,
connector,
server_name,
}
}
}
pub struct Tls<R> {
domain: ServerName<'static>,
acceptor: Arc<TlsAcceptor>,
connector: TlsConnector,
_marker: std::marker::PhantomData<R>,
}
impl<R> Tls<R> {
#[inline]
fn new_in(domain: ServerName<'static>, acceptor: TlsAcceptor, connector: TlsConnector) -> Self {
Self {
domain,
acceptor: Arc::new(acceptor),
connector,
_marker: std::marker::PhantomData,
}
}
}
impl<R: Runtime> StreamLayer for Tls<R> {
type Listener = TlsListener<R>;
type Stream = TlsStream<R>;
type Options = TlsOptions;
#[inline]
async fn new(options: Self::Options) -> io::Result<Self> {
Ok(Self::new_in(
options.server_name,
options.acceptor,
options.connector,
))
}
async fn connect(&self, addr: SocketAddr) -> io::Result<Self::Stream> {
let conn = <<R::Net as Net>::TcpStream as TcpStream>::connect(addr).await?;
let local_addr = conn.local_addr()?;
let stream = self.connector.connect(self.domain.clone(), conn).await?;
Ok(TlsStream {
stream: TlsStreamKind::Client {
stream,
read_deadline: None,
write_deadline: None,
},
peer_addr: addr,
local_addr,
})
}
async fn bind(&self, addr: SocketAddr) -> io::Result<Self::Listener> {
let acceptor = self.acceptor.clone();
<<R::Net as Net>::TcpListener as TcpListener>::bind(addr)
.await
.and_then(|ln| {
ln.local_addr().map(|local_addr| TlsListener {
ln,
acceptor,
local_addr,
})
})
}
async fn cache_stream(&self, _addr: SocketAddr, mut stream: Self::Stream) {
R::spawn_detach(async move {
let _ = stream.flush().await;
let _ = stream.close().await;
R::sleep(std::time::Duration::from_millis(100)).await;
});
}
fn is_secure() -> bool {
true
}
}
pub struct TlsListener<R: Runtime> {
ln: <R::Net as Net>::TcpListener,
acceptor: Arc<TlsAcceptor>,
local_addr: SocketAddr,
}
impl<R: Runtime> Listener for TlsListener<R> {
type Stream = TlsStream<R>;
async fn accept(&self) -> io::Result<(Self::Stream, std::net::SocketAddr)> {
let (conn, addr) = self.ln.accept().await?;
let stream = TlsAcceptor::accept(&self.acceptor, conn).await?;
Ok((
TlsStream {
stream: TlsStreamKind::Server {
stream,
read_deadline: None,
write_deadline: None,
},
peer_addr: addr,
local_addr: self.local_addr,
},
addr,
))
}
fn local_addr(&self) -> std::net::SocketAddr {
self.local_addr
}
async fn shutdown(&self) -> io::Result<()> {
TcpListener::shutdown(&self.ln).await
}
}
#[pin_project::pin_project]
enum TlsStreamKind<R: Runtime> {
Client {
#[pin]
stream: client::TlsStream<<R::Net as Net>::TcpStream>,
read_deadline: Option<Instant>,
write_deadline: Option<Instant>,
},
Server {
#[pin]
stream: server::TlsStream<<R::Net as Net>::TcpStream>,
read_deadline: Option<Instant>,
write_deadline: Option<Instant>,
},
}
impl<R: Runtime> AsyncRead for TlsStreamKind<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Client { stream, .. } => Pin::new(stream).poll_read(cx, buf),
Self::Server { stream, .. } => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl<R: Runtime> AsyncWrite for TlsStreamKind<R> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Client { stream, .. } => Pin::new(stream).poll_write(cx, buf),
Self::Server { stream, .. } => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Client { stream, .. } => Pin::new(stream).poll_flush(cx),
Self::Server { stream, .. } => Pin::new(stream).poll_flush(cx),
}
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Client { stream, .. } => Pin::new(stream).poll_close(cx),
Self::Server { stream, .. } => Pin::new(stream).poll_close(cx),
}
}
}
#[pin_project::pin_project]
pub struct TlsStream<R: Runtime> {
#[pin]
stream: TlsStreamKind<R>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
}
impl<R: Runtime> AsyncRead for TlsStream<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.project().stream.poll_read(cx, buf)
}
}
impl<R: Runtime> AsyncWrite for TlsStream<R> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
self.project().stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_close(cx)
}
}
impl<R: Runtime> TimeoutableReadStream for TlsStream<R> {
fn set_read_deadline(&mut self, deadline: Option<Instant>) {
match self {
Self {
stream: TlsStreamKind::Client { read_deadline, .. },
..
} => *read_deadline = deadline,
Self {
stream: TlsStreamKind::Server { read_deadline, .. },
..
} => *read_deadline = deadline,
}
}
fn read_deadline(&self) -> Option<Instant> {
match self {
Self {
stream: TlsStreamKind::Client { read_deadline, .. },
..
} => *read_deadline,
Self {
stream: TlsStreamKind::Server { read_deadline, .. },
..
} => *read_deadline,
}
}
}
impl<R: Runtime> TimeoutableWriteStream for TlsStream<R> {
fn set_write_deadline(&mut self, deadline: Option<Instant>) {
match self {
Self {
stream: TlsStreamKind::Client { write_deadline, .. },
..
} => *write_deadline = deadline,
Self {
stream: TlsStreamKind::Server { write_deadline, .. },
..
} => *write_deadline = deadline,
}
}
fn write_deadline(&self) -> Option<Instant> {
match self {
Self {
stream: TlsStreamKind::Client { write_deadline, .. },
..
} => *write_deadline,
Self {
stream: TlsStreamKind::Server { write_deadline, .. },
..
} => *write_deadline,
}
}
}
impl<R: Runtime> PromisedStream for TlsStream<R> {
#[inline]
fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[inline]
fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
}