use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{env, fs, io, sync};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
fn main() {
if let Err(e) = run_server() {
eprintln!("FAILED: {}", e);
std::process::exit(1);
}
}
fn error(err: String) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
#[tokio::main]
async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = match env::args().nth(1) {
Some(ref p) => p.to_owned(),
None => "1337".to_owned(),
};
let addr = format!("127.0.0.1:{}", port).parse()?;
let tls_cfg = {
let certs = load_certs("examples/sample.pem")?;
let key = load_private_key("examples/sample.rsa")?;
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?;
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
sync::Arc::new(cfg)
};
let incoming = AddrIncoming::bind(&addr)?;
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
println!("Starting to serve on https://{}.", addr);
server.await?;
Ok(())
}
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
pub struct TlsStream {
state: State,
}
impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}
impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let mut response = Response::new(Body::empty());
match (req.method(), req.uri().path()) {
(&Method::GET, "/") => {
*response.body_mut() = Body::from("Try POST /echo\n");
}
(&Method::POST, "/echo") => {
*response.body_mut() = req.into_body();
}
_ => {
*response.status_mut() = StatusCode::NOT_FOUND;
}
};
Ok(response)
}
fn load_certs(filename: &str) -> io::Result<Vec<rustls::Certificate>> {
let certfile = fs::File::open(filename)
.map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = io::BufReader::new(certfile);
let certs = rustls_pemfile::certs(&mut reader)
.map_err(|_| error("failed to load certificate".into()))?;
Ok(certs
.into_iter()
.map(rustls::Certificate)
.collect())
}
fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> {
let keyfile = fs::File::open(filename)
.map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = io::BufReader::new(keyfile);
let keys = rustls_pemfile::rsa_private_keys(&mut reader)
.map_err(|_| error("failed to load private key".into()))?;
if keys.len() != 1 {
return Err(error("expected a single private key".into()));
}
Ok(rustls::PrivateKey(keys[0].clone()))
}