use std::sync::Arc;
use vox_core::ConnectionAcceptor;
use super::{ServeError, VoxListener, serve_listener};
pub struct WssListener {
tcp: tokio::net::TcpListener,
tls: tokio_rustls::TlsAcceptor,
}
impl WssListener {
pub async fn bind(
addr: impl tokio::net::ToSocketAddrs,
cert_path: &std::path::Path,
key_path: &std::path::Path,
) -> std::io::Result<Self> {
let tls = build_tls_acceptor(cert_path, key_path)?;
let tcp = tokio::net::TcpListener::bind(addr).await?;
Ok(Self { tcp, tls })
}
pub fn from_tcp(
tcp: tokio::net::TcpListener,
cert_path: &std::path::Path,
key_path: &std::path::Path,
) -> std::io::Result<Self> {
let tls = build_tls_acceptor(cert_path, key_path)?;
Ok(Self { tcp, tls })
}
}
fn build_tls_acceptor(
cert_path: &std::path::Path,
key_path: &std::path::Path,
) -> std::io::Result<tokio_rustls::TlsAcceptor> {
use std::io::BufReader;
let cert_file = std::fs::File::open(cert_path)?;
let certs: Vec<_> =
rustls_pemfile::certs(&mut BufReader::new(cert_file)).collect::<Result<_, _>>()?;
if certs.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no certificates found in cert file",
));
}
let key_file = std::fs::File::open(key_path)?;
let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))?.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no private key found in key file",
)
})?;
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
}
impl VoxListener for WssListener {
type Link = vox_websocket::WsLink<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>;
async fn accept(&mut self) -> std::io::Result<Self::Link> {
let (stream, _addr) = self.tcp.accept().await?;
let tls_stream = self.tls.accept(stream).await?;
vox_websocket::WsLink::server(tls_stream).await
}
}
fn parse_query_params(s: &str) -> (&str, std::collections::HashMap<String, std::path::PathBuf>) {
match s.split_once('?') {
None => (s, Default::default()),
Some((host, query)) => {
let params = query
.split('&')
.filter_map(|pair| {
let (k, v) = pair.split_once('=')?;
Some((k.to_string(), std::path::PathBuf::from(v)))
})
.collect();
(host, params)
}
}
}
pub(super) async fn serve_wss(
host: &str,
acceptor: impl ConnectionAcceptor,
) -> Result<(), ServeError> {
let (host_part, params) = parse_query_params(host);
let cert = params.get("cert").ok_or_else(|| {
ServeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"wss:// requires ?cert=/path/to/cert.pem query parameter",
))
})?;
let key = params.get("key").ok_or_else(|| {
ServeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"wss:// requires ?key=/path/to/key.pem query parameter",
))
})?;
let listener = WssListener::bind(host_part, cert.as_ref(), key.as_ref()).await?;
Ok(serve_listener(listener, acceptor).await?)
}