use std::{
io,
os::fd::{AsRawFd, RawFd},
};
use rustls::ExtractedSecrets;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use crate::{
stream::{cork::CorkStream, KtlsStream},
utils::async_read_ready::AsyncReadReady,
Error,
};
pub struct Setup<IO> {
inner: Option<TlsStream<IO>>,
drained: Option<Vec<u8>>,
}
impl<IO> Setup<IO> {
#[inline]
pub const fn new_client_stream(inner: tokio_rustls::client::TlsStream<CorkStream<IO>>) -> Self
where
IO: AsRawFd + AsyncRead + AsyncWrite + Unpin,
{
Setup {
inner: Some(TlsStream::Client(inner)),
drained: None,
}
}
#[inline]
pub const fn new_server_stream<'a>(
inner: tokio_rustls::server::TlsStream<CorkStream<IO>>,
) -> Self
where
IO: AsRawFd + AsyncRead + AsyncReadReady<'a> + AsyncWrite + Unpin,
{
Setup {
inner: Some(TlsStream::Server(inner)),
drained: None,
}
}
pub fn try_recover(&mut self) -> Option<(Option<Vec<u8>>, TlsStream<IO>)> {
self.inner.take().map(|inner| (self.drained.take(), inner))
}
pub async fn execute(&mut self) -> Result<KtlsStream<IO>, Error>
where
IO: AsRawFd + AsyncRead + AsyncWrite + Unpin,
{
{
let Some(inner) = self.inner.as_ref() else {
return Err(crate::Error::ReuseAfterKtlsSetup);
};
crate::ffi::setup_ulp(inner.as_raw_fd()).map_err(Error::UlpError)?;
}
let Some(mut inner) = self.inner.take() else {
unreachable!("has checked for None");
};
{
inner.set_corked(true);
self.drained = inner.drain().await.map_err(Error::DrainError)?
}
let (CorkStream { io, .. }, tls_conn) = inner.into_inner();
let cipher_suite = tls_conn
.negotiated_cipher_suite()
.ok_or(Error::NoNegotiatedCipherSuite)?;
let ExtractedSecrets { tx, rx } = tls_conn
.dangerous_extract_secrets()
.map_err(Error::ExportSecrets)?;
{
let fd = io.as_raw_fd();
let tx = crate::ffi::CryptoInfo::from_rustls(cipher_suite, tx)?;
crate::ffi::setup_tls_info(fd, crate::ffi::Direction::Tx, tx)?;
let rx = crate::ffi::CryptoInfo::from_rustls(cipher_suite, rx)?;
crate::ffi::setup_tls_info(fd, crate::ffi::Direction::Rx, rx)?;
}
Ok(KtlsStream::new(io, self.drained.take()))
}
}
pub enum TlsStream<IO> {
Client(tokio_rustls::client::TlsStream<CorkStream<IO>>),
Server(tokio_rustls::server::TlsStream<CorkStream<IO>>),
}
impl<IO> TlsStream<IO> {
#[inline]
fn as_raw_fd(&self) -> RawFd
where
IO: AsRawFd,
{
match self {
TlsStream::Client(stream) => stream.get_ref().0.io.as_raw_fd(),
TlsStream::Server(stream) => stream.get_ref().0.io.as_raw_fd(),
}
}
#[inline]
fn set_corked(&mut self, corked: bool) {
match self {
TlsStream::Client(stream) => stream.get_mut().0.corked = corked,
TlsStream::Server(stream) => stream.get_mut().0.corked = corked,
}
}
#[inline]
async fn drain(&mut self) -> io::Result<Option<Vec<u8>>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
match self {
TlsStream::Client(stream) => drain(stream).await,
TlsStream::Server(stream) => drain(stream).await,
}
}
#[inline]
fn into_inner(self) -> (CorkStream<IO>, rustls::Connection) {
match self {
TlsStream::Client(stream) => {
let (io, tls_conn) = stream.into_inner();
(io, rustls::Connection::Client(tls_conn))
}
TlsStream::Server(stream) => {
let (io, tls_conn) = stream.into_inner();
(io, rustls::Connection::Server(tls_conn))
}
}
}
}
async fn drain(stream: &mut (impl AsyncRead + Unpin)) -> std::io::Result<Option<Vec<u8>>> {
tracing::trace!("Draining rustls stream");
let mut drained = vec![0u8; 128 * 1024];
let mut filled = 0;
loop {
tracing::trace!("stream.read called");
let n = match stream.read(&mut drained[filled..]).await {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
tracing::trace!("stream.read returned UnexpectedEof, that's expected for us");
break;
}
Err(e) => {
tracing::trace!("stream.read returned error: {e}");
return Err(e);
}
};
tracing::trace!("stream.read returned {n}");
if n == 0 {
break;
}
filled += n;
}
let maybe_drained = if filled == 0 {
None
} else {
tracing::trace!("Draining rustls stream done: drained {filled} bytes");
drained.resize(filled, 0);
Some(drained)
};
Ok(maybe_drained)
}
#[deprecated(
since = "7.0.0-rc.1",
note = "use `Setup::new_server_stream(...).execute()` instead"
)]
pub async fn config_ktls_server<'a, IO>(
inner: tokio_rustls::server::TlsStream<CorkStream<IO>>,
) -> Result<KtlsStream<IO>, Error>
where
IO: AsRawFd + AsyncRead + AsyncReadReady<'a> + AsyncWrite + Unpin,
{
Setup::new_server_stream(inner).execute().await
}
#[deprecated(
since = "7.0.0-rc.1",
note = "use `Setup::new_client_stream(...).execute()` instead"
)]
pub async fn config_ktls_client<IO>(
inner: tokio_rustls::client::TlsStream<CorkStream<IO>>,
) -> Result<KtlsStream<IO>, Error>
where
IO: AsRawFd + AsyncRead + AsyncWrite + Unpin,
{
Setup::new_client_stream(inner).execute().await
}