use std::io;
use std::fs::File;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use log::error;
use futures::{pin_mut, ready, TryFuture};
use futures::future::Either;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{Accept, TlsAcceptor};
use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::server::TlsStream;
use crate::error::ExitError;
pub use tokio_rustls::rustls::ServerConfig;
pub fn create_server_config(
service: &str, key_path: &Path, cert_path: &Path
) -> Result<ServerConfig, ExitError> {
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(read_certs(cert_path)?, read_key(key_path)?)
.map_err(|err| {
error!("Failed to create {} TLS server config: {}", service, err);
ExitError::Generic
})
}
fn read_certs(cert_path: &Path) -> Result<Vec<Certificate>, ExitError> {
rustls_pemfile::certs(
&mut io::BufReader::new(
File::open(cert_path).map_err(|err| {
error!(
"Failed to open TLS certificate file '{}': {}.",
cert_path.display(), err
);
ExitError::Generic
})?
)
).map_err(|err| {
error!(
"Failed to read TLS certificate file '{}': {}.",
cert_path.display(), err
);
ExitError::Generic
}).map(|mut certs| {
certs.drain(..).map(Certificate).collect()
})
}
fn read_key(key_path: &Path) -> Result<PrivateKey, ExitError> {
use rustls_pemfile::Item::*;
let mut key_file = io::BufReader::new(
File::open(key_path).map_err(|err| {
error!(
"Failed to open TLS key file '{}': {}.",
key_path.display(), err
);
ExitError::Generic
})?
);
let mut key = None;
while let Some(item) =
rustls_pemfile::read_one(&mut key_file).transpose()
{
let item = item.map_err(|err| {
error!(
"Failed to read TLS key file '{}': {}.",
key_path.display(), err
);
ExitError::Generic
})?;
let bits = match item {
RSAKey(bits) | PKCS8Key(bits) => bits,
_ => continue
};
if key.is_some() {
error!(
"TLS key file '{}' contains multiple keys.",
key_path.display()
);
return Err(ExitError::Generic)
}
key = Some(PrivateKey(bits))
}
match key {
Some(key) => Ok(key),
None => {
error!(
"TLS key file '{}' does not contain any usable keys.",
key_path.display()
);
Err(ExitError::Generic)
}
}
}
pin_project! {
#[project = TlsTcpStreamProj]
enum TlsTcpStream {
Accept { #[pin] fut: Accept<TcpStream> },
Stream { #[pin] fut: TlsStream<TcpStream> },
Empty,
}
}
impl TlsTcpStream {
fn new(sock: TcpStream, tls: &TlsAcceptor) -> Self {
Self::Accept { fut: tls.accept(sock) }
}
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Pin<&mut Self>, io::Error>> {
match self.as_mut().project() {
TlsTcpStreamProj::Accept { fut } => {
match ready!(fut.try_poll(cx)) {
Ok(fut) => {
self.set(Self::Stream { fut });
Poll::Ready(Ok(self))
}
Err(err) => {
self.set(Self::Empty);
Poll::Ready(Err(err))
}
}
}
_ => Poll::Ready(Ok(self)),
}
}
}
impl AsyncRead for TlsTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>
) -> Poll<Result<(), io::Error>> {
let mut this = match ready!(self.poll_accept(cx)) {
Ok(this) => this,
Err(err) => return Poll::Ready(Err(err))
};
match this.as_mut().project() {
TlsTcpStreamProj::Stream { fut } => {
fut.poll_read(cx, buf)
}
TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
_ => unreachable!()
}
}
}
impl AsyncWrite for TlsTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8]
) -> Poll<Result<usize, io::Error>> {
let mut this = match ready!(self.poll_accept(cx)) {
Ok(this) => this,
Err(err) => return Poll::Ready(Err(err))
};
match this.as_mut().project() {
TlsTcpStreamProj::Stream { fut } => {
fut.poll_write(cx, buf)
}
TlsTcpStreamProj::Empty => { Poll::Ready(Ok(0)) }
_ => unreachable!()
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), io::Error>> {
let mut this = match ready!(self.poll_accept(cx)) {
Ok(this) => this,
Err(err) => return Poll::Ready(Err(err))
};
match this.as_mut().project() {
TlsTcpStreamProj::Stream { fut } => {
fut.poll_flush(cx)
}
TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
_ => unreachable!()
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), io::Error>> {
let mut this = match ready!(self.poll_accept(cx)) {
Ok(this) => this,
Err(err) => return Poll::Ready(Err(err))
};
match this.as_mut().project() {
TlsTcpStreamProj::Stream { fut } => {
fut.poll_shutdown(cx)
}
TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
_ => unreachable!()
}
}
}
pub struct MaybeTlsTcpStream {
sock: Either<TcpStream, TlsTcpStream>,
}
impl MaybeTlsTcpStream {
pub fn new(sock: TcpStream, tls: Option<&TlsAcceptor>) -> Self {
MaybeTlsTcpStream {
sock: match tls {
Some(tls) => Either::Right(TlsTcpStream::new(sock, tls)),
None => Either::Left(sock)
}
}
}
}
impl AsyncRead for MaybeTlsTcpStream {
fn poll_read(
mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf
) -> Poll<Result<(), io::Error>> {
match self.sock {
Either::Left(ref mut sock) => {
pin_mut!(sock);
sock.poll_read(cx, buf)
}
Either::Right(ref mut sock) => {
pin_mut!(sock);
sock.poll_read(cx, buf)
}
}
}
}
impl AsyncWrite for MaybeTlsTcpStream {
fn poll_write(
mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]
) -> Poll<Result<usize, io::Error>> {
match self.sock {
Either::Left(ref mut sock) => {
pin_mut!(sock);
sock.poll_write(cx, buf)
}
Either::Right(ref mut sock) => {
pin_mut!(sock);
sock.poll_write(cx, buf)
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>, cx: &mut Context
) -> Poll<Result<(), io::Error>> {
match self.sock {
Either::Left(ref mut sock) => {
pin_mut!(sock);
sock.poll_flush(cx)
}
Either::Right(ref mut sock) => {
pin_mut!(sock);
sock.poll_flush(cx)
}
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>, cx: &mut Context
) -> Poll<Result<(), io::Error>> {
match self.sock {
Either::Left(ref mut sock) => {
pin_mut!(sock);
sock.poll_shutdown(cx)
}
Either::Right(ref mut sock) => {
pin_mut!(sock);
sock.poll_shutdown(cx)
}
}
}
}