#![warn(missing_docs)]
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
sync::Arc,
};
pub use rcgen;
pub use rustls::{
CertificateError,
pki_types::{SubjectPublicKeyInfoDer, UnixTime},
};
pub use url::Url;
pub use web_transport_quinn::{self as web_transport, quinn, quinn::rustls};
use bytes::{Buf, BufMut, Bytes};
use quinn::{
ApplicationClose, ConnectionError,
congestion::ControllerFactory,
crypto::rustls::{QuicClientConfig, QuicServerConfig},
};
use rcgen::{CertificateParams, DistinguishedName as Dn, DnType, KeyPair};
use rustls::{
DigitallySignedStruct, DistinguishedName, KeyLogFile, SignatureScheme,
client::{
ResolvesClientCert,
danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
},
crypto::{CryptoProvider, WebPkiSupportedAlgorithms, verify_tls13_signature},
pki_types::{CertificateDer, PrivateKeyDer, ServerName, alg_id},
server::{
ClientHello, ParsedCertificate, ResolvesServerCert,
danger::{ClientCertVerified, ClientCertVerifier},
},
sign::CertifiedKey,
};
use time::{Duration, OffsetDateTime};
use tokio::sync::Mutex;
use tracing::trace;
use web_transport::{ALPN, SessionError};
pub fn install_crypto_provider() {
let mut provider = rustls::crypto::ring::default_provider();
let algos = Box::leak(
provider
.signature_verification_algorithms
.all
.iter()
.cloned()
.filter(|a| a.public_key_alg_id() != alg_id::RSA_ENCRYPTION)
.collect::<Vec<_>>()
.into_boxed_slice(),
);
let mappings = Box::leak(
provider
.signature_verification_algorithms
.mapping
.iter()
.cloned()
.filter(|(sig, _)| sig.as_str().is_some_and(|s| !s.contains("RSA")))
.collect::<Vec<_>>()
.into_boxed_slice(),
);
provider.signature_verification_algorithms = WebPkiSupportedAlgorithms {
all: algos,
mapping: mappings,
};
trace!(?provider, "mushi crypto provider");
provider.install_default().unwrap();
}
#[derive(Debug, Clone)]
pub struct EndpointKey {
scheme: SigScheme,
key: Arc<KeyPair>,
pub validity: Duration,
}
impl std::ops::Deref for EndpointKey {
type Target = KeyPair;
fn deref(&self) -> &Self::Target {
&self.key
}
}
pub type SigScheme = (SignatureScheme, &'static rcgen::SignatureAlgorithm);
pub const SIGSCHEME_ED25519: SigScheme = (SignatureScheme::ED25519, &rcgen::PKCS_ED25519);
pub const SIGSCHEME_ECDSA256: SigScheme = (
SignatureScheme::ECDSA_NISTP256_SHA256,
&rcgen::PKCS_ECDSA_P256_SHA256,
);
pub const SIGSCHEME_ECDSA384: SigScheme = (
SignatureScheme::ECDSA_NISTP384_SHA384,
&rcgen::PKCS_ECDSA_P384_SHA384,
);
const MUSHI_TLD: &str = "xn--zqsr9q";
impl EndpointKey {
pub fn generate() -> Result<Self, rcgen::Error> {
Self::generate_for(SIGSCHEME_ED25519)
}
pub fn generate_for(scheme: SigScheme) -> Result<Self, rcgen::Error> {
Ok(Self {
scheme,
key: Arc::new(KeyPair::generate_for(scheme.1)?),
validity: Duration::MINUTE * 2,
})
}
pub fn load(key: KeyPair, scheme: SigScheme) -> Self {
if !key.compatible_algs().any(|alg| alg == scheme.1) {
panic!("KeyPair is not compatible with {scheme:?}");
}
Self {
scheme,
key: Arc::new(key),
validity: Duration::MINUTE * 2,
}
}
fn supports_sigschemes(&self, requested: &[SignatureScheme]) -> bool {
requested.contains(&self.scheme.0)
}
fn get_certificate(&self) -> Option<Arc<CertifiedKey>> {
let cert = self.make_certificate().ok()?;
let provider = CryptoProvider::get_default().expect("a default CryptoProvider must be set");
Some(Arc::new(
CertifiedKey::from_der(
vec![cert.der().to_owned()],
PrivateKeyDer::Pkcs8(self.key.serialize_der().into()),
&provider,
)
.ok()?,
))
}
pub fn make_certificate(&self) -> Result<rcgen::Certificate, rcgen::Error> {
let print = ring::digest::digest(&ring::digest::SHA256, &self.key.public_key_der());
let puny = idna::punycode::encode_str(&base65536::encode(&print, None))
.unwrap_or(MUSHI_TLD.to_string());
let san = format!("xn--{puny}.{MUSHI_TLD}");
let mut cert = CertificateParams::new(vec![san.clone()])?;
cert.distinguished_name = Dn::new();
cert.distinguished_name.push(DnType::CommonName, san);
let start = OffsetDateTime::now_utc() - Duration::MINUTE;
cert.not_before = start;
cert.not_after = start + Duration::MINUTE + self.validity;
cert.self_signed(&self.key)
}
}
impl ResolvesClientCert for EndpointKey {
fn resolve(&self, _hints: &[&[u8]], schemes: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
if self.supports_sigschemes(schemes) {
self.get_certificate()
} else {
None
}
}
fn has_certs(&self) -> bool {
true
}
}
impl ResolvesServerCert for EndpointKey {
fn resolve(&self, _hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
self.get_certificate()
}
}
pub trait AllowConnection: std::fmt::Debug + Send + Sync + 'static {
fn allow_public_key(
&self,
key: SubjectPublicKeyInfoDer<'_>,
now: UnixTime,
) -> Result<(), CertificateError>;
fn require_client_auth(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Copy)]
pub struct AllowAllConnections;
impl AllowConnection for AllowAllConnections {
fn allow_public_key(
&self,
_key: SubjectPublicKeyInfoDer<'_>,
_now: UnixTime,
) -> Result<(), CertificateError> {
Ok(())
}
}
#[derive(Debug, Clone)]
struct ConnectionAllower(Arc<dyn AllowConnection>);
impl ServerCertVerifier for ConnectionAllower {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let cert = ParsedCertificate::try_from(end_entity)?;
self.0
.allow_public_key(cert.subject_public_key_info(), now)
.map_err(rustls::Error::from)
.and(Ok(ServerCertVerified::assertion()))
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
unimplemented!("mushi works exclusively over TLS 1.3")
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
let algos = CryptoProvider::get_default()
.expect("a default CryptoProvider must be set")
.signature_verification_algorithms;
verify_tls13_signature(message, cert, dss, &algos)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
CryptoProvider::get_default()
.expect("a default CryptoProvider must be set")
.signature_verification_algorithms
.supported_schemes()
}
}
impl ClientCertVerifier for ConnectionAllower {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
now: UnixTime,
) -> Result<ClientCertVerified, rustls::Error> {
let cert = ParsedCertificate::try_from(end_entity)?;
self.0
.allow_public_key(cert.subject_public_key_info(), now)
.map_err(rustls::Error::from)
.and(Ok(ClientCertVerified::assertion()))
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
unimplemented!("mushi works exclusively over TLS 1.3")
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
let algos = CryptoProvider::get_default()
.expect("a default CryptoProvider must be set")
.signature_verification_algorithms;
verify_tls13_signature(message, cert, dss, &algos)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
CryptoProvider::get_default()
.expect("a default CryptoProvider must be set")
.signature_verification_algorithms
.supported_schemes()
}
fn client_auth_mandatory(&self) -> bool {
self.0.require_client_auth()
}
}
#[derive(Clone)]
pub struct Endpoint {
client_config: quinn::ClientConfig,
server: Arc<Mutex<web_transport::Server>>,
key: Arc<EndpointKey>,
endpoint: quinn::Endpoint,
}
impl std::fmt::Debug for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let server = &f
.debug_struct("web_transport_quinn::Server")
.finish_non_exhaustive()?;
f.debug_struct("Endpoint")
.field("client_config", &self.client_config)
.field("server", &server)
.field("key", &self.key)
.field("endpoint", &self.endpoint)
.finish()
}
}
impl Endpoint {
pub fn new(
bind_to: impl ToSocketAddrs,
key: EndpointKey,
allower: Arc<dyn AllowConnection>,
cc: Option<Arc<(dyn ControllerFactory + Send + Sync + 'static)>>,
) -> Result<Self, Error> {
let provider = Arc::new(rustls::crypto::ring::default_provider());
let key = Arc::new(key);
let allower = Arc::new(ConnectionAllower(allower));
let mut server_config = rustls::ServerConfig::builder_with_provider(provider.clone())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
.with_client_cert_verifier(allower.clone())
.with_cert_resolver(key.clone());
server_config.alpn_protocols = vec![ALPN.to_vec()];
let mut client_config = rustls::ClientConfig::builder_with_provider(provider.clone())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
.dangerous()
.with_custom_certificate_verifier(allower)
.with_client_cert_resolver(key.clone());
client_config.alpn_protocols = vec![ALPN.to_vec()];
if cfg!(debug_assertions) {
server_config.key_log = Arc::new(KeyLogFile::new());
client_config.key_log = server_config.key_log.clone();
}
let mut transport = quinn::TransportConfig::default();
if let Some(cc) = cc {
transport.congestion_controller_factory(cc.clone());
}
let transport = Arc::new(transport);
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(server_config).unwrap(),
));
server_config.transport_config(transport.clone());
let mut client_config =
quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_config).unwrap()));
client_config.transport_config(transport);
let mut last_err = None;
let mut endpoint = None;
for addr in bind_to.to_socket_addrs()? {
match quinn::Endpoint::server(server_config.clone(), addr) {
Ok(s) => { endpoint = Some(s); break; },
Err(err) => { last_err = Some(err); }
}
}
let mut endpoint = match (endpoint, last_err) {
(Some(e), _) => e,
(None, Some(err)) => return Err(err.into()),
(None, None) => return Err(Error::NoAddrs),
};
endpoint.set_default_client_config(client_config.clone());
Ok(Self {
key,
client_config,
server: Arc::new(Mutex::new(web_transport::Server::new(endpoint.clone()))),
endpoint,
})
}
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.endpoint.local_addr().map_err(Error::from)
}
pub fn open_connections(&self) -> usize {
self.endpoint.open_connections()
}
pub fn stats(&self) -> quinn::EndpointStats {
self.endpoint.stats()
}
pub async fn connect(&self, addrs: impl ToSocketAddrs) -> Result<Session, Error> {
let mut last_err = None;
for mut addr in addrs.to_socket_addrs()? {
if addr.ip().is_unspecified() {
addr.set_ip(match addr.ip() {
IpAddr::V4(_) => Ipv4Addr::LOCALHOST.into(),
IpAddr::V6(_) => Ipv6Addr::LOCALHOST.into(),
});
}
let url = Url::parse(&format!("https://{addr}")).unwrap();
let conn =
self.endpoint
.connect_with(self.client_config.clone(), addr, "mushi.mushi")?;
let conn = conn.await?;
match web_transport::Session::connect(conn, &url).await {
Ok(s) => return Ok(Session::new(s)),
Err(e) => last_err = Some(Error::from(e)),
}
}
Err(last_err.unwrap_or(Error::NoAddrs))
}
pub async fn accept(&self) -> Option<Result<Session, Error>> {
match self.server.lock().await.accept().await {
Some(session) => Some(
session
.ok()
.await
.map(Session::new)
.map_err(|e| Error::Write(e.into())),
),
None => None,
}
}
pub fn key(&self) -> Arc<KeyPair> {
self.key.key.clone()
}
pub async fn wait_idle(&self) {
self.endpoint.wait_idle().await
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Session {
inner: web_transport::Session,
peer_key: Option<SubjectPublicKeyInfoDer<'static>>,
}
impl Session {
fn new(session: web_transport::Session) -> Self {
let peer_key = session.peer_identity().and_then(|id| {
let certs: Vec<CertificateDer> = *id.downcast().ok()?;
for cert in certs {
let Ok(cert) = ParsedCertificate::try_from(&cert) else {
continue;
};
return Some(cert.subject_public_key_info());
}
None
});
Self {
inner: session,
peer_key,
}
}
pub fn peer_key(&self) -> Option<&SubjectPublicKeyInfoDer<'_>> {
self.peer_key.as_ref()
}
pub unsafe fn as_quic(&self) -> &quinn::Connection {
&self.inner
}
pub async fn accept_uni(&self) -> Result<RecvStream, Error> {
let inner = self.inner.clone();
let stream = inner.accept_uni().await?;
Ok(RecvStream::new(stream))
}
pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), Error> {
let (s, r) = self.inner.accept_bi().await?;
Ok((SendStream::new(s), RecvStream::new(r)))
}
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), Error> {
let inner = self.inner.clone();
Ok(inner
.open_bi()
.await
.map(|(s, r)| (SendStream::new(s), RecvStream::new(r)))?)
}
pub async fn open_uni(&self) -> Result<SendStream, Error> {
let inner = self.inner.clone();
Ok(inner.open_uni().await.map(SendStream::new)?)
}
pub fn send_datagram(&self, payload: Bytes) -> Result<(), Error> {
let inner = self.inner.clone();
Ok(inner.send_datagram(payload)?)
}
pub async fn max_datagram_size(&self) -> usize {
self.inner.max_datagram_size()
}
pub async fn recv_datagram(&self) -> Result<Bytes, Error> {
let inner = self.inner.clone();
Ok(inner.read_datagram().await?)
}
pub fn close(&self, code: u32, reason: &str) {
let inner = self.inner.clone();
inner.close(code, reason.as_bytes())
}
pub async fn closed(&self) -> Result<Option<ApplicationClose>, Error> {
match self.inner.closed().await {
SessionError::ConnectionError(ConnectionError::LocallyClosed) => Ok(None),
SessionError::ConnectionError(ConnectionError::ApplicationClosed(ac)) => Ok(Some(ac)),
e => Err(Error::Session(e)),
}
}
}
#[derive(Debug)]
pub struct SendStream {
inner: web_transport::SendStream,
}
impl SendStream {
fn new(inner: web_transport::SendStream) -> Self {
Self { inner }
}
pub async fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
self.inner.write_all(buf).await?;
Ok(())
}
pub async fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Result<(), Error> {
while buf.has_remaining() {
let size = self.inner.write(buf.chunk()).await?;
buf.advance(size);
}
Ok(())
}
pub fn set_priority(&mut self, order: i32) {
self.inner.set_priority(order).ok();
}
pub fn reset(&mut self, code: u32) {
self.inner.reset(code).ok();
}
}
#[derive(Debug)]
pub struct RecvStream {
inner: web_transport::RecvStream,
}
impl RecvStream {
fn new(inner: web_transport::RecvStream) -> Self {
Self { inner }
}
pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
Ok(self
.inner
.read_chunk(max, true)
.await?
.map(|chunk| chunk.bytes))
}
pub async fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Result<Option<usize>, Error> {
let dst = buf.chunk_mut();
let dst = unsafe { &mut *(dst as *mut _ as *mut [u8]) };
let size = match self.inner.read(dst).await? {
Some(size) => size,
None => return Ok(None),
};
unsafe { buf.advance_mut(size) };
Ok(Some(size))
}
pub fn stop(&mut self, code: u32) {
self.inner.stop(code).ok();
}
}
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum Error {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("session error: {0}")]
Session(#[from] web_transport::SessionError),
#[error("client error: {0}")]
Client(#[from] web_transport::ClientError),
#[error("connect error: {0}")]
Connect(#[from] quinn::ConnectError),
#[error("connect error: {0}")]
Connection(#[from] quinn::ConnectionError),
#[error("write error: {0}")]
Write(web_transport::WriteError),
#[error("read error: {0}")]
Read(web_transport::ReadError),
#[error("no addresses found")]
NoAddrs,
}
impl From<web_transport::WriteError> for Error {
fn from(e: web_transport::WriteError) -> Self {
match e {
web_transport::WriteError::SessionError(e) => Error::Session(e),
e => Error::Write(e),
}
}
}
impl From<web_transport::ReadError> for Error {
fn from(e: web_transport::ReadError) -> Self {
match e {
web_transport::ReadError::SessionError(e) => Error::Session(e),
e => Error::Read(e),
}
}
}