1use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19pub struct Connection<T>
47where
48 T: AsyncRead + AsyncWrite,
49{
50 reader: PacketReader<ReadHalf<T>>,
52 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54 assembler: MessageAssembler,
56 cancel_notify: Arc<Notify>,
58 cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64 T: AsyncRead + AsyncWrite,
65{
66 pub fn new(transport: T) -> Self {
70 let (read_half, write_half) = tokio::io::split(transport);
71
72 Self {
73 reader: PacketReader::new(read_half),
74 writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75 assembler: MessageAssembler::new(),
76 cancel_notify: Arc::new(Notify::new()),
77 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78 }
79 }
80
81 pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83 let (read_half, write_half) = tokio::io::split(transport);
84
85 Self {
86 reader: PacketReader::with_codec(read_half, read_codec),
87 writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88 write_half,
89 write_codec,
90 ))),
91 assembler: MessageAssembler::new(),
92 cancel_notify: Arc::new(Notify::new()),
93 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94 }
95 }
96
97 #[must_use]
101 pub fn cancel_handle(&self) -> CancelHandle<T> {
102 CancelHandle {
103 writer: Arc::clone(&self.writer),
104 notify: Arc::clone(&self.cancel_notify),
105 cancelling: Arc::clone(&self.cancelling),
106 }
107 }
108
109 #[must_use]
111 pub fn is_cancelling(&self) -> bool {
112 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113 }
114
115 pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
123 loop {
124 if self.is_cancelling() {
126 return self.drain_after_cancel().await;
128 }
129
130 match self.reader.next().await {
131 Some(Ok(packet)) => {
132 if let Some(message) = self.assembler.push(packet) {
133 if self.is_cancelling() {
141 if Self::payload_ends_with_attention_done(&message.payload) {
142 tracing::debug!(
143 "received DONE with ATTENTION, cancellation complete"
144 );
145 self.finish_cancel();
146 return Err(CodecError::Cancelled);
147 }
148 tracing::debug!("discarding message from cancelled request");
149 continue;
150 }
151 return Ok(Some(message));
152 }
153 }
155 Some(Err(e)) => return Err(e),
156 None => {
157 if self.assembler.has_partial() {
159 return Err(CodecError::ConnectionClosed);
160 }
161 return Ok(None);
162 }
163 }
164 }
165 }
166
167 pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
171 match self.reader.next().await {
172 Some(result) => result.map(Some),
173 None => Ok(None),
174 }
175 }
176
177 pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
179 let mut writer = self.writer.lock().await;
180 writer.send(packet).await
181 }
182
183 pub async fn send_message(
190 &mut self,
191 packet_type: PacketType,
192 payload: Bytes,
193 max_packet_size: usize,
194 ) -> Result<(), CodecError> {
195 self.send_message_with_reset(packet_type, payload, max_packet_size, false)
196 .await
197 }
198
199 pub async fn send_message_with_reset(
206 &mut self,
207 packet_type: PacketType,
208 payload: Bytes,
209 max_packet_size: usize,
210 reset_connection: bool,
211 ) -> Result<(), CodecError> {
212 let max_payload = max_packet_size - PACKET_HEADER_SIZE;
213 let chunks: Vec<_> = payload.chunks(max_payload).collect();
214 let total_chunks = chunks.len();
215
216 let mut writer = self.writer.lock().await;
217
218 for (i, chunk) in chunks.into_iter().enumerate() {
219 let is_first = i == 0;
220 let is_last = i == total_chunks - 1;
221
222 let mut status = if is_last {
224 PacketStatus::END_OF_MESSAGE
225 } else {
226 PacketStatus::NORMAL
227 };
228
229 if is_first && reset_connection {
231 status |= PacketStatus::RESET_CONNECTION;
232 }
233
234 let header = PacketHeader::new(packet_type, status, 0);
235 let packet = Packet::new(header, BytesMut::from(chunk));
236
237 writer.send(packet).await?;
238 }
239
240 Ok(())
241 }
242
243 pub async fn flush(&mut self) -> Result<(), CodecError> {
245 let mut writer = self.writer.lock().await;
246 writer.flush().await
247 }
248
249 async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
254 tracing::debug!("draining packets after cancellation");
255
256 self.assembler.clear();
258
259 loop {
260 match self.reader.next().await {
261 Some(Ok(packet)) => {
262 if let Some(message) = self.assembler.push(packet) {
266 if message.packet_type == PacketType::TabularResult
267 && Self::payload_ends_with_attention_done(&message.payload)
268 {
269 tracing::debug!("received DONE with ATTENTION, cancellation complete");
270 self.finish_cancel();
271 return Err(CodecError::Cancelled);
272 }
273 tracing::debug!("discarding message from cancelled request");
274 }
275 }
277 Some(Err(e)) => {
278 self.cancelling
279 .store(false, std::sync::atomic::Ordering::Release);
280 return Err(e);
281 }
282 None => {
283 self.cancelling
286 .store(false, std::sync::atomic::Ordering::Release);
287 return Err(CodecError::ConnectionClosed);
288 }
289 }
290 }
291 }
292
293 fn finish_cancel(&self) {
295 self.cancelling
296 .store(false, std::sync::atomic::Ordering::Release);
297 self.cancel_notify.notify_waiters();
298 }
299
300 fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
313 let Some(start) = payload.len().checked_sub(13) else {
314 return false;
315 };
316 payload[start] == 0xFD
318 && u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
319 }
320
321 pub fn read_codec(&self) -> &TdsCodec {
323 self.reader.codec()
324 }
325
326 pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
328 self.reader.codec_mut()
329 }
330}
331
332impl<T> std::fmt::Debug for Connection<T>
333where
334 T: AsyncRead + AsyncWrite + std::fmt::Debug,
335{
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_struct("Connection")
338 .field("cancelling", &self.is_cancelling())
339 .field("has_partial_message", &self.assembler.has_partial())
340 .finish_non_exhaustive()
341 }
342}
343
344pub struct CancelHandle<T>
349where
350 T: AsyncRead + AsyncWrite,
351{
352 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
353 notify: Arc<Notify>,
354 cancelling: Arc<std::sync::atomic::AtomicBool>,
355}
356
357impl<T> CancelHandle<T>
358where
359 T: AsyncRead + AsyncWrite + Unpin,
360{
361 pub async fn cancel(&self) -> Result<(), CodecError> {
366 self.cancelling
368 .store(true, std::sync::atomic::Ordering::Release);
369
370 tracing::debug!("sending Attention packet for query cancellation");
371
372 let mut writer = self.writer.lock().await;
374
375 let header = PacketHeader::new(
377 PacketType::Attention,
378 PacketStatus::END_OF_MESSAGE,
379 PACKET_HEADER_SIZE as u16,
380 );
381 let packet = Packet::new(header, BytesMut::new());
382
383 writer.send(packet).await?;
384 writer.flush().await?;
385
386 Ok(())
387 }
388
389 pub async fn wait_cancelled(&self) {
394 if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
395 self.notify.notified().await;
396 }
397 }
398
399 #[must_use]
401 pub fn is_cancelling(&self) -> bool {
402 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
403 }
404}
405
406impl<T> Clone for CancelHandle<T>
407where
408 T: AsyncRead + AsyncWrite,
409{
410 fn clone(&self) -> Self {
411 Self {
412 writer: Arc::clone(&self.writer),
413 notify: Arc::clone(&self.notify),
414 cancelling: Arc::clone(&self.cancelling),
415 }
416 }
417}
418
419impl<T> std::fmt::Debug for CancelHandle<T>
420where
421 T: AsyncRead + AsyncWrite + Unpin,
422{
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 f.debug_struct("CancelHandle")
425 .field("cancelling", &self.is_cancelling())
426 .finish_non_exhaustive()
427 }
428}
429
430#[cfg(test)]
431#[allow(clippy::unwrap_used, clippy::expect_used)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_attention_packet_header() {
437 let header = PacketHeader::new(
439 PacketType::Attention,
440 PacketStatus::END_OF_MESSAGE,
441 PACKET_HEADER_SIZE as u16,
442 );
443
444 assert_eq!(header.packet_type, PacketType::Attention);
445 assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
446 assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
447 }
448
449 #[test]
450 fn test_check_attention_done() {
451 let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
457
458 let payload_with_attn = BytesMut::from(
460 &[
461 0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
462 ][..],
463 );
464 let packet_with_attn = Packet::new(header, payload_with_attn);
465
466 let payload_no_attn = BytesMut::from(
468 &[
469 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
470 ][..],
471 );
472 let packet_no_attn = Packet::new(header, payload_no_attn);
473
474 assert!(
475 Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
476 &packet_with_attn.payload
477 )
478 );
479 assert!(
480 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
481 &packet_no_attn.payload
482 )
483 );
484
485 let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
488 interior.extend_from_slice(&[
489 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
490 ]);
491 assert!(
492 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
493 );
494 }
495
496 fn raw_message(payload: &[u8]) -> Vec<u8> {
498 let mut v = vec![0x04, 0x01]; v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
500 v.extend_from_slice(&[0, 0, 1, 0]); v.extend_from_slice(payload);
502 v
503 }
504
505 fn done_token(status: u16) -> [u8; 13] {
507 let s = status.to_le_bytes();
508 [
509 0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
510 ]
511 }
512
513 #[tokio::test]
524 async fn test_cancel_mid_read_discards_cancelled_stream() {
525 use std::task::{Context, Poll};
526 use tokio::io::AsyncWriteExt;
527
528 let (client_io, mut server_io) = tokio::io::duplex(4096);
529 let mut conn = Connection::new(client_io);
530 let cancel = conn.cancel_handle();
531
532 let mut read_fut = Box::pin(conn.read_message());
536 let waker = std::task::Waker::noop();
537 let mut cx = Context::from_waker(waker);
538 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
539
540 cancel.cancel().await.expect("send attention");
543 server_io
544 .write_all(&raw_message(&done_token(0x0002))) .await
546 .unwrap();
547 server_io
548 .write_all(&raw_message(&done_token(0x0020))) .await
550 .unwrap();
551 server_io
552 .write_all(&raw_message(&done_token(0x0010))) .await
554 .unwrap();
555
556 let result = read_fut.await;
557 assert!(
558 matches!(result, Err(CodecError::Cancelled)),
559 "parked read must consume the cancelled stream and report \
560 Cancelled, got {result:?}"
561 );
562 assert!(!conn.is_cancelling(), "cancel flag must be cleared");
563
564 let message = conn
566 .read_message()
567 .await
568 .expect("next read")
569 .expect("next message");
570 assert_eq!(message.payload[0], 0xFD);
571 assert_eq!(
572 u16::from_le_bytes([message.payload[1], message.payload[2]]),
573 0x0010,
574 "next response must not be eaten by a stale drain"
575 );
576 }
577
578 #[tokio::test]
581 async fn test_cancel_before_read_drains_to_attention_ack() {
582 use tokio::io::AsyncWriteExt;
583
584 let (client_io, mut server_io) = tokio::io::duplex(4096);
585 let mut conn = Connection::new(client_io);
586 let cancel = conn.cancel_handle();
587
588 cancel.cancel().await.expect("send attention");
589 server_io
590 .write_all(&raw_message(&done_token(0x0022))) .await
592 .unwrap();
593 server_io
594 .write_all(&raw_message(&done_token(0x0010))) .await
596 .unwrap();
597
598 let result = conn.read_message().await;
599 assert!(matches!(result, Err(CodecError::Cancelled)));
600 assert!(!conn.is_cancelling());
601
602 let message = conn
603 .read_message()
604 .await
605 .expect("next read")
606 .expect("next message");
607 assert_eq!(
608 u16::from_le_bytes([message.payload[1], message.payload[2]]),
609 0x0010
610 );
611 }
612
613 #[tokio::test]
623 async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
624 use std::task::{Context, Poll};
625 use tokio::io::AsyncWriteExt;
626
627 let (client_io, mut server_io) = tokio::io::duplex(4096);
628 let mut conn = Connection::new(client_io);
629 let cancel = conn.cancel_handle();
630
631 let mut read_fut = Box::pin(conn.read_message());
633 let waker = std::task::Waker::noop();
634 let mut cx = Context::from_waker(waker);
635 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
636 cancel.cancel().await.expect("send attention");
637
638 let mut row_data = vec![0xD1, 0x08]; row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
643 row_data.extend_from_slice(&done_token(0x0001)); server_io.write_all(&raw_message(&row_data)).await.unwrap();
645
646 server_io
648 .write_all(&raw_message(&done_token(0x0020)))
649 .await
650 .unwrap();
651
652 server_io
654 .write_all(&raw_message(&done_token(0x0010)))
655 .await
656 .unwrap();
657
658 let result = read_fut.await;
659 assert!(
660 matches!(result, Err(CodecError::Cancelled)),
661 "cancelled read must end in Cancelled, got {result:?}"
662 );
663 assert!(!conn.is_cancelling());
664
665 let message = conn
668 .read_message()
669 .await
670 .expect("next read")
671 .expect("next message");
672 let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
673 assert_eq!(
674 status, 0x0010,
675 "next request's response must come through intact; 0x0020 means \
676 the interior row bytes were mistaken for the ack and the real \
677 ack leaked into the next request"
678 );
679 }
680}