use {
alloc::sync::Arc,
core::net::{Ipv6Addr, SocketAddr},
derive_more::{Display, Error},
rustls::pki_types::{CertificateDer, PrivateKeyDer},
tokio_tungstenite::tungstenite::{
handshake::server::{ErrorResponse, Request, Response},
protocol::WebSocketConfig,
},
};
#[derive(Clone)]
pub struct HandshakeHandler(
Arc<dyn Fn(&Request, Response) -> Result<Response, ErrorResponse> + Send + Sync>,
);
impl HandshakeHandler {
pub fn new(
pred: impl Fn(&Request, Response) -> Result<Response, ErrorResponse> + Send + Sync + 'static,
) -> Self {
Self(Arc::new(pred))
}
pub fn from_arc(
pred: Arc<dyn Fn(&Request, Response) -> Result<Response, ErrorResponse> + Send + Sync>,
) -> Self {
Self(pred)
}
#[expect(
clippy::result_large_err,
reason = "`tokio_tungstenite` requires that we return the error unboxed,
so we cannot box it here"
)]
pub(crate) fn handle(&self, req: &Request, resp: Response) -> Result<Response, ErrorResponse> {
self.0(req, resp)
}
}
#[derive(Clone)]
#[must_use]
pub struct ServerConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) tls: Option<Arc<rustls::ServerConfig>>,
pub(crate) socket: WebSocketConfig,
pub(crate) handshake_handler: Option<HandshakeHandler>,
}
impl ServerConfig {
pub const fn builder() -> ServerConfigBuilder<WantsBindAddress> {
ServerConfigBuilder(WantsBindAddress(()))
}
#[must_use]
pub const fn bind_address(&self) -> SocketAddr {
self.bind_address
}
}
#[must_use]
pub struct ServerConfigBuilder<S>(S);
pub struct WantsBindAddress(());
pub struct WantsTlsConfig {
bind_address: SocketAddr,
}
impl ServerConfigBuilder<WantsBindAddress> {
pub fn with_bind_default(self, listening_port: u16) -> ServerConfigBuilder<WantsTlsConfig> {
self.with_bind_address(SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
listening_port,
))
}
pub const fn with_bind_address(
self,
bind_address: SocketAddr,
) -> ServerConfigBuilder<WantsTlsConfig> {
ServerConfigBuilder(WantsTlsConfig { bind_address })
}
}
impl ServerConfigBuilder<WantsTlsConfig> {
pub fn with_identity(self, identity: Identity) -> ServerConfig {
let crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(identity.cert_chain, identity.key_der)
.expect("identity is not valid");
self.with_tls_config(crypto)
}
pub fn with_tls_config(self, tls: impl Into<Arc<rustls::ServerConfig>>) -> ServerConfig {
let tls = tls.into();
self.with_tls(Some(tls))
}
pub fn with_no_encryption(self) -> ServerConfig {
self.with_tls(None)
}
fn with_tls(self, tls: Option<Arc<rustls::ServerConfig>>) -> ServerConfig {
ServerConfig {
bind_address: self.0.bind_address,
tls,
socket: WebSocketConfig::default(),
handshake_handler: None,
}
}
}
impl ServerConfig {
pub fn with_socket_config(self, socket: WebSocketConfig) -> Self {
Self { socket, ..self }
}
pub fn with_handshake_handler(self, handshake_handler: HandshakeHandler) -> Self {
Self {
handshake_handler: Some(handshake_handler),
..self
}
}
}
#[derive(Debug)]
pub struct Identity {
pub cert_chain: Vec<CertificateDer<'static>>,
pub key_der: PrivateKeyDer<'static>,
}
impl Identity {
#[must_use]
pub fn new(
cert_chain: impl IntoIterator<Item = CertificateDer<'static>>,
key_der: PrivateKeyDer<'static>,
) -> Self {
Self {
cert_chain: cert_chain.into_iter().collect::<Vec<_>>(),
key_der,
}
}
#[cfg(feature = "self-signed")]
#[expect(clippy::missing_panics_doc, reason = "shouldn't panic")]
pub fn self_signed(
subject_alt_names: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<Self, InvalidSan> {
use {
rcgen::{
CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256,
},
rustls::pki_types::PrivatePkcs8KeyDer,
};
let subject_alt_names = subject_alt_names
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let mut dname = DistinguishedName::new();
dname.push(DnType::CommonName, "aeronet self-signed");
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
.expect("algorithm for key pair should be supported");
let cert = CertificateParams::new(subject_alt_names)
.map_err(|_| InvalidSan)?
.self_signed(&key_pair)
.expect("inner params should be valid");
Ok(Self::new(
[cert.der().clone()],
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der())),
))
}
}
#[cfg(feature = "self-signed")]
#[derive(Debug, Display, Error)]
#[display("invalid SANs for self-signed certificate")]
pub struct InvalidSan;