use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use futures_util::{SinkExt, StreamExt};
use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::{Mutex, Notify};
use crate::error::CodecError;
use crate::framed::{PacketReader, PacketWriter};
use crate::message::{Message, MessageAssembler};
use crate::packet_codec::{Packet, TdsCodec};
pub struct Connection<T>
where
T: AsyncRead + AsyncWrite,
{
reader: PacketReader<ReadHalf<T>>,
writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
assembler: MessageAssembler,
cancel_notify: Arc<Notify>,
cancelling: Arc<std::sync::atomic::AtomicBool>,
}
impl<T> Connection<T>
where
T: AsyncRead + AsyncWrite,
{
pub fn new(transport: T) -> Self {
let (read_half, write_half) = tokio::io::split(transport);
Self {
reader: PacketReader::new(read_half),
writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
assembler: MessageAssembler::new(),
cancel_notify: Arc::new(Notify::new()),
cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
let (read_half, write_half) = tokio::io::split(transport);
Self {
reader: PacketReader::with_codec(read_half, read_codec),
writer: Arc::new(Mutex::new(PacketWriter::with_codec(
write_half,
write_codec,
))),
assembler: MessageAssembler::new(),
cancel_notify: Arc::new(Notify::new()),
cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
#[must_use]
pub fn cancel_handle(&self) -> CancelHandle<T> {
CancelHandle {
writer: Arc::clone(&self.writer),
notify: Arc::clone(&self.cancel_notify),
cancelling: Arc::clone(&self.cancelling),
}
}
#[must_use]
pub fn is_cancelling(&self) -> bool {
self.cancelling.load(std::sync::atomic::Ordering::Acquire)
}
pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
loop {
if self.is_cancelling() {
return self.drain_after_cancel().await;
}
match self.reader.next().await {
Some(Ok(packet)) => {
if let Some(message) = self.assembler.push(packet) {
return Ok(Some(message));
}
}
Some(Err(e)) => return Err(e),
None => {
if self.assembler.has_partial() {
return Err(CodecError::ConnectionClosed);
}
return Ok(None);
}
}
}
}
pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
match self.reader.next().await {
Some(result) => result.map(Some),
None => Ok(None),
}
}
pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
let mut writer = self.writer.lock().await;
writer.send(packet).await
}
pub async fn send_message(
&mut self,
packet_type: PacketType,
payload: Bytes,
max_packet_size: usize,
) -> Result<(), CodecError> {
self.send_message_with_reset(packet_type, payload, max_packet_size, false)
.await
}
pub async fn send_message_with_reset(
&mut self,
packet_type: PacketType,
payload: Bytes,
max_packet_size: usize,
reset_connection: bool,
) -> Result<(), CodecError> {
let max_payload = max_packet_size - PACKET_HEADER_SIZE;
let chunks: Vec<_> = payload.chunks(max_payload).collect();
let total_chunks = chunks.len();
let mut writer = self.writer.lock().await;
for (i, chunk) in chunks.into_iter().enumerate() {
let is_first = i == 0;
let is_last = i == total_chunks - 1;
let mut status = if is_last {
PacketStatus::END_OF_MESSAGE
} else {
PacketStatus::NORMAL
};
if is_first && reset_connection {
status |= PacketStatus::RESET_CONNECTION;
}
let header = PacketHeader::new(packet_type, status, 0);
let packet = Packet::new(header, BytesMut::from(chunk));
writer.send(packet).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<(), CodecError> {
let mut writer = self.writer.lock().await;
writer.flush().await
}
async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
tracing::debug!("draining packets after cancellation");
self.assembler.clear();
loop {
match self.reader.next().await {
Some(Ok(packet)) => {
if packet.header.packet_type == PacketType::TabularResult
&& !packet.payload.is_empty()
{
if self.check_attention_done(&packet) {
tracing::debug!("received DONE with ATTENTION, cancellation complete");
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
self.cancel_notify.notify_waiters();
return Ok(None);
}
}
}
Some(Err(e)) => {
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
return Err(e);
}
None => {
self.cancelling
.store(false, std::sync::atomic::Ordering::Release);
return Ok(None);
}
}
}
}
fn check_attention_done(&self, packet: &Packet) -> bool {
let payload = &packet.payload;
for i in 0..payload.len() {
if payload[i] == 0xFD && i + 3 <= payload.len() {
let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
if status & 0x0020 != 0 {
return true;
}
}
}
false
}
pub fn read_codec(&self) -> &TdsCodec {
self.reader.codec()
}
pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
self.reader.codec_mut()
}
}
impl<T> std::fmt::Debug for Connection<T>
where
T: AsyncRead + AsyncWrite + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("cancelling", &self.is_cancelling())
.field("has_partial_message", &self.assembler.has_partial())
.finish_non_exhaustive()
}
}
pub struct CancelHandle<T>
where
T: AsyncRead + AsyncWrite,
{
writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
notify: Arc<Notify>,
cancelling: Arc<std::sync::atomic::AtomicBool>,
}
impl<T> CancelHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub async fn cancel(&self) -> Result<(), CodecError> {
self.cancelling
.store(true, std::sync::atomic::Ordering::Release);
tracing::debug!("sending Attention packet for query cancellation");
let mut writer = self.writer.lock().await;
let header = PacketHeader::new(
PacketType::Attention,
PacketStatus::END_OF_MESSAGE,
PACKET_HEADER_SIZE as u16,
);
let packet = Packet::new(header, BytesMut::new());
writer.send(packet).await?;
writer.flush().await?;
Ok(())
}
pub async fn wait_cancelled(&self) {
if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
self.notify.notified().await;
}
}
#[must_use]
pub fn is_cancelling(&self) -> bool {
self.cancelling.load(std::sync::atomic::Ordering::Acquire)
}
}
impl<T> Clone for CancelHandle<T>
where
T: AsyncRead + AsyncWrite,
{
fn clone(&self) -> Self {
Self {
writer: Arc::clone(&self.writer),
notify: Arc::clone(&self.notify),
cancelling: Arc::clone(&self.cancelling),
}
}
}
impl<T> std::fmt::Debug for CancelHandle<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancelHandle")
.field("cancelling", &self.is_cancelling())
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_attention_packet_header() {
let header = PacketHeader::new(
PacketType::Attention,
PacketStatus::END_OF_MESSAGE,
PACKET_HEADER_SIZE as u16,
);
assert_eq!(header.packet_type, PacketType::Attention);
assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
}
#[test]
fn test_check_attention_done() {
let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
let payload_with_attn = BytesMut::from(
&[
0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
][..],
);
let packet_with_attn = Packet::new(header, payload_with_attn);
let payload_no_attn = BytesMut::from(
&[
0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
][..],
);
let packet_no_attn = Packet::new(header, payload_no_attn);
let check_done = |packet: &Packet| -> bool {
let payload = &packet.payload;
for i in 0..payload.len() {
if payload[i] == 0xFD && i + 3 <= payload.len() {
let status = u16::from_le_bytes([payload[i + 1], payload[i + 2]]);
if status & 0x0020 != 0 {
return true;
}
}
}
false
};
assert!(check_done(&packet_with_attn));
assert!(!check_done(&packet_no_attn));
}
}