use std::{
ascii, fs, io,
net::SocketAddr,
path::{self, Path, PathBuf},
str,
sync::Arc,
};
use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use tracing::{error, info, info_span};
use tracing_futures::Instrument as _;
mod common;
#[derive(Parser, Debug)]
#[clap(name = "server")]
struct Opt {
#[clap(long = "keylog")]
keylog: bool,
#[clap(parse(from_os_str))]
root: PathBuf,
#[clap(parse(from_os_str), short = 'k', long = "key", requires = "cert")]
key: Option<PathBuf>,
#[clap(parse(from_os_str), short = 'c', long = "cert", requires = "key")]
cert: Option<PathBuf>,
#[clap(long = "stateless-retry")]
stateless_retry: bool,
#[clap(long = "listen", default_value = "[::1]:4433")]
listen: SocketAddr,
}
fn main() {
tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish(),
)
.unwrap();
let opt = Opt::parse();
let code = {
if let Err(e) = run(opt) {
eprintln!("ERROR: {}", e);
1
} else {
0
}
};
::std::process::exit(code);
}
#[tokio::main]
#[allow(clippy::field_reassign_with_default)] async fn run(options: Opt) -> Result<()> {
let (certs, key) = if let (Some(key_path), Some(cert_path)) = (&options.key, &options.cert) {
let key = fs::read(key_path).context("failed to read private key")?;
let key = if key_path.extension().map_or(false, |x| x == "der") {
rustls::PrivateKey(key)
} else {
let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*key)
.context("malformed PKCS #8 private key")?;
match pkcs8.into_iter().next() {
Some(x) => rustls::PrivateKey(x),
None => {
let rsa = rustls_pemfile::rsa_private_keys(&mut &*key)
.context("malformed PKCS #1 private key")?;
match rsa.into_iter().next() {
Some(x) => rustls::PrivateKey(x),
None => {
anyhow::bail!("no private keys found");
}
}
}
}
};
let cert_chain = fs::read(cert_path).context("failed to read certificate chain")?;
let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") {
vec![rustls::Certificate(cert_chain)]
} else {
rustls_pemfile::certs(&mut &*cert_chain)
.context("invalid PEM-encoded certificate")?
.into_iter()
.map(rustls::Certificate)
.collect()
};
(cert_chain, key)
} else {
let dirs = directories_next::ProjectDirs::from("org", "quinn", "quinn-examples").unwrap();
let path = dirs.data_local_dir();
let cert_path = path.join("cert.der");
let key_path = path.join("key.der");
let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
Ok(x) => x,
Err(ref e) if e.kind() == io::ErrorKind::NotFound => {
info!("generating self-signed certificate");
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = cert.serialize_private_key_der();
let cert = cert.serialize_der().unwrap();
fs::create_dir_all(path).context("failed to create certificate directory")?;
fs::write(&cert_path, &cert).context("failed to write certificate")?;
fs::write(&key_path, &key).context("failed to write private key")?;
(cert, key)
}
Err(e) => {
bail!("failed to read certificate: {}", e);
}
};
let key = rustls::PrivateKey(key);
let cert = rustls::Certificate(cert);
(vec![cert], key)
};
let mut server_crypto = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
server_crypto.alpn_protocols = common::ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect();
if options.keylog {
server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
}
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
Arc::get_mut(&mut server_config.transport)
.unwrap()
.max_concurrent_uni_streams(0_u8.into());
if options.stateless_retry {
server_config.use_retry(true);
}
let root = Arc::<Path>::from(options.root.clone());
if !root.exists() {
bail!("root path does not exist");
}
let endpoint = quinn::Endpoint::server(server_config, options.listen)?;
eprintln!("listening on {}", endpoint.local_addr()?);
while let Some(conn) = endpoint.accept().await {
info!("connection incoming");
let fut = handle_connection(root.clone(), conn);
tokio::spawn(async move {
if let Err(e) = fut.await {
error!("connection failed: {reason}", reason = e.to_string())
}
});
}
Ok(())
}
async fn handle_connection(root: Arc<Path>, conn: quinn::Connecting) -> Result<()> {
let connection = conn.await?;
let span = info_span!(
"connection",
remote = %connection.remote_address(),
protocol = %connection
.handshake_data()
.unwrap()
.downcast::<quinn::crypto::rustls::HandshakeData>().unwrap()
.protocol
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned())
);
async {
info!("established");
loop {
let stream = connection.accept_bi().await;
let stream = match stream {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
info!("connection closed");
return Ok(());
}
Err(e) => {
return Err(e);
}
Ok(s) => s,
};
let fut = handle_request(root.clone(), stream);
tokio::spawn(
async move {
if let Err(e) = fut.await {
error!("failed: {reason}", reason = e.to_string());
}
}
.instrument(info_span!("request")),
);
}
}
.instrument(span)
.await?;
Ok(())
}
async fn handle_request(
root: Arc<Path>,
(mut send, recv): (quinn::SendStream, quinn::RecvStream),
) -> Result<()> {
let req = recv
.read_to_end(64 * 1024)
.await
.map_err(|e| anyhow!("failed reading request: {}", e))?;
let mut escaped = String::new();
for &x in &req[..] {
let part = ascii::escape_default(x).collect::<Vec<_>>();
escaped.push_str(str::from_utf8(&part).unwrap());
}
info!(content = %escaped);
let resp = process_get(&root, &req).unwrap_or_else(|e| {
error!("failed: {}", e);
format!("failed to process request: {}\n", e).into_bytes()
});
send.write_all(&resp)
.await
.map_err(|e| anyhow!("failed to send response: {}", e))?;
send.finish()
.await
.map_err(|e| anyhow!("failed to shutdown stream: {}", e))?;
info!("complete");
Ok(())
}
fn process_get(root: &Path, x: &[u8]) -> Result<Vec<u8>> {
if x.len() < 4 || &x[0..4] != b"GET " {
bail!("missing GET");
}
if x[4..].len() < 2 || &x[x.len() - 2..] != b"\r\n" {
bail!("missing \\r\\n");
}
let x = &x[4..x.len() - 2];
let end = x.iter().position(|&c| c == b' ').unwrap_or(x.len());
let path = str::from_utf8(&x[..end]).context("path is malformed UTF-8")?;
let path = Path::new(&path);
let mut real_path = PathBuf::from(root);
let mut components = path.components();
match components.next() {
Some(path::Component::RootDir) => {}
_ => {
bail!("path must be absolute");
}
}
for c in components {
match c {
path::Component::Normal(x) => {
real_path.push(x);
}
x => {
bail!("illegal component in path: {:?}", x);
}
}
}
let data = fs::read(&real_path).context("failed reading file")?;
Ok(data)
}