use crate::proto::Protocol;
use crate::proto::{ALPN_H1, ALPN_H2};
use crate::Error;
use crate::Stream;
use crate::{AsyncRead, AsyncWrite};
use futures_util::future::poll_fn;
use futures_util::ready;
use rustls::Session;
use rustls::{ClientConfig, ClientSession};
use std::io;
use std::io::Read;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use webpki::DNSNameRef;
use webpki_roots::TLS_SERVER_ROOTS;
pub(crate) async fn wrap_tls_client(
stream: impl Stream,
domain: &str,
tls_disable_verify: bool,
) -> Result<(impl Stream, Protocol), Error> {
let mut config = ClientConfig::new();
config
.root_store
.add_server_trust_anchors(&TLS_SERVER_ROOTS);
if tls_disable_verify {
config
.dangerous()
.set_certificate_verifier(Arc::new(DisabledCertVerified));
}
config.alpn_protocols = vec![ALPN_H2.to_owned(), ALPN_H1.to_owned()];
let config = Arc::new(config);
let dnsname = DNSNameRef::try_from_ascii_str(domain)?;
let client = ClientSession::new(&config, dnsname);
let mut tls = TlsStream::new(stream, client);
let ret = poll_fn(|cx| Pin::new(&mut tls).poll_handshake(cx)).await;
trace!("tls handshake: {:?}", ret);
ret?;
let proto = Protocol::from_alpn(tls.tls.get_alpn_protocol());
Ok((tls, proto))
}
struct DisabledCertVerified;
impl rustls::ServerCertVerifier for DisabledCertVerified {
fn verify_server_cert(
&self,
_: &rustls::RootCertStore,
_: &[rustls::Certificate],
name: DNSNameRef,
_: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
warn!("Ignoring TLS verification for {:?}", name);
Ok(rustls::ServerCertVerified::assertion())
}
}
#[cfg(feature = "server")]
use rustls::ServerConfig;
#[cfg(feature = "server")]
pub(crate) fn configure_tls_server(config: &mut ServerConfig) {
config.alpn_protocols = vec![ALPN_H2.to_owned(), ALPN_H1.to_owned()];
}
#[cfg(feature = "server")]
pub(crate) async fn wrap_tls_server(
stream: impl Stream,
config: Arc<ServerConfig>,
) -> Result<(impl Stream, Protocol), Error> {
use rustls::ServerSession;
let server = ServerSession::new(&config);
let mut tls = TlsStream::new(stream, server);
let ret = poll_fn(|cx| Pin::new(&mut tls).poll_handshake(cx)).await;
trace!("tls handshake: {:?}", ret);
ret?;
let proto = Protocol::from_alpn(tls.tls.get_alpn_protocol());
Ok((tls, proto))
}
struct TlsStream<S, E> {
stream: S,
tls: E,
read_buf: Vec<u8>, write_buf: Vec<u8>,
wants_flush: bool,
plaintext: Vec<u8>,
plaintext_idx: usize,
}
impl<S: Stream, E: Session + Unpin + 'static> TlsStream<S, E> {
pub fn new(stream: S, tls: E) -> Self {
TlsStream {
stream,
tls,
read_buf: Vec::new(),
write_buf: Vec::new(),
wants_flush: false,
plaintext: Vec::new(),
plaintext_idx: 0,
}
}
fn plaintext_left(&self) -> usize {
self.plaintext.len() - self.plaintext_idx
}
#[allow(clippy::useless_let_if_seq)]
fn poll_tls(&mut self, cx: &mut Context, poll_for_read: bool) -> Poll<io::Result<()>> {
loop {
ready!(self.try_write_buf(cx))?;
if self.wants_flush {
ready!(Pin::new(&mut self.stream).poll_flush(cx))?;
self.wants_flush = false;
}
if self.read_buf.is_empty()
&& (poll_for_read && self.plaintext_left() == 0 || self.tls.is_handshaking())
{
let _ = self.try_read_buf(cx);
}
let mut did_tls_read_or_write = false;
if self.tls.wants_read() && !self.read_buf.is_empty() {
let mut sync = SyncStream::new(
&mut self.read_buf,
&mut self.write_buf,
&mut self.wants_flush,
);
let _ = ready!(blocking_to_poll(self.tls.read_tls(&mut sync), cx))?;
self.tls
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
if !self.tls.is_handshaking() {
let _ = self.tls.read_to_end(&mut self.plaintext)?;
}
did_tls_read_or_write = true;
}
if self.tls.wants_write() {
let mut sync = SyncStream::new(
&mut self.read_buf,
&mut self.write_buf,
&mut self.wants_flush,
);
let _ = ready!(blocking_to_poll(self.tls.write_tls(&mut sync), cx))?;
did_tls_read_or_write = true;
}
if did_tls_read_or_write {
continue;
}
if poll_for_read && self.plaintext_left() == 0 {
return Poll::Pending;
} else {
return Poll::Ready(Ok(()));
}
}
}
fn try_write_buf(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
if !self.write_buf.is_empty() {
let to_write = &self.write_buf[..];
let amount = ready!(Pin::new(&mut self.stream).poll_write(cx, to_write))?;
let rest = self.write_buf.split_off(amount);
self.write_buf = rest;
}
Ok(()).into()
}
fn try_read_buf(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut tmp = [0; 8_192];
let amount = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut tmp[..]))?;
self.read_buf.extend_from_slice(&tmp[0..amount]);
Ok(()).into()
}
fn poll_handshake(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut();
ready!(this.poll_tls(cx, false))?;
if this.tls.is_handshaking() {
Poll::Pending
} else {
Ok(()).into()
}
}
}
impl<S: Stream, E: Session + Unpin + 'static> AsyncRead for TlsStream<S, E> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.plaintext_left() == 0 {
ready!(this.poll_tls(cx, true))?;
}
let idx = this.plaintext_idx;
let amt = (&this.plaintext[idx..]).read(buf)?;
this.plaintext_idx += amt;
if this.plaintext_idx == this.plaintext.len() {
this.plaintext_idx = 0;
this.plaintext.clear();
}
Ok(amt).into()
}
}
impl<S: Stream, E: Session + Unpin + 'static> AsyncWrite for TlsStream<S, E> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.get_mut();
ready!(this.poll_tls(cx, false))?;
let amount = this.tls.write(buf)?;
Ok(amount).into()
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<Result<usize, io::Error>> {
let this = self.get_mut();
ready!(this.poll_tls(cx, false))?;
let amount = this.tls.write_vectored(bufs)?;
Ok(amount).into()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.get_mut();
ready!(this.poll_tls(cx, false))?;
this.tls.flush()?;
ready!(this.poll_tls(cx, false))?;
Ok(()).into()
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.get_mut();
ready!(this.poll_tls(cx, false))?;
this.tls.send_close_notify();
ready!(this.poll_tls(cx, false))?;
Pin::new(&mut this.stream).poll_close(cx)
}
}
struct SyncStream<'a> {
read_buf: &'a mut Vec<u8>,
write_buf: &'a mut Vec<u8>,
wants_flush: &'a mut bool,
}
impl<'a> SyncStream<'a> {
fn new(
read_buf: &'a mut Vec<u8>,
write_buf: &'a mut Vec<u8>,
wants_flush: &'a mut bool,
) -> Self {
SyncStream {
read_buf,
write_buf,
wants_flush,
}
}
}
impl<'a> io::Read for SyncStream<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let from = &mut self.read_buf;
if from.is_empty() {
return would_block();
}
let amt = from.as_slice().read(buf)?;
let rest = from.split_off(amt);
*self.read_buf = rest;
Ok(amt)
}
}
impl<'a> io::Write for SyncStream<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let to = &mut self.write_buf;
to.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
*self.wants_flush = true;
Ok(())
}
}
fn would_block() -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "block"))
}
fn blocking_to_poll<T>(result: io::Result<T>, cx: &mut Context) -> Poll<io::Result<T>> {
match result {
Ok(v) => Poll::Ready(Ok(v)),
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Err(e))
}
}
}
}