use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use futures_util::{SinkExt, StreamExt};
use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::{Mutex, Notify};
use crate::error::CodecError;
use crate::framed::{PacketReader, PacketWriter};
use crate::message::{Message, MessageAssembler};
use crate::packet_codec::{Packet, TdsCodec};
pub struct Connection<T>
where
T: AsyncRead + AsyncWrite,
{
reader: PacketReader<ReadHalf<T>>,
writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
assembler: MessageAssembler,
cancel_notify: Arc<Notify>,
cancelling: Arc<std::sync::atomic::AtomicBool>,
}
impl<T> Connection<T>
where
T: AsyncRead + AsyncWrite,
{
pub fn new(transport: T) -> Self {
let (read_half, write_half) = tokio::io::split(transport);
Self {
reader: PacketReader::new(read_half),
writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
assembler: MessageAssembler::new(),
cancel_notify: Arc::new(Notify::new()),
cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
let (read_half, write_half) = tokio::io::split(transport);
Self {
reader: PacketReader::with_codec(read_half, read_codec),
writer: Arc::new(Mutex::new(PacketWriter::with_codec(
write_half,
write_codec,
))),
assembler: MessageAssembler::new(),
cancel_notify: Arc::new(Notify::new()),
cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
#[must_use]
pub fn cancel_handle(&self) -> CancelHandle<T> {
CancelHandle {
writer: Arc::clone(&self.writer),
notify: Arc::clone(&self.cancel_notify),
cancelling: Arc::clone(&self.cancelling),
}
}
#[must_use]
pub fn is_cancelling(&self) -> bool {
self.cancelling.load(std::sync::atomic::Ordering::Acquire)
}
pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
loop {
if self.is_cancelling() {
return self.drain_after_cancel().await;
}
match self.reader.next().await {
Some(Ok(packet)) => {
if let Some(message) = self.assembler.push(packet) {
if self.is_cancelling() {
if Self::payload_ends_with_attention_done(&message.payload) {
tracing::debug!(
"received DONE with ATTENTION, cancellation complete"
);
self.finish_cancel();
return Err(CodecError::Cancelled);
}
tracing::debug!("discarding message from cancelled request");
continue;
}
return Ok(Some(message));
}
}
Some(Err(e)) => return Err(e),
None => {
if self.assembler.has_partial() {
return Err(CodecError::ConnectionClosed);
}
return Ok(None);
}
}
}
}
pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
match self.reader.next().await {
Some(result) => result.map(Some),
None => Ok(None),
}
}
pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
let mut writer = self.writer.lock().await;
writer.send(packet).await
}
pub async fn send_message(
&mut self,
packet_type: PacketType,
payload: Bytes,
max_packet_size: usize,
) -> Result<(), CodecError> {
self.send_message_with_reset(packet_type, payload, max_packet_size, false)
.await
}
pub async fn send_message_with_reset(
&mut self,
packet_type: PacketType,
payload: Bytes,
max_packet_size: usize,
reset_connection: bool,
) -> Result<(), CodecError> {
let max_payload = max_packet_size - PACKET_HEADER_SIZE;
let chunks: Vec<_> = payload.chunks(max_payload).collect();
let total_chunks = chunks.len();
let mut writer = self.writer.lock().await;
for (i, chunk) in chunks.into_iter().enumerate() {
let is_first = i == 0;
let is_last = i == total_chunks - 1;
let mut status = if is_last {
PacketStatus::END_OF_MESSAGE
} else {
PacketStatus::NORMAL
};
if is_first && reset_connection {
status |= PacketStatus::RESET_CONNECTION;
}
let header = PacketHeader::new(packet_type, status, 0);
let packet = Packet::new(header, BytesMut::from(chunk));
writer.send(packet).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<(), CodecError> {
let mut writer = self.writer.lock().await;
writer.flush().await
}
async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
tracing::debug!("draining packets after cancellation");
self.assembler.clear();
loop {
match self.reader.next().await {
Some(Ok(packet)) => {
if let Some(message) = self.assembler.push(packet) {
if message.packet_type == PacketType::TabularResult
&& Self::payload_ends_with_attention_done(&message.payload)
{
tracing::debug!("received DONE with ATTENTION, cancellation complete");
self.finish_cancel();
return Err(CodecError::Cancelled);
}
tracing::debug!("discarding message from cancelled request");
}
}
Some(Err(e)) => {
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
return Err(e);
}
None => {
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
return Err(CodecError::ConnectionClosed);
}
}
}
}
fn finish_cancel(&self) {
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
self.cancel_notify.notify_waiters();
}
fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
let Some(start) = payload.len().checked_sub(13) else {
return false;
};
payload[start] == 0xFD
&& u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
}
pub fn read_codec(&self) -> &TdsCodec {
self.reader.codec()
}
pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
self.reader.codec_mut()
}
}
impl<T> std::fmt::Debug for Connection<T>
where
T: AsyncRead + AsyncWrite + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("cancelling", &self.is_cancelling())
.field("has_partial_message", &self.assembler.has_partial())
.finish_non_exhaustive()
}
}
pub struct CancelHandle<T>
where
T: AsyncRead + AsyncWrite,
{
writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
notify: Arc<Notify>,
cancelling: Arc<std::sync::atomic::AtomicBool>,
}
impl<T> CancelHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub async fn cancel(&self) -> Result<(), CodecError> {
self.cancelling
.store(true, std::sync::atomic::Ordering::Release);
tracing::debug!("sending Attention packet for query cancellation");
let mut writer = self.writer.lock().await;
let header = PacketHeader::new(
PacketType::Attention,
PacketStatus::END_OF_MESSAGE,
PACKET_HEADER_SIZE as u16,
);
let packet = Packet::new(header, BytesMut::new());
writer.send(packet).await?;
writer.flush().await?;
Ok(())
}
pub async fn wait_cancelled(&self) {
if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
self.notify.notified().await;
}
}
#[must_use]
pub fn is_cancelling(&self) -> bool {
self.cancelling.load(std::sync::atomic::Ordering::Acquire)
}
}
impl<T> Clone for CancelHandle<T>
where
T: AsyncRead + AsyncWrite,
{
fn clone(&self) -> Self {
Self {
writer: Arc::clone(&self.writer),
notify: Arc::clone(&self.notify),
cancelling: Arc::clone(&self.cancelling),
}
}
}
impl<T> std::fmt::Debug for CancelHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancelHandle")
.field("cancelling", &self.is_cancelling())
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_attention_packet_header() {
let header = PacketHeader::new(
PacketType::Attention,
PacketStatus::END_OF_MESSAGE,
PACKET_HEADER_SIZE as u16,
);
assert_eq!(header.packet_type, PacketType::Attention);
assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
}
#[test]
fn test_check_attention_done() {
let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
let payload_with_attn = BytesMut::from(
&[
0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
][..],
);
let packet_with_attn = Packet::new(header, payload_with_attn);
let payload_no_attn = BytesMut::from(
&[
0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
][..],
);
let packet_no_attn = Packet::new(header, payload_no_attn);
assert!(
Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
&packet_with_attn.payload
)
);
assert!(
!Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
&packet_no_attn.payload
)
);
let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
interior.extend_from_slice(&[
0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);
assert!(
!Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
);
}
fn raw_message(payload: &[u8]) -> Vec<u8> {
let mut v = vec![0x04, 0x01]; v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
v.extend_from_slice(&[0, 0, 1, 0]); v.extend_from_slice(payload);
v
}
fn done_token(status: u16) -> [u8; 13] {
let s = status.to_le_bytes();
[
0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]
}
#[tokio::test]
async fn test_cancel_mid_read_discards_cancelled_stream() {
use std::task::{Context, Poll};
use tokio::io::AsyncWriteExt;
let (client_io, mut server_io) = tokio::io::duplex(4096);
let mut conn = Connection::new(client_io);
let cancel = conn.cancel_handle();
let mut read_fut = Box::pin(conn.read_message());
let waker = std::task::Waker::noop();
let mut cx = Context::from_waker(waker);
assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
cancel.cancel().await.expect("send attention");
server_io
.write_all(&raw_message(&done_token(0x0002))) .await
.unwrap();
server_io
.write_all(&raw_message(&done_token(0x0020))) .await
.unwrap();
server_io
.write_all(&raw_message(&done_token(0x0010))) .await
.unwrap();
let result = read_fut.await;
assert!(
matches!(result, Err(CodecError::Cancelled)),
"parked read must consume the cancelled stream and report \
Cancelled, got {result:?}"
);
assert!(!conn.is_cancelling(), "cancel flag must be cleared");
let message = conn
.read_message()
.await
.expect("next read")
.expect("next message");
assert_eq!(message.payload[0], 0xFD);
assert_eq!(
u16::from_le_bytes([message.payload[1], message.payload[2]]),
0x0010,
"next response must not be eaten by a stale drain"
);
}
#[tokio::test]
async fn test_cancel_before_read_drains_to_attention_ack() {
use tokio::io::AsyncWriteExt;
let (client_io, mut server_io) = tokio::io::duplex(4096);
let mut conn = Connection::new(client_io);
let cancel = conn.cancel_handle();
cancel.cancel().await.expect("send attention");
server_io
.write_all(&raw_message(&done_token(0x0022))) .await
.unwrap();
server_io
.write_all(&raw_message(&done_token(0x0010))) .await
.unwrap();
let result = conn.read_message().await;
assert!(matches!(result, Err(CodecError::Cancelled)));
assert!(!conn.is_cancelling());
let message = conn
.read_message()
.await
.expect("next read")
.expect("next message");
assert_eq!(
u16::from_le_bytes([message.payload[1], message.payload[2]]),
0x0010
);
}
#[tokio::test]
async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
use std::task::{Context, Poll};
use tokio::io::AsyncWriteExt;
let (client_io, mut server_io) = tokio::io::duplex(4096);
let mut conn = Connection::new(client_io);
let cancel = conn.cancel_handle();
let mut read_fut = Box::pin(conn.read_message());
let waker = std::task::Waker::noop();
let mut cx = Context::from_waker(waker);
assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
cancel.cancel().await.expect("send attention");
let mut row_data = vec![0xD1, 0x08]; row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
row_data.extend_from_slice(&done_token(0x0001)); server_io.write_all(&raw_message(&row_data)).await.unwrap();
server_io
.write_all(&raw_message(&done_token(0x0020)))
.await
.unwrap();
server_io
.write_all(&raw_message(&done_token(0x0010)))
.await
.unwrap();
let result = read_fut.await;
assert!(
matches!(result, Err(CodecError::Cancelled)),
"cancelled read must end in Cancelled, got {result:?}"
);
assert!(!conn.is_cancelling());
let message = conn
.read_message()
.await
.expect("next read")
.expect("next message");
let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
assert_eq!(
status, 0x0010,
"next request's response must come through intact; 0x0020 means \
the interior row bytes were mistaken for the ack and the real \
ack leaked into the next request"
);
}
}