use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use quinn::{ClientConfig, Connection, Endpoint, RecvStream, SendStream, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use tokio::sync::Mutex;
use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};
pub struct QuicStreamChannel {
send: Arc<Mutex<Option<SendStream>>>,
recv: Arc<Mutex<Option<RecvStream>>>,
}
impl QuicStreamChannel {
fn from_bi(send: SendStream, recv: RecvStream) -> Self {
Self {
send: Arc::new(Mutex::new(Some(send))),
recv: Arc::new(Mutex::new(Some(recv))),
}
}
pub async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut g = self.send.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
write_length_prefixed(s, &data).await
}
pub async fn recv(&self) -> crate::error::Result<Bytes> {
let mut g = self.recv.lock().await;
let r = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let v = read_length_prefixed(r).await?;
Ok(Bytes::from(v))
}
pub async fn is_healthy(&self) -> bool {
self.send.lock().await.is_some() && self.recv.lock().await.is_some()
}
pub async fn close(&self) -> crate::error::Result<()> {
let mut s = self.send.lock().await;
if let Some(mut stream) = s.take() {
let _ = stream.finish();
}
self.recv.lock().await.take();
Ok(())
}
}
pub struct QuicTransport {
conn: Arc<Connection>,
default_stream: QuicStreamChannel,
}
impl QuicTransport {
fn from_bi(conn: Connection, send: SendStream, recv: RecvStream) -> Self {
Self {
conn: Arc::new(conn),
default_stream: QuicStreamChannel::from_bi(send, recv),
}
}
pub fn server_config() -> crate::error::Result<(ServerConfig, CertificateDer<'static>)> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let cert_der = CertificateDer::from(cert.cert);
let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
let server_config = ServerConfig::with_single_cert(vec![cert_der.clone()], key.into())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok((server_config, cert_der))
}
pub fn client_config_trust_server(
cert_der: &CertificateDer<'_>,
) -> crate::error::Result<ClientConfig> {
let mut roots = rustls::RootCertStore::empty();
roots
.add(cert_der.clone())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
ClientConfig::with_root_certificates(Arc::new(roots))
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))
}
pub fn client_config_webpki() -> crate::error::Result<ClientConfig> {
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
ClientConfig::with_root_certificates(Arc::new(roots))
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))
}
pub async fn connect(
bind: SocketAddr,
server: SocketAddr,
server_name: &str,
server_cert_der: &CertificateDer<'_>,
) -> crate::error::Result<Self> {
Self::connect_pinned(bind, server, server_name, server_cert_der, &[]).await
}
pub async fn connect_pinned(
bind: SocketAddr,
server: SocketAddr,
server_name: &str,
server_cert_der: &CertificateDer<'_>,
pins: &[CertPin],
) -> crate::error::Result<Self> {
let mut endpoint = Endpoint::client(bind)
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let cfg = Self::client_config_trust_server(server_cert_der)?;
endpoint.set_default_client_config(cfg);
let conn = endpoint
.connect(server, server_name)
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Self::verify_connection_pins(&conn, pins)?;
let (send, recv) = conn
.open_bi()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self::from_bi(conn, send, recv))
}
pub async fn accept_bi(conn: Connection) -> crate::error::Result<Self> {
let (send, recv) = conn
.accept_bi()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self::from_bi(conn, send, recv))
}
pub async fn open_stream(&self) -> crate::error::Result<QuicStreamChannel> {
let (send, recv) =
self.conn.open_bi().await.map_err(|e| {
SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
})?;
Ok(QuicStreamChannel::from_bi(send, recv))
}
pub async fn accept_stream(&self) -> crate::error::Result<QuicStreamChannel> {
let (send, recv) =
self.conn.accept_bi().await.map_err(|e| {
SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
})?;
Ok(QuicStreamChannel::from_bi(send, recv))
}
fn verify_connection_pins(conn: &Connection, pins: &[CertPin]) -> crate::error::Result<()> {
if pins.is_empty() {
return Ok(());
}
let identity = conn.peer_identity().ok_or_else(|| {
SrxError::Transport(TransportError::ConnectionFailed(
"missing QUIC peer identity for pin verification".into(),
))
})?;
let certs = identity
.downcast::<Vec<CertificateDer<'static>>>()
.map_err(|_| {
SrxError::Transport(TransportError::ConnectionFailed(
"unexpected QUIC peer identity type for pin verification".into(),
))
})?;
verify_peer_cert_pins(Some(certs.as_slice()), pins)
}
}
#[async_trait]
impl Transport for QuicTransport {
fn kind(&self) -> TransportKind {
TransportKind::Quic
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
self.default_stream.send(data).await
}
async fn recv(&self) -> crate::error::Result<Bytes> {
self.default_stream.recv().await
}
async fn is_healthy(&self) -> bool {
self.default_stream.is_healthy().await
}
async fn close(&self) -> crate::error::Result<()> {
self.default_stream.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_config_webpki_builds() {
assert!(QuicTransport::client_config_webpki().is_ok());
}
#[tokio::test]
async fn quic_length_prefixed_roundtrip() {
let (server_config, cert_der) = QuicTransport::server_config().unwrap();
let listen = SocketAddr::from(([127, 0, 0, 1], 0));
let endpoint = Endpoint::server(server_config, listen).unwrap();
let server_addr = endpoint.local_addr().unwrap();
let server = tokio::spawn(async move {
let incoming = endpoint.accept().await.expect("accept");
let conn = incoming.await.expect("conn");
let t = QuicTransport::accept_bi(conn).await.expect("bi");
let got = t.recv().await.expect("recv");
assert_eq!(got.as_ref(), b"quic-ping");
t.send(Bytes::from_static(b"quic-pong"))
.await
.expect("send");
let _ = t.recv().await;
});
let client = QuicTransport::connect(
SocketAddr::from(([127, 0, 0, 1], 0)),
server_addr,
"localhost",
&cert_der,
)
.await
.expect("client connect");
client
.send(Bytes::from_static(b"quic-ping"))
.await
.expect("send");
let reply = client.recv().await.expect("recv");
assert_eq!(reply.as_ref(), b"quic-pong");
client.close().await.ok();
server.await.unwrap();
}
#[tokio::test]
async fn quic_pinned_rejects_wrong_pin() {
let (server_config, cert_der) = QuicTransport::server_config().unwrap();
let listen = SocketAddr::from(([127, 0, 0, 1], 0));
let endpoint = Endpoint::server(server_config, listen).unwrap();
let server_addr = endpoint.local_addr().unwrap();
let wrong_pin = [0x33u8; 32];
let server = tokio::spawn(async move {
let incoming = endpoint.accept().await.expect("accept");
let _ = incoming.await;
});
let res = QuicTransport::connect_pinned(
SocketAddr::from(([127, 0, 0, 1], 0)),
server_addr,
"localhost",
&cert_der,
&[wrong_pin],
)
.await;
assert!(res.is_err(), "expected pin mismatch");
server.abort();
}
#[tokio::test]
async fn quic_multi_stream_roundtrip() {
let (server_config, cert_der) = QuicTransport::server_config().unwrap();
let endpoint =
Endpoint::server(server_config, SocketAddr::from(([127, 0, 0, 1], 0))).unwrap();
let server_addr = endpoint.local_addr().unwrap();
let server = tokio::spawn(async move {
let incoming = endpoint.accept().await.expect("accept");
let conn = incoming.await.expect("conn");
let server_t = QuicTransport::accept_bi(conn)
.await
.expect("default stream");
let got = server_t.recv().await.expect("default recv");
assert_eq!(got.as_ref(), b"default-ping");
server_t
.send(Bytes::from_static(b"default-pong"))
.await
.expect("default send");
let s1 = server_t.accept_stream().await.expect("accept stream1");
let got1 = s1.recv().await.expect("recv stream1");
assert_eq!(got1.as_ref(), b"stream-1");
s1.send(Bytes::from_static(b"ack-1"))
.await
.expect("send stream1");
let s2 = server_t.accept_stream().await.expect("accept stream2");
let got2 = s2.recv().await.expect("recv stream2");
assert_eq!(got2.as_ref(), b"stream-2");
s2.send(Bytes::from_static(b"ack-2"))
.await
.expect("send stream2");
let done = server_t.recv().await.expect("done recv");
assert_eq!(done.as_ref(), b"done");
});
let client_t = QuicTransport::connect(
SocketAddr::from(([127, 0, 0, 1], 0)),
server_addr,
"localhost",
&cert_der,
)
.await
.expect("client connect");
client_t
.send(Bytes::from_static(b"default-ping"))
.await
.expect("default send");
let default_reply = client_t.recv().await.expect("default recv");
assert_eq!(default_reply.as_ref(), b"default-pong");
let c1 = client_t.open_stream().await.expect("open stream1");
c1.send(Bytes::from_static(b"stream-1"))
.await
.expect("send stream1");
let ack1 = c1.recv().await.expect("recv stream1");
assert_eq!(ack1.as_ref(), b"ack-1");
let c2 = client_t.open_stream().await.expect("open stream2");
c2.send(Bytes::from_static(b"stream-2"))
.await
.expect("send stream2");
let ack2 = c2.recv().await.expect("recv stream2");
assert_eq!(ack2.as_ref(), b"ack-2");
client_t
.send(Bytes::from_static(b"done"))
.await
.expect("done send");
server.await.unwrap();
}
}