use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use rustls::pki_types::ServerName;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
sync::Mutex,
};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{WebSocketStream, accept_async, client_async};
use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
enum WsStream {
Plain(TcpStream),
ClientTls(tokio_rustls::client::TlsStream<TcpStream>),
ServerTls(tokio_rustls::server::TlsStream<TcpStream>),
}
impl AsyncRead for WsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
WsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
WsStream::ClientTls(s) => Pin::new(s).poll_read(cx, buf),
WsStream::ServerTls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for WsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
WsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
WsStream::ClientTls(s) => Pin::new(s).poll_write(cx, buf),
WsStream::ServerTls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
WsStream::ClientTls(s) => Pin::new(s).poll_flush(cx),
WsStream::ServerTls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
WsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
WsStream::ClientTls(s) => Pin::new(s).poll_shutdown(cx),
WsStream::ServerTls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl Unpin for WsStream {}
pub struct WebSocketTransport {
ws: Arc<Mutex<Option<WebSocketStream<WsStream>>>>,
}
impl WebSocketTransport {
pub async fn connect(uri: impl AsRef<str>) -> crate::error::Result<Self> {
let uri_str = uri.as_ref();
let addr = uri_str
.strip_prefix("ws://")
.unwrap_or(uri_str)
.trim_end_matches('/');
let tcp = TcpStream::connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let request = http::Request::builder()
.method("GET")
.header("Host", addr)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
)
.uri(uri_str)
.body(())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let (ws, _) = client_async(request, WsStream::Plain(tcp))
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
ws: Arc::new(Mutex::new(Some(ws))),
})
}
pub async fn connect_tls(
addr: impl tokio::net::ToSocketAddrs,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
) -> crate::error::Result<Self> {
Self::connect_tls_pinned(addr, server_name, client_config, &[]).await
}
pub async fn connect_tls_pinned(
addr: impl tokio::net::ToSocketAddrs,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
pins: &[CertPin],
) -> crate::error::Result<Self> {
let tcp = TcpStream::connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let host = server_name.to_str().to_string();
let connector = TlsConnector::from(client_config);
let tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
verify_peer_cert_pins(tls_stream.get_ref().1.peer_certificates(), pins)?;
let uri = format!("wss://{host}/");
let request = http::Request::builder()
.method("GET")
.header("Host", host.as_str())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tokio_tungstenite::tungstenite::handshake::client::generate_key(),
)
.uri(&uri)
.body(())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let (ws, _) = client_async(request, WsStream::ClientTls(tls_stream))
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
ws: Arc::new(Mutex::new(Some(ws))),
})
}
pub async fn accept(stream: TcpStream) -> crate::error::Result<Self> {
let ws = accept_async(WsStream::Plain(stream))
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
ws: Arc::new(Mutex::new(Some(ws))),
})
}
pub async fn accept_tls(
stream: TcpStream,
server_config: Arc<rustls::ServerConfig>,
) -> crate::error::Result<Self> {
let acceptor = TlsAcceptor::from(server_config);
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let ws = accept_async(WsStream::ServerTls(tls_stream))
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
ws: Arc::new(Mutex::new(Some(ws))),
})
}
}
#[async_trait]
impl Transport for WebSocketTransport {
fn kind(&self) -> TransportKind {
TransportKind::WebSocket
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut g = self.ws.lock().await;
let ws = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
ws.send(Message::Binary(data))
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
let mut g = self.ws.lock().await;
let ws = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
loop {
let msg = ws.next().await.transpose().map_err(|e| {
SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
})?;
let Some(msg) = msg else {
return Err(SrxError::Transport(TransportError::ChannelClosed));
};
match msg {
Message::Binary(b) => return Ok(b),
Message::Ping(p) => {
ws.send(Message::Pong(p)).await.map_err(|e| {
SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
})?;
}
Message::Close(_) => {
return Err(SrxError::Transport(TransportError::ChannelClosed));
}
Message::Pong(_) | Message::Frame(_) | Message::Text(_) => {}
}
}
}
async fn is_healthy(&self) -> bool {
self.ws.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
let mut g = self.ws.lock().await;
if let Some(mut ws) = g.take() {
let _ = SinkExt::close(&mut ws).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
#[tokio::test]
async fn ws_binary_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let uri = format!("ws://{addr}/");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = WebSocketTransport::accept(stream).await.unwrap();
let got = t.recv().await.unwrap();
assert_eq!(got.as_ref(), b"ws-ping");
t.send(Bytes::from_static(b"ws-pong")).await.unwrap();
});
let client = WebSocketTransport::connect(&uri).await.unwrap();
client.send(Bytes::from_static(b"ws-ping")).await.unwrap();
let reply = client.recv().await.unwrap();
assert_eq!(reply.as_ref(), b"ws-pong");
client.close().await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn ws_tls_binary_roundtrip() {
let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = CertificateDer::from(ck.cert);
let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
let server_cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key.into())
.unwrap();
let server_cfg = Arc::new(server_cfg);
let mut roots = rustls::RootCertStore::empty();
roots.add(cert_der.clone()).unwrap();
let client_cfg = rustls::ClientConfig::builder()
.with_root_certificates(Arc::new(roots))
.with_no_client_auth();
let client_cfg = Arc::new(client_cfg);
let server_name = ServerName::try_from("localhost".to_string()).unwrap();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let t = WebSocketTransport::accept_tls(stream, server_cfg)
.await
.unwrap();
let got = t.recv().await.unwrap();
assert_eq!(got.as_ref(), b"wss-ping");
t.send(Bytes::from_static(b"wss-pong")).await.unwrap();
});
let client = WebSocketTransport::connect_tls(addr, server_name, client_cfg)
.await
.unwrap();
client.send(Bytes::from_static(b"wss-ping")).await.unwrap();
let reply = client.recv().await.unwrap();
assert_eq!(reply.as_ref(), b"wss-pong");
client.close().await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn ws_tls_pinned_rejects_wrong_pin() {
let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = CertificateDer::from(ck.cert);
let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
let server_cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key.into())
.unwrap();
let server_cfg = Arc::new(server_cfg);
let mut roots = rustls::RootCertStore::empty();
roots.add(cert_der).unwrap();
let client_cfg = rustls::ClientConfig::builder()
.with_root_certificates(Arc::new(roots))
.with_no_client_auth();
let client_cfg = Arc::new(client_cfg);
let server_name = ServerName::try_from("localhost".to_string()).unwrap();
let wrong_pin = [0x42u8; 32];
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let _ = WebSocketTransport::accept_tls(stream, server_cfg).await;
});
let res =
WebSocketTransport::connect_tls_pinned(addr, server_name, client_cfg, &[wrong_pin])
.await;
assert!(res.is_err(), "expected pin mismatch");
server.abort();
}
}