use std::cmp;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
const HEADER_SIZE: usize = 8;
const PACKET_TYPE_PRELOGIN: u8 = 0x12;
const PACKET_STATUS_EOM: u8 = 0x01;
pub struct TlsPreloginWrapper<S> {
stream: S,
pending_handshake: bool,
header_buf: [u8; HEADER_SIZE],
header_pos: usize,
read_remaining: usize,
write_buf: Vec<u8>,
write_pos: usize,
header_written: bool,
}
impl<S> TlsPreloginWrapper<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
pending_handshake: true,
header_buf: [0u8; HEADER_SIZE],
header_pos: 0,
read_remaining: 0,
write_buf: vec![0u8; HEADER_SIZE], write_pos: HEADER_SIZE, header_written: false,
}
}
pub fn handshake_complete(&mut self) {
self.pending_handshake = false;
}
pub fn get_ref(&self) -> &S {
&self.stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn into_inner(self) -> S {
self.stream
}
}
impl<S: AsyncRead + Unpin> AsyncRead for TlsPreloginWrapper<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.pending_handshake {
return Pin::new(&mut this.stream).poll_read(cx, buf);
}
while this.header_pos < HEADER_SIZE {
let mut header_buf = ReadBuf::new(&mut this.header_buf[this.header_pos..]);
match Pin::new(&mut this.stream).poll_read(cx, &mut header_buf)? {
Poll::Ready(()) => {
let n = header_buf.filled().len();
if n == 0 {
return Poll::Ready(Ok(()));
}
this.header_pos += n;
}
Poll::Pending => return Poll::Pending,
}
}
if this.read_remaining == 0 {
let packet_type = this.header_buf[0];
if packet_type != PACKET_TYPE_PRELOGIN {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Expected PreLogin packet (0x12), got 0x{:02X}", packet_type),
)));
}
let length = u16::from_be_bytes([this.header_buf[2], this.header_buf[3]]) as usize;
this.read_remaining = length.saturating_sub(HEADER_SIZE);
tracing::trace!(
"TLS wrapper: reading {} bytes of payload",
this.read_remaining
);
}
let max_read = cmp::min(this.read_remaining, buf.remaining());
if max_read == 0 {
return Poll::Ready(Ok(()));
}
let mut temp_buf = vec![0u8; max_read];
let mut temp_read_buf = ReadBuf::new(&mut temp_buf);
match Pin::new(&mut this.stream).poll_read(cx, &mut temp_read_buf)? {
Poll::Ready(()) => {
let n = temp_read_buf.filled().len();
if n > 0 {
buf.put_slice(&temp_buf[..n]);
this.read_remaining -= n;
if this.read_remaining == 0 {
this.header_pos = 0;
}
}
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for TlsPreloginWrapper<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if !this.pending_handshake {
return Pin::new(&mut this.stream).poll_write(cx, buf);
}
this.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.pending_handshake && this.write_buf.len() > HEADER_SIZE {
if !this.header_written {
let total_length = this.write_buf.len();
this.write_buf[0] = PACKET_TYPE_PRELOGIN;
this.write_buf[1] = PACKET_STATUS_EOM;
this.write_buf[2] = (total_length >> 8) as u8;
this.write_buf[3] = total_length as u8;
this.write_buf[4] = 0; this.write_buf[5] = 0; this.write_buf[6] = 1; this.write_buf[7] = 0;
this.header_written = true;
this.write_pos = 0;
tracing::trace!("TLS wrapper: sending {} bytes", total_length);
}
while this.write_pos < this.write_buf.len() {
match Pin::new(&mut this.stream)
.poll_write(cx, &this.write_buf[this.write_pos..])?
{
Poll::Ready(n) => {
this.write_pos += n;
}
Poll::Pending => return Poll::Pending,
}
}
this.write_buf.truncate(HEADER_SIZE);
this.write_pos = HEADER_SIZE;
this.header_written = false;
}
Pin::new(&mut this.stream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
}
}