1use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use futures_util::{SinkExt, StreamExt};
10use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType};
11use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
12use tokio::sync::{Mutex, Notify};
13
14use crate::error::CodecError;
15use crate::framed::{PacketReader, PacketWriter};
16use crate::message::{Message, MessageAssembler};
17use crate::packet_codec::{Packet, TdsCodec};
18
19pub struct Connection<T>
47where
48 T: AsyncRead + AsyncWrite,
49{
50 reader: PacketReader<ReadHalf<T>>,
52 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
54 assembler: MessageAssembler,
56 cancel_notify: Arc<Notify>,
58 cancelling: Arc<std::sync::atomic::AtomicBool>,
60}
61
62impl<T> Connection<T>
63where
64 T: AsyncRead + AsyncWrite,
65{
66 pub fn new(transport: T) -> Self {
70 let (read_half, write_half) = tokio::io::split(transport);
71
72 Self {
73 reader: PacketReader::new(read_half),
74 writer: Arc::new(Mutex::new(PacketWriter::new(write_half))),
75 assembler: MessageAssembler::new(),
76 cancel_notify: Arc::new(Notify::new()),
77 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
78 }
79 }
80
81 pub fn with_codecs(transport: T, read_codec: TdsCodec, write_codec: TdsCodec) -> Self {
83 let (read_half, write_half) = tokio::io::split(transport);
84
85 Self {
86 reader: PacketReader::with_codec(read_half, read_codec),
87 writer: Arc::new(Mutex::new(PacketWriter::with_codec(
88 write_half,
89 write_codec,
90 ))),
91 assembler: MessageAssembler::new(),
92 cancel_notify: Arc::new(Notify::new()),
93 cancelling: Arc::new(std::sync::atomic::AtomicBool::new(false)),
94 }
95 }
96
97 #[must_use]
101 pub fn cancel_handle(&self) -> CancelHandle<T> {
102 CancelHandle {
103 writer: Arc::clone(&self.writer),
104 notify: Arc::clone(&self.cancel_notify),
105 cancelling: Arc::clone(&self.cancelling),
106 }
107 }
108
109 #[must_use]
111 pub fn is_cancelling(&self) -> bool {
112 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
113 }
114
115 pub async fn read_message(&mut self) -> Result<Option<Message>, CodecError> {
123 loop {
124 if self.is_cancelling() {
126 return self.drain_after_cancel().await;
128 }
129
130 match self.reader.next().await {
131 Some(Ok(packet)) => {
132 if let Some(message) = self.assembler.push(packet) {
133 if self.is_cancelling() {
141 if Self::payload_ends_with_attention_done(&message.payload) {
142 tracing::debug!(
143 "received DONE with ATTENTION, cancellation complete"
144 );
145 self.finish_cancel();
146 return Err(CodecError::Cancelled);
147 }
148 tracing::debug!("discarding message from cancelled request");
149 continue;
150 }
151 return Ok(Some(message));
152 }
153 }
155 Some(Err(e)) => return Err(e),
156 None => {
157 if self.assembler.has_partial() {
159 return Err(CodecError::ConnectionClosed);
160 }
161 return Ok(None);
162 }
163 }
164 }
165 }
166
167 pub async fn read_packet(&mut self) -> Result<Option<Packet>, CodecError> {
171 match self.reader.next().await {
172 Some(result) => result.map(Some),
173 None => Ok(None),
174 }
175 }
176
177 pub async fn send_packet(&mut self, packet: Packet) -> Result<(), CodecError> {
179 let mut writer = self.writer.lock().await;
180 writer.send(packet).await
181 }
182
183 pub async fn send_message(
190 &mut self,
191 packet_type: PacketType,
192 payload: Bytes,
193 max_packet_size: usize,
194 ) -> Result<(), CodecError> {
195 self.send_message_with_reset(packet_type, payload, max_packet_size, false)
196 .await
197 }
198
199 pub async fn send_message_with_reset(
206 &mut self,
207 packet_type: PacketType,
208 payload: Bytes,
209 max_packet_size: usize,
210 reset_connection: bool,
211 ) -> Result<(), CodecError> {
212 let max_payload = max_packet_size - PACKET_HEADER_SIZE;
213 let chunks: Vec<&[u8]> = if payload.is_empty() {
218 vec![&[]]
219 } else {
220 payload.chunks(max_payload).collect()
221 };
222 let total_chunks = chunks.len();
223
224 let mut writer = self.writer.lock().await;
225
226 for (i, chunk) in chunks.into_iter().enumerate() {
227 let is_first = i == 0;
228 let is_last = i == total_chunks - 1;
229
230 let mut status = if is_last {
232 PacketStatus::END_OF_MESSAGE
233 } else {
234 PacketStatus::NORMAL
235 };
236
237 if is_first && reset_connection {
239 status |= PacketStatus::RESET_CONNECTION;
240 }
241
242 let header = PacketHeader::new(packet_type, status, 0);
243 let packet = Packet::new(header, BytesMut::from(chunk));
244
245 writer.send(packet).await?;
246 }
247
248 Ok(())
249 }
250
251 pub async fn flush(&mut self) -> Result<(), CodecError> {
253 let mut writer = self.writer.lock().await;
254 writer.flush().await
255 }
256
257 async fn drain_after_cancel(&mut self) -> Result<Option<Message>, CodecError> {
262 tracing::debug!("draining packets after cancellation");
263
264 self.assembler.clear();
266
267 loop {
268 match self.reader.next().await {
269 Some(Ok(packet)) => {
270 if let Some(message) = self.assembler.push(packet) {
274 if message.packet_type == PacketType::TabularResult
275 && Self::payload_ends_with_attention_done(&message.payload)
276 {
277 tracing::debug!("received DONE with ATTENTION, cancellation complete");
278 self.finish_cancel();
279 return Err(CodecError::Cancelled);
280 }
281 tracing::debug!("discarding message from cancelled request");
282 }
283 }
285 Some(Err(e)) => {
286 self.cancelling
287 .store(false, std::sync::atomic::Ordering::Release);
288 return Err(e);
289 }
290 None => {
291 self.cancelling
294 .store(false, std::sync::atomic::Ordering::Release);
295 return Err(CodecError::ConnectionClosed);
296 }
297 }
298 }
299 }
300
301 fn finish_cancel(&self) {
303 self.cancelling
304 .store(false, std::sync::atomic::Ordering::Release);
305 self.cancel_notify.notify_waiters();
306 }
307
308 fn payload_ends_with_attention_done(payload: &[u8]) -> bool {
321 let Some(start) = payload.len().checked_sub(13) else {
322 return false;
323 };
324 payload[start] == 0xFD
326 && u16::from_le_bytes([payload[start + 1], payload[start + 2]]) & 0x0020 != 0
327 }
328
329 pub fn read_codec(&self) -> &TdsCodec {
331 self.reader.codec()
332 }
333
334 pub fn read_codec_mut(&mut self) -> &mut TdsCodec {
336 self.reader.codec_mut()
337 }
338}
339
340impl<T> std::fmt::Debug for Connection<T>
341where
342 T: AsyncRead + AsyncWrite + std::fmt::Debug,
343{
344 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 f.debug_struct("Connection")
346 .field("cancelling", &self.is_cancelling())
347 .field("has_partial_message", &self.assembler.has_partial())
348 .finish_non_exhaustive()
349 }
350}
351
352pub struct CancelHandle<T>
357where
358 T: AsyncRead + AsyncWrite,
359{
360 writer: Arc<Mutex<PacketWriter<WriteHalf<T>>>>,
361 notify: Arc<Notify>,
362 cancelling: Arc<std::sync::atomic::AtomicBool>,
363}
364
365impl<T> CancelHandle<T>
366where
367 T: AsyncRead + AsyncWrite + Unpin,
368{
369 pub async fn cancel(&self) -> Result<(), CodecError> {
374 self.cancelling
376 .store(true, std::sync::atomic::Ordering::Release);
377
378 tracing::debug!("sending Attention packet for query cancellation");
379
380 let mut writer = self.writer.lock().await;
382
383 let header = PacketHeader::new(
385 PacketType::Attention,
386 PacketStatus::END_OF_MESSAGE,
387 PACKET_HEADER_SIZE as u16,
388 );
389 let packet = Packet::new(header, BytesMut::new());
390
391 writer.send(packet).await?;
392 writer.flush().await?;
393
394 Ok(())
395 }
396
397 pub async fn wait_cancelled(&self) {
402 if self.cancelling.load(std::sync::atomic::Ordering::Acquire) {
403 self.notify.notified().await;
404 }
405 }
406
407 #[must_use]
409 pub fn is_cancelling(&self) -> bool {
410 self.cancelling.load(std::sync::atomic::Ordering::Acquire)
411 }
412}
413
414impl<T> Clone for CancelHandle<T>
415where
416 T: AsyncRead + AsyncWrite,
417{
418 fn clone(&self) -> Self {
419 Self {
420 writer: Arc::clone(&self.writer),
421 notify: Arc::clone(&self.notify),
422 cancelling: Arc::clone(&self.cancelling),
423 }
424 }
425}
426
427impl<T> std::fmt::Debug for CancelHandle<T>
428where
429 T: AsyncRead + AsyncWrite + Unpin,
430{
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 f.debug_struct("CancelHandle")
433 .field("cancelling", &self.is_cancelling())
434 .finish_non_exhaustive()
435 }
436}
437
438#[cfg(test)]
439#[allow(clippy::unwrap_used, clippy::expect_used)]
440mod tests {
441 use super::*;
442
443 #[tokio::test]
448 async fn test_send_empty_payload_emits_one_eom_packet() {
449 use tokio::io::AsyncReadExt;
450
451 let (client_io, mut server_io) = tokio::io::duplex(4096);
452 let mut conn = Connection::new(client_io);
453
454 conn.send_message(PacketType::SqlBatch, Bytes::new(), 4096)
455 .await
456 .expect("empty message should send");
457
458 let mut header = [0u8; PACKET_HEADER_SIZE];
460 server_io
461 .read_exact(&mut header)
462 .await
463 .expect("one header-only packet must be sent");
464 assert_eq!(header[0], PacketType::SqlBatch as u8);
465 assert!(
466 PacketStatus::from_bits_truncate(header[1]).contains(PacketStatus::END_OF_MESSAGE),
467 "the single packet must be flagged END_OF_MESSAGE"
468 );
469 let length = u16::from_be_bytes([header[2], header[3]]);
470 assert_eq!(
471 length as usize, PACKET_HEADER_SIZE,
472 "length must be header-only (no payload)"
473 );
474
475 drop(conn);
477 let mut rest = Vec::new();
478 server_io.read_to_end(&mut rest).await.expect("read rest");
479 assert!(rest.is_empty(), "no second packet may follow");
480 }
481
482 #[tokio::test]
485 async fn test_reset_flag_on_first_packet_only_across_multi_packet_send() {
486 use tokio::io::AsyncReadExt;
487
488 let (client_io, mut server_io) = tokio::io::duplex(4096);
489 let mut conn = Connection::new(client_io);
490
491 let payload = Bytes::from(vec![0xABu8; 12]);
493 conn.send_message_with_reset(PacketType::SqlBatch, payload, 16, true)
494 .await
495 .expect("multi-packet send should succeed");
496 drop(conn);
497
498 let mut all = Vec::new();
499 server_io.read_to_end(&mut all).await.expect("read packets");
500
501 let s0 = PacketStatus::from_bits_truncate(all[1]);
503 assert!(
504 s0.contains(PacketStatus::RESET_CONNECTION),
505 "first packet must carry RESET_CONNECTION"
506 );
507 assert!(
508 !s0.contains(PacketStatus::END_OF_MESSAGE),
509 "first packet of two must not be END_OF_MESSAGE"
510 );
511
512 let s1 = PacketStatus::from_bits_truncate(all[16 + 1]);
514 assert!(
515 !s1.contains(PacketStatus::RESET_CONNECTION),
516 "RESET_CONNECTION must not repeat on later packets"
517 );
518 assert!(
519 s1.contains(PacketStatus::END_OF_MESSAGE),
520 "last packet must be END_OF_MESSAGE"
521 );
522 assert_eq!(all.len(), 16 + 8 + 4, "exactly two packets must be sent");
523 }
524
525 #[test]
526 fn test_attention_packet_header() {
527 let header = PacketHeader::new(
529 PacketType::Attention,
530 PacketStatus::END_OF_MESSAGE,
531 PACKET_HEADER_SIZE as u16,
532 );
533
534 assert_eq!(header.packet_type, PacketType::Attention);
535 assert!(header.status.contains(PacketStatus::END_OF_MESSAGE));
536 assert_eq!(header.length, PACKET_HEADER_SIZE as u16);
537 }
538
539 #[test]
540 fn test_check_attention_done() {
541 let header = PacketHeader::new(PacketType::TabularResult, PacketStatus::END_OF_MESSAGE, 0);
547
548 let payload_with_attn = BytesMut::from(
550 &[
551 0xFD, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
552 ][..],
553 );
554 let packet_with_attn = Packet::new(header, payload_with_attn);
555
556 let payload_no_attn = BytesMut::from(
558 &[
559 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
560 ][..],
561 );
562 let packet_no_attn = Packet::new(header, payload_no_attn);
563
564 assert!(
565 Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
566 &packet_with_attn.payload
567 )
568 );
569 assert!(
570 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(
571 &packet_no_attn.payload
572 )
573 );
574
575 let mut interior = vec![0xD1, 0x08, 0xFD, 0x20, 0xAA, 0xBB];
578 interior.extend_from_slice(&[
579 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
580 ]);
581 assert!(
582 !Connection::<tokio::io::DuplexStream>::payload_ends_with_attention_done(&interior)
583 );
584 }
585
586 fn raw_message(payload: &[u8]) -> Vec<u8> {
588 let mut v = vec![0x04, 0x01]; v.extend_from_slice(&((payload.len() + 8) as u16).to_be_bytes());
590 v.extend_from_slice(&[0, 0, 1, 0]); v.extend_from_slice(payload);
592 v
593 }
594
595 fn done_token(status: u16) -> [u8; 13] {
597 let s = status.to_le_bytes();
598 [
599 0xFD, s[0], s[1], 0xC1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
600 ]
601 }
602
603 #[tokio::test]
614 async fn test_cancel_mid_read_discards_cancelled_stream() {
615 use std::task::{Context, Poll};
616 use tokio::io::AsyncWriteExt;
617
618 let (client_io, mut server_io) = tokio::io::duplex(4096);
619 let mut conn = Connection::new(client_io);
620 let cancel = conn.cancel_handle();
621
622 let mut read_fut = Box::pin(conn.read_message());
626 let waker = std::task::Waker::noop();
627 let mut cx = Context::from_waker(waker);
628 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
629
630 cancel.cancel().await.expect("send attention");
633 server_io
634 .write_all(&raw_message(&done_token(0x0002))) .await
636 .unwrap();
637 server_io
638 .write_all(&raw_message(&done_token(0x0020))) .await
640 .unwrap();
641 server_io
642 .write_all(&raw_message(&done_token(0x0010))) .await
644 .unwrap();
645
646 let result = read_fut.await;
647 assert!(
648 matches!(result, Err(CodecError::Cancelled)),
649 "parked read must consume the cancelled stream and report \
650 Cancelled, got {result:?}"
651 );
652 assert!(!conn.is_cancelling(), "cancel flag must be cleared");
653
654 let message = conn
656 .read_message()
657 .await
658 .expect("next read")
659 .expect("next message");
660 assert_eq!(message.payload[0], 0xFD);
661 assert_eq!(
662 u16::from_le_bytes([message.payload[1], message.payload[2]]),
663 0x0010,
664 "next response must not be eaten by a stale drain"
665 );
666 }
667
668 #[tokio::test]
671 async fn test_cancel_before_read_drains_to_attention_ack() {
672 use tokio::io::AsyncWriteExt;
673
674 let (client_io, mut server_io) = tokio::io::duplex(4096);
675 let mut conn = Connection::new(client_io);
676 let cancel = conn.cancel_handle();
677
678 cancel.cancel().await.expect("send attention");
679 server_io
680 .write_all(&raw_message(&done_token(0x0022))) .await
682 .unwrap();
683 server_io
684 .write_all(&raw_message(&done_token(0x0010))) .await
686 .unwrap();
687
688 let result = conn.read_message().await;
689 assert!(matches!(result, Err(CodecError::Cancelled)));
690 assert!(!conn.is_cancelling());
691
692 let message = conn
693 .read_message()
694 .await
695 .expect("next read")
696 .expect("next message");
697 assert_eq!(
698 u16::from_le_bytes([message.payload[1], message.payload[2]]),
699 0x0010
700 );
701 }
702
703 #[tokio::test]
713 async fn test_cancel_race_row_bytes_do_not_fake_the_attention_ack() {
714 use std::task::{Context, Poll};
715 use tokio::io::AsyncWriteExt;
716
717 let (client_io, mut server_io) = tokio::io::duplex(4096);
718 let mut conn = Connection::new(client_io);
719 let cancel = conn.cancel_handle();
720
721 let mut read_fut = Box::pin(conn.read_message());
723 let waker = std::task::Waker::noop();
724 let mut cx = Context::from_waker(waker);
725 assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending));
726 cancel.cancel().await.expect("send attention");
727
728 let mut row_data = vec![0xD1, 0x08]; row_data.extend_from_slice(&[0xFD, 0x20, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
733 row_data.extend_from_slice(&done_token(0x0001)); server_io.write_all(&raw_message(&row_data)).await.unwrap();
735
736 server_io
738 .write_all(&raw_message(&done_token(0x0020)))
739 .await
740 .unwrap();
741
742 server_io
744 .write_all(&raw_message(&done_token(0x0010)))
745 .await
746 .unwrap();
747
748 let result = read_fut.await;
749 assert!(
750 matches!(result, Err(CodecError::Cancelled)),
751 "cancelled read must end in Cancelled, got {result:?}"
752 );
753 assert!(!conn.is_cancelling());
754
755 let message = conn
758 .read_message()
759 .await
760 .expect("next read")
761 .expect("next message");
762 let status = u16::from_le_bytes([message.payload[1], message.payload[2]]);
763 assert_eq!(
764 status, 0x0010,
765 "next request's response must come through intact; 0x0020 means \
766 the interior row bytes were mistaken for the ack and the real \
767 ack leaked into the next request"
768 );
769 }
770}