use std::cmp;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
const HEADER_SIZE: usize = 8;
const PACKET_TYPE_PRELOGIN: u8 = 0x12;
const PACKET_STATUS_EOM: u8 = 0x01;
const HANDSHAKE_PACKET_SIZE: usize = 4096;
const MAX_HANDSHAKE_PAYLOAD: usize = HANDSHAKE_PACKET_SIZE - HEADER_SIZE;
pub struct TlsPreloginWrapper<S> {
stream: S,
pending_handshake: bool,
header_buf: [u8; HEADER_SIZE],
header_pos: usize,
read_remaining: usize,
write_buf: Vec<u8>,
write_pos: usize,
header_written: bool,
}
impl<S> TlsPreloginWrapper<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
pending_handshake: true,
header_buf: [0u8; HEADER_SIZE],
header_pos: 0,
read_remaining: 0,
write_buf: vec![0u8; HEADER_SIZE], write_pos: HEADER_SIZE, header_written: false,
}
}
pub fn handshake_complete(&mut self) {
self.pending_handshake = false;
}
pub fn get_ref(&self) -> &S {
&self.stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn into_inner(self) -> S {
self.stream
}
}
impl<S: AsyncRead + Unpin> AsyncRead for TlsPreloginWrapper<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.pending_handshake {
return Pin::new(&mut this.stream).poll_read(cx, buf);
}
while this.header_pos < HEADER_SIZE {
let mut header_buf = ReadBuf::new(&mut this.header_buf[this.header_pos..]);
match Pin::new(&mut this.stream).poll_read(cx, &mut header_buf)? {
Poll::Ready(()) => {
let n = header_buf.filled().len();
if n == 0 {
return Poll::Ready(Ok(()));
}
this.header_pos += n;
}
Poll::Pending => return Poll::Pending,
}
}
if this.read_remaining == 0 {
let packet_type = this.header_buf[0];
if packet_type != PACKET_TYPE_PRELOGIN {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Expected PreLogin packet (0x12), got 0x{packet_type:02X}"),
)));
}
let length = u16::from_be_bytes([this.header_buf[2], this.header_buf[3]]) as usize;
this.read_remaining = length.saturating_sub(HEADER_SIZE);
tracing::trace!(
"TLS wrapper: reading {} bytes of payload",
this.read_remaining
);
}
let max_read = cmp::min(this.read_remaining, buf.remaining());
if max_read == 0 {
return Poll::Ready(Ok(()));
}
let mut temp_buf = vec![0u8; max_read];
let mut temp_read_buf = ReadBuf::new(&mut temp_buf);
match Pin::new(&mut this.stream).poll_read(cx, &mut temp_read_buf)? {
Poll::Ready(()) => {
let n = temp_read_buf.filled().len();
if n > 0 {
buf.put_slice(&temp_buf[..n]);
this.read_remaining -= n;
if this.read_remaining == 0 {
this.header_pos = 0;
}
}
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for TlsPreloginWrapper<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if !this.pending_handshake {
return Pin::new(&mut this.stream).poll_write(cx, buf);
}
this.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.pending_handshake && this.write_buf.len() > HEADER_SIZE {
if !this.header_written {
let payload = this.write_buf.split_off(HEADER_SIZE);
let packets = payload.len().div_ceil(MAX_HANDSHAKE_PAYLOAD);
let mut framed = Vec::with_capacity(payload.len() + packets * HEADER_SIZE);
for (i, chunk) in payload.chunks(MAX_HANDSHAKE_PAYLOAD).enumerate() {
let total = HEADER_SIZE + chunk.len();
framed.push(PACKET_TYPE_PRELOGIN);
framed.push(PACKET_STATUS_EOM);
framed.push((total >> 8) as u8);
framed.push(total as u8);
framed.push(0); framed.push(0); framed.push((i as u8).wrapping_add(1)); framed.push(0); framed.extend_from_slice(chunk);
}
this.write_buf = framed;
this.write_pos = 0;
this.header_written = true;
tracing::trace!(
payload_bytes = payload.len(),
packets,
"TLS wrapper: sending handshake flight"
);
}
while this.write_pos < this.write_buf.len() {
match Pin::new(&mut this.stream)
.poll_write(cx, &this.write_buf[this.write_pos..])?
{
Poll::Ready(n) => {
this.write_pos += n;
}
Poll::Pending => return Poll::Pending,
}
}
this.write_buf.clear();
this.write_buf.resize(HEADER_SIZE, 0);
this.write_pos = HEADER_SIZE;
this.header_written = false;
}
Pin::new(&mut this.stream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn parse_packets(mut bytes: &[u8]) -> Vec<(u8, u8, u8, Vec<u8>)> {
let mut packets = Vec::new();
while !bytes.is_empty() {
let total = usize::from(bytes[2]) << 8 | usize::from(bytes[3]);
let (packet, rest) = bytes.split_at(total);
packets.push((
packet[0],
packet[1],
packet[6],
packet[HEADER_SIZE..].to_vec(),
));
bytes = rest;
}
packets
}
#[tokio::test]
async fn small_flight_is_one_prelogin_packet() {
let (client, mut server) = tokio::io::duplex(1 << 20);
let mut wrapper = TlsPreloginWrapper::new(client);
let payload: Vec<u8> = (0..100u8).collect();
wrapper.write_all(&payload).await.unwrap();
wrapper.flush().await.unwrap();
let mut received = vec![0u8; payload.len() + HEADER_SIZE];
server.read_exact(&mut received).await.unwrap();
let packets = parse_packets(&received);
assert_eq!(packets.len(), 1);
let (ptype, status, id, data) = &packets[0];
assert_eq!(*ptype, PACKET_TYPE_PRELOGIN);
assert_eq!(*status, PACKET_STATUS_EOM);
assert_eq!(*id, 1);
assert_eq!(*data, payload);
}
#[tokio::test]
async fn oversized_flight_splits_at_packet_cap() {
let (client, mut server) = tokio::io::duplex(1 << 20);
let mut wrapper = TlsPreloginWrapper::new(client);
let payload: Vec<u8> = (0..70_000u32).map(|i| (i % 251) as u8).collect();
wrapper.write_all(&payload).await.unwrap();
wrapper.flush().await.unwrap();
let expected_packets = payload.len().div_ceil(MAX_HANDSHAKE_PAYLOAD);
let mut received = vec![0u8; payload.len() + expected_packets * HEADER_SIZE];
server.read_exact(&mut received).await.unwrap();
let packets = parse_packets(&received);
assert_eq!(packets.len(), expected_packets);
let mut reassembled = Vec::new();
for (i, (ptype, status, id, data)) in packets.iter().enumerate() {
assert_eq!(*ptype, PACKET_TYPE_PRELOGIN);
assert_eq!(
*status, PACKET_STATUS_EOM,
"each chunk is its own complete EOM message"
);
assert_eq!(*id, (i as u8).wrapping_add(1));
assert!(data.len() + HEADER_SIZE <= HANDSHAKE_PACKET_SIZE);
reassembled.extend_from_slice(data);
}
assert_eq!(reassembled, payload, "no bytes lost or reordered");
}
#[tokio::test]
async fn consecutive_flights_reuse_the_wrapper_cleanly() {
let (client, mut server) = tokio::io::duplex(1 << 20);
let mut wrapper = TlsPreloginWrapper::new(client);
for round in 0..3u8 {
let payload = vec![round; 50];
wrapper.write_all(&payload).await.unwrap();
wrapper.flush().await.unwrap();
let mut received = vec![0u8; payload.len() + HEADER_SIZE];
server.read_exact(&mut received).await.unwrap();
let packets = parse_packets(&received);
assert_eq!(packets.len(), 1);
assert_eq!(packets[0].3, payload, "round {round} payload intact");
}
}
}