sqlx-sqlserver 0.0.1-beta.1

Independent Microsoft SQL Server driver crate for SQLx.
Documentation
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::protocol::packet::{
    encode_message, PacketFrameError, PacketHeader, PacketHeaderError, PacketType,
    PACKET_HEADER_LEN,
};

pub(crate) struct TlsPreloginStream<S> {
    stream: S,
    handshake: bool,
    header_buf: [u8; PACKET_HEADER_LEN],
    header_pos: usize,
    read_remaining: usize,
    write_buf: Vec<u8>,
}

impl<S> TlsPreloginStream<S> {
    pub(crate) fn new(stream: S) -> Self {
        Self {
            stream,
            handshake: false,
            header_buf: [0; PACKET_HEADER_LEN],
            header_pos: 0,
            read_remaining: 0,
            write_buf: Vec::new(),
        }
    }

    pub(crate) fn start_handshake(&mut self) {
        self.handshake = true;
    }

    pub(crate) fn finish_handshake(&mut self) {
        self.handshake = false;
    }
}

impl<S> std::fmt::Debug for TlsPreloginStream<S> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TlsPreloginStream")
            .field("handshake", &self.handshake)
            .field("read_remaining", &self.read_remaining)
            .field("write_buf_len", &self.write_buf.len())
            .finish_non_exhaustive()
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsPreloginStream<S> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        if !self.handshake {
            return Pin::new(&mut self.stream).poll_read(cx, buf);
        }

        if buf.remaining() == 0 {
            return Poll::Ready(Ok(()));
        }

        if self.read_remaining == 0 {
            while self.header_pos < PACKET_HEADER_LEN {
                let mut scratch = [0; PACKET_HEADER_LEN];
                let remaining = PACKET_HEADER_LEN - self.header_pos;
                let mut header_read = ReadBuf::new(&mut scratch[..remaining]);
                ready!(Pin::new(&mut self.stream).poll_read(cx, &mut header_read))?;

                let read = header_read.filled().len();
                if read == 0 {
                    let message = if self.header_pos == 0 {
                        "SQL Server closed the connection before sending a TDS PRELOGIN packet during TLS handshake"
                    } else {
                        "SQL Server closed the connection in the middle of a TDS PRELOGIN packet header during TLS handshake"
                    };
                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, message)));
                }

                let header_pos = self.header_pos;
                self.header_buf[header_pos..header_pos + read]
                    .copy_from_slice(header_read.filled());
                self.header_pos += read;
            }

            let header = PacketHeader::decode(&self.header_buf).map_err(packet_header_io_error)?;
            if header.packet_type != PacketType::PRE_LOGIN {
                return Poll::Ready(Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    format!(
                        "expected TLS handshake bytes in PRELOGIN packet, got packet type 0x{:02x}",
                        header.packet_type.code()
                    ),
                )));
            }

            self.read_remaining = usize::from(header.length)
                .checked_sub(PACKET_HEADER_LEN)
                .ok_or_else(|| {
                    io::Error::new(io::ErrorKind::InvalidData, "invalid TDS packet length")
                })?;
            self.header_pos = 0;
        }

        let max_read = std::cmp::min(self.read_remaining, buf.remaining());
        let mut limited_buf = buf.take(max_read);
        ready!(Pin::new(&mut self.stream).poll_read(cx, &mut limited_buf))?;

        let read = limited_buf.filled().len();
        if read == 0 && self.read_remaining > 0 {
            return Poll::Ready(Err(io::Error::new(
                io::ErrorKind::UnexpectedEof,
                "SQL Server closed the connection in the middle of a TDS PRELOGIN TLS payload",
            )));
        }

        buf.advance(read);
        self.read_remaining -= read;

        Poll::Ready(Ok(()))
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsPreloginStream<S> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        if !self.handshake {
            return Pin::new(&mut self.stream).poll_write(cx, buf);
        }

        self.write_buf.extend_from_slice(buf);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        if self.handshake && !self.write_buf.is_empty() {
            let payload = std::mem::take(&mut self.write_buf);
            self.write_buf =
                wrap_prelogin_tls_payload(&payload, 4096).map_err(packet_frame_error)?;

            while !self.write_buf.is_empty() {
                let write_buf = std::mem::take(&mut self.write_buf);
                let written = ready!(Pin::new(&mut self.stream).poll_write(cx, &write_buf))?;
                if written == 0 {
                    return Poll::Ready(Err(io::Error::new(
                        io::ErrorKind::WriteZero,
                        "failed to write TLS handshake packet",
                    )));
                }

                if written < write_buf.len() {
                    self.write_buf.extend_from_slice(&write_buf[written..]);
                }
            }
        }

        Pin::new(&mut self.stream).poll_flush(cx)
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_shutdown(cx)
    }
}

fn wrap_prelogin_tls_payload(
    payload: &[u8],
    packet_size: usize,
) -> Result<Vec<u8>, PacketFrameError> {
    encode_message(PacketType::PRE_LOGIN, payload, packet_size)
}

fn packet_header_io_error(error: PacketHeaderError) -> io::Error {
    io::Error::new(io::ErrorKind::InvalidData, error)
}

fn packet_frame_error(error: PacketFrameError) -> io::Error {
    io::Error::new(
        io::ErrorKind::InvalidData,
        format!("failed to wrap TLS handshake bytes in a TDS PRELOGIN packet: {error}"),
    )
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn wraps_tls_handshake_bytes_as_prelogin_packets() {
        let packet = wrap_prelogin_tls_payload(b"hello", 512).unwrap();
        let header = PacketHeader::decode(&packet[..PACKET_HEADER_LEN]).unwrap();

        assert_eq!(PacketType::PRE_LOGIN, header.packet_type);
        assert_eq!(PACKET_HEADER_LEN + 5, usize::from(header.length));
        assert_eq!(b"hello", &packet[PACKET_HEADER_LEN..]);
    }
}