use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tokio_rustls::TlsStream;
use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};
const RECV_BUF: usize = 65_536;
pub struct TlsTcpTransport {
stream: Arc<Mutex<Option<TlsStream<TcpStream>>>>,
}
impl TlsTcpTransport {
pub fn server_config_from_der(
cert_der: CertificateDer<'static>,
key_der: PrivatePkcs8KeyDer<'static>,
) -> crate::error::Result<Arc<rustls::ServerConfig>> {
let cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der.into())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Arc::new(cfg))
}
pub fn client_config_trust_server_cert(
cert_der: &CertificateDer<'_>,
) -> crate::error::Result<Arc<rustls::ClientConfig>> {
let mut roots = rustls::RootCertStore::empty();
roots
.add(cert_der.clone())
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let cfg = rustls::ClientConfig::builder()
.with_root_certificates(Arc::new(roots))
.with_no_client_auth();
Ok(Arc::new(cfg))
}
pub fn server_name_dns(host: &str) -> crate::error::Result<ServerName<'static>> {
ServerName::try_from(host.to_string()).map_err(|_| {
SrxError::Transport(TransportError::ConnectionFailed(format!(
"invalid TLS server name: {host}"
)))
})
}
pub async fn connect(
addr: impl tokio::net::ToSocketAddrs,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
) -> crate::error::Result<Self> {
Self::connect_pinned(addr, server_name, client_config, &[]).await
}
pub async fn connect_pinned(
addr: impl tokio::net::ToSocketAddrs,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
pins: &[CertPin],
) -> crate::error::Result<Self> {
let tcp = TcpStream::connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let connector = TlsConnector::from(client_config);
let tls_client = connector
.connect(server_name, tcp)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
verify_peer_cert_pins(tls_client.get_ref().1.peer_certificates(), pins)?;
let tls: TlsStream<TcpStream> = tls_client.into();
Ok(Self {
stream: Arc::new(Mutex::new(Some(tls))),
})
}
pub async fn accept(
listener: &TcpListener,
server_config: Arc<rustls::ServerConfig>,
) -> crate::error::Result<Self> {
let (tcp, _) = listener
.accept()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let acceptor = TlsAcceptor::from(server_config);
let tls: TlsStream<TcpStream> = acceptor
.accept(tcp)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?
.into();
Ok(Self {
stream: Arc::new(Mutex::new(Some(tls))),
})
}
#[must_use]
pub fn from_tls_stream(stream: TlsStream<TcpStream>) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
}
}
pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
write_length_prefixed(stream, payload).await
}
pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let v = read_length_prefixed(stream).await?;
Ok(Bytes::from(v))
}
}
#[async_trait]
impl Transport for TlsTcpTransport {
fn kind(&self) -> TransportKind {
TransportKind::Tcp
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
stream
.write_all(&data)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
stream
.flush()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let mut buf = vec![0u8; RECV_BUF];
let n = stream
.read(&mut buf)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
if n == 0 {
return Err(SrxError::Transport(TransportError::ChannelClosed));
}
buf.truncate(n);
Ok(Bytes::from(buf))
}
async fn is_healthy(&self) -> bool {
self.stream.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
if let Some(mut s) = guard.take() {
let _ = s.shutdown().await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
#[tokio::test]
async fn tls_framed_roundtrip() {
let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = CertificateDer::from(ck.cert);
let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
let server_cfg = TlsTcpTransport::server_config_from_der(cert_der.clone(), key).unwrap();
let client_cfg = TlsTcpTransport::client_config_trust_server_cert(&cert_der).unwrap();
let name = TlsTcpTransport::server_name_dns("localhost").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let t = TlsTcpTransport::accept(&listener, server_cfg)
.await
.unwrap();
let got = t.recv_framed().await.unwrap();
assert_eq!(got.as_ref(), b"tls-payload");
t.send_framed(b"tls-ack").await.unwrap();
});
let client = TlsTcpTransport::connect(addr, name, client_cfg)
.await
.unwrap();
client.send_framed(b"tls-payload").await.unwrap();
let reply = client.recv_framed().await.unwrap();
assert_eq!(reply.as_ref(), b"tls-ack");
client.close().await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn tls_pinned_rejects_wrong_pin() {
let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_der = CertificateDer::from(ck.cert);
let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
let server_cfg = TlsTcpTransport::server_config_from_der(cert_der.clone(), key).unwrap();
let client_cfg = TlsTcpTransport::client_config_trust_server_cert(&cert_der).unwrap();
let name = TlsTcpTransport::server_name_dns("localhost").unwrap();
let wrong_pin = [0xA5u8; 32];
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let _ = TlsTcpTransport::accept(&listener, server_cfg).await;
});
let res = TlsTcpTransport::connect_pinned(addr, name, client_cfg, &[wrong_pin]).await;
assert!(res.is_err(), "expected pin mismatch");
server.abort();
}
}