use async_stream::stream;
use core::task::{Context, Poll};
use env_logger;
use futures_util::stream::Stream;
use hyper::body::Bytes;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use log::LevelFilter;
use rustls::internal::msgs::enums::AlertDescription;
use rustls::TLSError;
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{read_one, Item};
use std::convert::Infallible;
use std::fs;
use std::io;
use std::io::Read;
use std::iter;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::process;
use std::sync::Arc;
use structopt::StructOpt;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{timeout, Duration};
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
#[derive(StructOpt, Debug)]
#[structopt(name = "kcup")]
struct KCupOpts {
#[structopt(short = "h", long = "host", default_value = "127.0.0.1", env = "HOST")]
host: String,
#[structopt(short = "p", long = "port", default_value = "5000", env = "PORT")]
port: i32,
#[structopt(
long = "stdin-read-timeout-seconds",
default_value = "60",
env = "STDIN_READ_TIMEOUT_SECONDS"
)]
stdin_read_timeout_seconds: u64,
#[structopt(
name = "FILE",
short = "f",
long = "file",
parse(from_os_str),
env = "FILE"
)]
file_path: Option<PathBuf>,
#[structopt(
name = "TLS_KEY",
long = "tls-key",
parse(from_os_str),
env = "TLS_KEY"
)]
tls_key_path: Option<PathBuf>,
#[structopt(
name = "TLS_CERT",
long = "tls-cert",
parse(from_os_str),
env = "TLS_CERT"
)]
tls_cert_path: Option<PathBuf>,
}
async fn serve_static_content(
req: Request<Body>,
content: Bytes,
) -> Result<Response<Body>, Infallible> {
match req.method() {
&Method::GET => Ok(Response::new(Body::from(content))),
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body("No such resource".into())
.unwrap()),
}
}
fn load_certs(filename: &str) -> io::Result<Vec<Certificate>> {
let cert_file = fs::File::open(filename).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to open {}: {}", filename, e),
)
})?;
let mut reader = io::BufReader::new(cert_file);
let mut certificates = Vec::new();
for item in iter::from_fn(|| read_one(&mut reader).transpose()) {
match item.unwrap() {
Item::X509Certificate(cert) => certificates.push(Certificate(cert)),
Item::RSAKey(_) => log::warn!("Unexpected RSAKey in TLS certificate file"),
Item::PKCS8Key(_) => log::warn!("Unexpected PKCS8Key in TLS certificate file"),
}
}
return Ok(certificates);
}
fn load_private_key(filename: &str) -> io::Result<PrivateKey> {
let key_file = fs::File::open(filename).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to open {}: {}", filename, e),
)
})?;
let mut reader = io::BufReader::new(key_file);
let mut key: Option<PrivateKey> = None;
for item in iter::from_fn(|| read_one(&mut reader).transpose()) {
match item.unwrap() {
Item::X509Certificate(_) => log::warn!("Unexpected RSAKey in TLS key file"),
Item::RSAKey(bytes) => {
key.replace(PrivateKey(bytes));
}
Item::PKCS8Key(bytes) => {
key.replace(PrivateKey(bytes));
}
}
}
if let None = key {
return Err(io::Error::new(
io::ErrorKind::Other,
"failed to parse a single RSA private key",
));
}
return Ok(key.unwrap());
}
struct HyperAcceptor<'a> {
acceptor: Pin<Box<dyn Stream<Item = Result<TlsStream<TcpStream>, io::Error>> + 'a>>,
}
impl hyper::server::accept::Accept for HyperAcceptor<'_> {
type Conn = TlsStream<TcpStream>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
Pin::new(&mut self.acceptor).poll_next(cx)
}
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
let mut filter_builder = env_logger::Builder::new();
filter_builder.filter(Some("rustls"), LevelFilter::Off);
filter_builder.parse_env(env_logger::Env::default().filter_or("LOG_LEVEL", "info"));
filter_builder.init();
let KCupOpts {
host,
port,
stdin_read_timeout_seconds,
file_path,
tls_key_path,
tls_cert_path,
} = KCupOpts::from_args();
let addr = String::from(format!("{}:{}", host, port)).parse::<SocketAddr>();
if let Err(_) = addr {
log::error!("Failed to parse host & port combination");
process::exit(1);
}
let addr = addr.unwrap(); log::info!("Server configured to run @ [{}]", addr);
let mut file_contents = String::new();
if let Some(path) = file_path {
log::info!("Reading file from path [{}]", path.to_string_lossy());
file_contents = fs::read_to_string(path)?;
} else {
log::info!(
"No file path provided, waiting for input on STDIN (max {} seconds)...",
stdin_read_timeout_seconds,
);
let stdin_read_task = tokio::task::spawn_blocking(move || {
let _ = io::stdin().read_to_string(&mut file_contents);
return file_contents;
});
match timeout(
Duration::from_secs(stdin_read_timeout_seconds),
stdin_read_task,
)
.await
{
Ok(Ok(contents)) => {
file_contents = contents;
log::info!("Successfully read input from STDIN");
}
_ => {
log::error!(
"Failed to read from STDIN after waiting {} seconds",
stdin_read_timeout_seconds
);
process::exit(1);
}
}
}
if file_contents.is_empty() {
log::error!(
"No file contents -- please ensure you've specified a file or fed in data via STDIN"
);
process::exit(1);
}
log::info!("Read [{}] characters", file_contents.len());
let file_contents_bytes = Bytes::from(file_contents);
if tls_key_path.is_none() || tls_cert_path.is_none() {
run_server_http(file_contents_bytes, &addr).await?;
} else {
let tls_key_path = tls_key_path.unwrap();
let tls_cert_path = tls_cert_path.unwrap();
run_server_https(file_contents_bytes, &addr, &tls_key_path, &tls_cert_path).await?;
};
Ok(())
}
async fn run_server_https(
file_contents_bytes: Bytes,
addr: &SocketAddr,
tls_key_path: &PathBuf,
tls_cert_path: &PathBuf,
) -> Result<(), std::io::Error> {
let svc_builder_fn = make_service_fn(move |_conn| {
let file_contents = file_contents_bytes.clone();
async {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
let file_contents = file_contents.clone();
serve_static_content(req, file_contents)
}))
}
});
log::info!(
"Building TLS configuration with key [{}] and (CA) certs [{}] ...",
tls_key_path.to_string_lossy(),
tls_cert_path.to_string_lossy(),
);
let tls_cfg = {
let certs = load_certs(&tls_cert_path.to_string_lossy())?;
let key = load_private_key(&tls_key_path.to_string_lossy())?;
let mut cfg = rustls::ServerConfig::new(
rustls::AllowAnyAnonymousOrAuthenticatedClient::new(rustls::RootCertStore::empty()),
);
cfg.set_single_cert(certs, key).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed setting cert for use: {}", e),
)
})?;
cfg.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
Arc::new(cfg)
};
log::info!("Binding TCP on port [{}]...", &addr);
let tcp = TcpListener::bind(&addr).await?;
let tls_acceptor = TlsAcceptor::from(tls_cfg);
let incoming_tls_stream = stream! {
loop {
let (socket, _) = tcp.accept().await?;
let stream = tls_acceptor.accept(socket);
match stream.await {
result @ Ok(_) => { yield result; },
Err(mut err) => {
if let Some(inner_err) = err.get_mut() {
if let Some(downcasted) = inner_err.downcast_mut::<TLSError>() {
match downcasted {
TLSError::AlertReceived(AlertDescription::BadCertificate) => {
log::debug!("TLS Error (ignored): {}", downcasted);
},
_ => log::warn!("TLS Error: {}", downcasted),
}
} else {
log::warn!("TLS Error: {}", inner_err);
}
}
}
}
}
};
let server = Server::builder(HyperAcceptor {
acceptor: Box::pin(incoming_tls_stream),
})
.serve(svc_builder_fn);
log::info!("Starting HTTPS server...");
if let Err(e) = server.await {
log::error!("Server error: {}", &e);
eprintln!("Server error: {}", e);
}
Ok(())
}
async fn run_server_http(file_contents_bytes: Bytes, addr: &SocketAddr) -> Result<(), io::Error> {
let svc_builder_fn = make_service_fn(move |_conn| {
let file_contents = file_contents_bytes.clone();
async {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
let file_contents = file_contents.clone();
serve_static_content(req, file_contents)
}))
}
});
let server = Server::bind(&addr).serve(svc_builder_fn);
log::info!("Starting HTTP server...");
if let Err(e) = server.await {
log::error!("Server error: {}", &e);
eprintln!("Server error: {}", e);
}
Ok(())
}